3131
3232@triton .jit
3333def moe_mmk (
34- a_desc ,
35- b_desc ,
34+ a_ptrs ,
35+ b_ptrs ,
3636 K ,
3737 expert_id ,
3838 a_scale_ptr ,
@@ -41,6 +41,9 @@ def moe_mmk(
4141 # moving by 1 element in a particular dimension. E.g. `stride_am` is
4242 # how much to increase `a_ptr` by to get the element one row down
4343 # (A has M rows).
44+ stride_ak : tl .int64 ,
45+ stride_bk : tl .int64 ,
46+ stride_ase : tl .int64 ,
4447 stride_asm : tl .int64 ,
4548 stride_ask : tl .int64 ,
4649 stride_bse : tl .int64 ,
@@ -65,6 +68,7 @@ def moe_mmk(
6568 use_w8a16 : tl .constexpr ,
6669 per_act_token_quant : tl .constexpr ,
6770):
71+ offs_k = tl .arange (0 , BLOCK_K )
6872
6973 if use_w8a16 :
7074 b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n [None , :] * stride_bsn
@@ -99,8 +103,12 @@ def moe_mmk(
99103 accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
100104 for k in range (0 , tl .cdiv (K , BLOCK_K )):
101105 # Load the next block of A and B using tensor descriptors
102- a = a_desc .load ([pid_m * BLOCK_M , k * BLOCK_K ])
103- b = b_desc .load ([k * BLOCK_K , pid_n * BLOCK_N ])
106+ a = tl .load (
107+ a_ptrs ,
108+ mask = mask_m [:, None ] & (offs_k [None , :] < K - k * BLOCK_K ),
109+ other = 0.0 ,
110+ )
111+ b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_K , other = 0.0 )
104112
105113 # We accumulate along the K dimension.
106114 if use_w8a16 :
@@ -119,6 +127,9 @@ def moe_mmk(
119127 else :
120128 accumulator += tl .dot (a , b )
121129
130+ a_ptrs += BLOCK_K * stride_ak
131+ b_ptrs += BLOCK_K * stride_bk
132+
122133 if use_w8a16 :
123134 accumulator = (accumulator * b_scale ).to (compute_type )
124135 elif use_w8a8 :
@@ -134,9 +145,9 @@ def moe_mmk(
134145
135146@triton .jit
136147def expert_triton_kernel (
137- a_desc , #[max_tokens, K]
138- b_desc , #[K, N]
139- c_desc , #[max_tokens, N]
148+ a_ptr ,
149+ b_ptr ,
150+ c_ptr ,
140151 expert_id ,
141152 compute_type : tl .constexpr ,
142153 # Dimensions
@@ -147,8 +158,12 @@ def expert_triton_kernel(
147158 a_scale_ptr ,
148159 b_scale_ptr ,
149160 # strides
161+ stride_am : tl .int64 ,
150162 stride_ak : tl .int64 ,
151163 stride_bk : tl .int64 ,
164+ stride_bn : tl .int64 ,
165+ stride_cm : tl .int64 ,
166+ stride_cn : tl .int64 ,
152167 stride_ase : tl .int64 ,
153168 stride_asm : tl .int64 ,
154169 stride_ask : tl .int64 ,
@@ -174,15 +189,19 @@ def expert_triton_kernel(
174189
175190 offs_m = tl .arange (0 , BLOCK_M )
176191 offs_n = tl .arange (0 , BLOCK_N ) % N
177- # offs_k = tl.arange(0, BLOCK_K)
192+ offs_k = tl .arange (0 , BLOCK_K )
178193 mask_m = offs_m < M
179194
195+ a_ptrs = a_ptr + offs_m [:, None ] * stride_am + offs_k [None , :] * stride_ak
196+ b_ptrs = b_ptr + offs_k [:, None ] * stride_bk + offs_n [None , :] * stride_bn
197+
180198 accumulator = moe_mmk (
181- a_desc , b_desc , K , expert_id , a_scale_ptr , b_scale_ptr ,
199+ a_ptrs , b_ptrs , K , expert_id , a_scale_ptr , b_scale_ptr ,
182200 # The stride variables represent how much to increase the ptr by when
183201 # moving by 1 element in a particular dimension. E.g. `stride_am` is
184202 # how much to increase `a_ptr` by to get the element one row down
185203 # (A has M rows).
204+ stride_ak , stride_bk , stride_ase ,
186205 stride_asm , stride_ask , stride_bse , stride_bsk , stride_bsn ,
187206 # Offsets and masks
188207 offs_m , offs_n , offs_bn , mask_m ,
@@ -192,11 +211,10 @@ def expert_triton_kernel(
192211 BLOCK_M , BLOCK_N , BLOCK_K , compute_type , use_fp8_w8a8 , use_int8_w8a16 , per_act_token_quant )
193212
194213 # store in C
195- # offs_cn = tl.arange(0, BLOCK_N)
196- # c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
197- # c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
198- c_desc .store ([pid_m * BLOCK_M , pid_n * BLOCK_N ], accumulator )
199- # tl.store(c_ptrs, accumulator, mask=c_mask)
214+ offs_cn = tl .arange (0 , BLOCK_N )
215+ c_ptrs = c_ptr + offs_m [:, None ] * stride_cm + offs_cn [None , :] * stride_cn
216+ c_mask = mask_m [:, None ] & (offs_cn [None , :] < N )
217+ tl .store (c_ptrs , accumulator , mask = c_mask )
200218
201219
202220def get_matmul_batched_autotune_configs () -> List [triton .Config ]:
@@ -292,17 +310,10 @@ def batched_triton_kernel(
292310 cta_m_size = min (BLOCK_M , e_num_tokens - cta_m_start )
293311 cta_n_size = min (BLOCK_N , N - cta_n_start )
294312
295- a_desc = tl .make_tensor_descriptor (base = a_ptr + expert_id * stride_ae , shape = (e_num_tokens , K ),
296- strides = (stride_am , stride_ak ), block_shape = (BLOCK_M , BLOCK_K ))
297- b_desc = tl .make_tensor_descriptor (base = b_ptr + expert_id * stride_be , shape = (K , N ), strides = (stride_bk , stride_bn ),
298- block_shape = (BLOCK_K , BLOCK_N ))
299- c_desc = tl .make_tensor_descriptor (base = c_ptr + expert_id * stride_ce , shape = (e_num_tokens , N ),
300- strides = (stride_cm , stride_cn ), block_shape = (BLOCK_M , BLOCK_N ))
301-
302- # a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
303- # b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
304- # c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
305- # cta_n_start * stride_cn)
313+ a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
314+ b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
315+ c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
316+ cta_n_start * stride_cn )
306317
307318 offs_bn = (pid_n * BLOCK_N + tl .arange (0 , BLOCK_N ).to (tl .int64 )) % N
308319
@@ -314,12 +325,12 @@ def batched_triton_kernel(
314325 if group_k > 0 and group_n > 0 or per_act_token_quant :
315326 a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
316327
317- expert_triton_kernel (a_desc , b_desc , c_desc , expert_id , compute_type , cta_m_size , # M
328+ expert_triton_kernel (a_ptr , b_ptr , c_ptr , expert_id , compute_type , cta_m_size , # M
318329 cta_n_size , # N
319330 K , # K
320331 a_scale_ptr , b_scale_ptr ,
321332 # Strides
322- stride_ak , stride_bk , stride_ase , stride_asm , stride_ask , stride_bse , stride_bsk , stride_bsn ,
333+ stride_am , stride_ak , stride_bk , stride_bn , stride_cm , stride_cn , stride_ase , stride_asm , stride_ask , stride_bse , stride_bsk , stride_bsn ,
323334 # offsets
324335 offs_bn ,
325336 # Blockwise quantization data
@@ -502,13 +513,8 @@ def get_batched_mm_benchmark(
502513 Returns a Mark object containing a Benchmark object for batched matrix multiplication.
503514 """
504515 supported_providers = {
505- 'triton' : 'triton' ,
506516 'triton-td' : 'triton-td' ,
507- 'pytorch' : 'pytorch' ,
508517 }
509- if fp8 :
510- # pytorch is very slow with fp8 case, for (8, 64, 1024, 2048) case it has ~0.15 TFlops vs 1.5 for triton
511- del supported_providers ['pytorch' ]
512518
513519 providers = benchmark_suite .filter_providers (supported_providers , providers_filter )
514520 configs = MM_CONFIGS_FP8 if fp8 else MM_CONFIGS_BF16
0 commit comments