Skip to content

Commit cda4229

Browse files
authored
[BENCH] add renormalize knob in routing, so that the kernel can be used for qwen and mixtral moe family (#6896)
In some family of model, such as qwen1.5 or mixtral 7x8b or 7x22b. The expert_weights are calculated first using softmax and then topk (without renormalization). For the `routing` kernel to be compatible with those models, a new knob is added to turn off the softmax after topk and instead passed in logits that are already softmax-ed for calculation.
1 parent ece03bb commit cda4229

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

python/triton_kernels/tests/test_routing.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def ref_expt_data(routing_data, n_gates, block_m):
4444
@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 4), (1500, 8)])
4545
@pytest.mark.parametrize("block_m", [64, 128])
4646
@pytest.mark.parametrize("use_expt_indx", [False, True])
47-
def test_op(n_tokens, n_expts_tot, n_expts_act, block_m, use_expt_indx, device):
47+
@pytest.mark.parametrize("renormalize", [True, False])
48+
def test_op(n_tokens, n_expts_tot, n_expts_act, renormalize, block_m, use_expt_indx, device):
4849
torch.manual_seed(2)
4950
tri_logits = init_data(n_tokens, n_expts_tot, device=device).detach()
5051
ref_logits = tri_logits.clone()
@@ -55,8 +56,11 @@ def test_op(n_tokens, n_expts_tot, n_expts_act, block_m, use_expt_indx, device):
5556
ref_expt_indx = tri_expt_indx[:n_tokens]
5657
else:
5758
tri_expt_indx = ref_expt_indx = None
58-
ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act, ref_expt_indx)
59-
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act, tri_expt_indx)
59+
if not renormalize:
60+
tri_logits = torch.softmax(tri_logits, dim=-1)
61+
ref_logits = torch.softmax(ref_logits, dim=-1)
62+
ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act, renormalize, ref_expt_indx)
63+
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act, renormalize, tri_expt_indx)
6064
ref_metadata = ref_expt_data(ref_routing_data, n_tokens * n_expts_act, block_m)
6165
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m)
6266

python/triton_kernels/triton_kernels/routing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def n_blocks(self, n_rows, block_m):
5353
# --------------------------
5454

5555

56-
def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
56+
def routing(logits, n_expts_act, renormalize=True, expt_indx=None, simulated_ep=1):
5757
from .topk import topk
5858
from .compaction import compaction
5959
cdiv = triton.cdiv
@@ -63,7 +63,7 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
6363
n_tokens, n_expts_tot = logits.shape
6464
n_gates = n_tokens * n_expts_act
6565
device = logits.device
66-
expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, y_indx=expt_indx)
66+
expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, apply_softmax=renormalize, y_indx=expt_indx)
6767
# mutate bitmatrix
6868
if simulated_ep > 1:
6969
assert n_expts_tot % simulated_ep == 0
@@ -108,7 +108,7 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
108108
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act), gather_indx, scatter_indx
109109

110110

111-
def routing_torch(logits, n_expts_act, expt_indx=None):
111+
def routing_torch(logits, n_expts_act, renormalize=True, expt_indx=None):
112112

113113
def topk(vals, k, expt_indx):
114114
# topk of experts
@@ -121,7 +121,8 @@ def topk(vals, k, expt_indx):
121121

122122
_, n_expts_tot = logits.shape
123123
expt_scal, expt_indx = topk(logits, n_expts_act, expt_indx)
124-
expt_scal = torch.softmax(expt_scal, dim=-1)
124+
if renormalize:
125+
expt_scal = torch.softmax(expt_scal, dim=-1)
125126
# flatten topk data
126127
expt_scal = expt_scal.reshape(-1)
127128
expt_indx = expt_indx.reshape(-1).to(torch.int32)

python/triton_kernels/triton_kernels/topk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .bitmatrix import Bitmatrix
44

55

6-
def topk(x, k, dim=1, return_bitmatrix=True, y_indx=None):
6+
def topk(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None):
77
cdiv = lambda a, b: (a + b - 1) // b
88
BLOCK_M = 32
99
BLOCK_N = 32
@@ -39,5 +39,5 @@ def topk(x, k, dim=1, return_bitmatrix=True, y_indx=None):
3939
S, BLOCK_S, s_blocks, # thing to memset to zero
4040
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter
4141
N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants
42-
)
42+
APPLY_SOFTMAX=apply_softmax)
4343
return y_vals, y_indx, Bitmatrix(bitmatrix, [n_rows, n_cols], S)

python/triton_kernels/triton_kernels/topk_details/_topk.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def _topk(X, stride_xm, # inputs
7272
Yv, Yi, stride_ym, # topk values/indices
7373
USE_PROVIDED_INDX: tl.constexpr, Bits, stride_rm: tl.constexpr, stride_rn: tl.constexpr, n_rows, # bitmatrix
7474
n_expts_tot, S, BLOCK_S: tl.constexpr, s_blocks, # thing to memset
75-
BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr):
75+
BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr,
76+
APPLY_SOFTMAX: tl.constexpr):
7677

7778
pid = tl.program_id(0)
7879

@@ -105,8 +106,8 @@ def _topk(X, stride_xm, # inputs
105106
y_indices = y & 0x0000FFFF
106107
y_values = (y >> x_nbits).to(x_utype).to(x_dtype, bitcast=True)
107108

108-
# normalize selected values
109-
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)
109+
if APPLY_SOFTMAX:
110+
y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)
110111

111112
# write back
112113
Yv_ptrs = Yv + offs_m[:, None] * stride_ym + offs_y_n[None, :]

0 commit comments

Comments
 (0)