3131
3232@triton .jit
3333def moe_mmk (
34- a_ptrs ,
35- b_ptrs ,
34+ a_desc ,
35+ b_desc ,
3636 K ,
3737 expert_id ,
3838 a_scale_ptr ,
@@ -41,9 +41,6 @@ def moe_mmk(
4141 # moving by 1 element in a particular dimension. E.g. `stride_am` is
4242 # how much to increase `a_ptr` by to get the element one row down
4343 # (A has M rows).
44- stride_ak : tl .int64 ,
45- stride_bk : tl .int64 ,
46- stride_ase : tl .int64 ,
4744 stride_asm : tl .int64 ,
4845 stride_ask : tl .int64 ,
4946 stride_bse : tl .int64 ,
@@ -68,7 +65,6 @@ def moe_mmk(
6865 use_w8a16 : tl .constexpr ,
6966 per_act_token_quant : tl .constexpr ,
7067):
71- offs_k = tl .arange (0 , BLOCK_K )
7268
7369 if use_w8a16 :
7470 b_scale_ptrs = b_scale_ptr + expert_id * stride_bse + offs_n [None , :] * stride_bsn
@@ -103,12 +99,8 @@ def moe_mmk(
10399 accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
104100 for k in range (0 , tl .cdiv (K , BLOCK_K )):
105101 # Load the next block of A and B using tensor descriptors
106- a = tl .load (
107- a_ptrs ,
108- mask = mask_m [:, None ] & (offs_k [None , :] < K - k * BLOCK_K ),
109- other = 0.0 ,
110- )
111- b = tl .load (b_ptrs , mask = offs_k [:, None ] < K - k * BLOCK_K , other = 0.0 )
102+ a = a_desc .load ([pid_m * BLOCK_M , k * BLOCK_K ])
103+ b = b_desc .load ([k * BLOCK_K , pid_n * BLOCK_N ])
112104
113105 # We accumulate along the K dimension.
114106 if use_w8a16 :
@@ -127,9 +119,6 @@ def moe_mmk(
127119 else :
128120 accumulator += tl .dot (a , b )
129121
130- a_ptrs += BLOCK_K * stride_ak
131- b_ptrs += BLOCK_K * stride_bk
132-
133122 if use_w8a16 :
134123 accumulator = (accumulator * b_scale ).to (compute_type )
135124 elif use_w8a8 :
@@ -145,9 +134,9 @@ def moe_mmk(
145134
146135@triton .jit
147136def expert_triton_kernel (
148- a_ptr ,
149- b_ptr ,
150- c_ptr ,
137+ a_desc , #[max_tokens, K]
138+ b_desc , #[K, N]
139+ c_desc , #[max_tokens, N]
151140 expert_id ,
152141 compute_type : tl .constexpr ,
153142 # Dimensions
@@ -158,12 +147,8 @@ def expert_triton_kernel(
158147 a_scale_ptr ,
159148 b_scale_ptr ,
160149 # strides
161- stride_am : tl .int64 ,
162150 stride_ak : tl .int64 ,
163151 stride_bk : tl .int64 ,
164- stride_bn : tl .int64 ,
165- stride_cm : tl .int64 ,
166- stride_cn : tl .int64 ,
167152 stride_ase : tl .int64 ,
168153 stride_asm : tl .int64 ,
169154 stride_ask : tl .int64 ,
@@ -189,19 +174,15 @@ def expert_triton_kernel(
189174
190175 offs_m = tl .arange (0 , BLOCK_M )
191176 offs_n = tl .arange (0 , BLOCK_N ) % N
192- offs_k = tl .arange (0 , BLOCK_K )
177+ # offs_k = tl.arange(0, BLOCK_K)
193178 mask_m = offs_m < M
194179
195- a_ptrs = a_ptr + offs_m [:, None ] * stride_am + offs_k [None , :] * stride_ak
196- b_ptrs = b_ptr + offs_k [:, None ] * stride_bk + offs_n [None , :] * stride_bn
197-
198180 accumulator = moe_mmk (
199- a_ptrs , b_ptrs , K , expert_id , a_scale_ptr , b_scale_ptr ,
181+ a_desc , b_desc , K , expert_id , a_scale_ptr , b_scale_ptr ,
200182 # The stride variables represent how much to increase the ptr by when
201183 # moving by 1 element in a particular dimension. E.g. `stride_am` is
202184 # how much to increase `a_ptr` by to get the element one row down
203185 # (A has M rows).
204- stride_ak , stride_bk , stride_ase ,
205186 stride_asm , stride_ask , stride_bse , stride_bsk , stride_bsn ,
206187 # Offsets and masks
207188 offs_m , offs_n , offs_bn , mask_m ,
@@ -211,10 +192,11 @@ def expert_triton_kernel(
211192 BLOCK_M , BLOCK_N , BLOCK_K , compute_type , use_fp8_w8a8 , use_int8_w8a16 , per_act_token_quant )
212193
213194 # store in C
214- offs_cn = tl .arange (0 , BLOCK_N )
215- c_ptrs = c_ptr + offs_m [:, None ] * stride_cm + offs_cn [None , :] * stride_cn
216- c_mask = mask_m [:, None ] & (offs_cn [None , :] < N )
217- tl .store (c_ptrs , accumulator , mask = c_mask )
195+ # offs_cn = tl.arange(0, BLOCK_N)
196+ # c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_cn[None, :] * stride_cn
197+ # c_mask = mask_m[:, None] & (offs_cn[None, :] < N)
198+ c_desc .store ([pid_m * BLOCK_M , pid_n * BLOCK_N ], accumulator )
199+ # tl.store(c_ptrs, accumulator, mask=c_mask)
218200
219201
220202def get_matmul_batched_autotune_configs () -> List [triton .Config ]:
@@ -310,10 +292,17 @@ def batched_triton_kernel(
310292 cta_m_size = min (BLOCK_M , e_num_tokens - cta_m_start )
311293 cta_n_size = min (BLOCK_N , N - cta_n_start )
312294
313- a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
314- b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
315- c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
316- cta_n_start * stride_cn )
295+ a_desc = tl .make_tensor_descriptor (base = a_ptr + expert_id * stride_ae , shape = (e_num_tokens , K ),
296+ strides = (stride_am , stride_ak ), block_shape = (BLOCK_M , BLOCK_K ))
297+ b_desc = tl .make_tensor_descriptor (base = b_ptr + expert_id * stride_be , shape = (K , N ), strides = (stride_bk , stride_bn ),
298+ block_shape = (BLOCK_K , BLOCK_N ))
299+ c_desc = tl .make_tensor_descriptor (base = c_ptr + expert_id * stride_ce , shape = (e_num_tokens , N ),
300+ strides = (stride_cm , stride_cn ), block_shape = (BLOCK_M , BLOCK_N ))
301+
302+ # a_ptr = a_ptr + expert_id * stride_ae + cta_m_start * stride_am
303+ # b_ptr = b_ptr + expert_id * stride_be + cta_n_start * stride_bn
304+ # c_ptr = (c_ptr + expert_id * stride_ce + cta_m_start * stride_cm +
305+ # cta_n_start * stride_cn)
317306
318307 offs_bn = (pid_n * BLOCK_N + tl .arange (0 , BLOCK_N ).to (tl .int64 )) % N
319308
@@ -325,12 +314,12 @@ def batched_triton_kernel(
325314 if group_k > 0 and group_n > 0 or per_act_token_quant :
326315 a_scale_ptr = a_scale_ptr + cta_m_start * stride_asm
327316
328- expert_triton_kernel (a_ptr , b_ptr , c_ptr , expert_id , compute_type , cta_m_size , # M
317+ expert_triton_kernel (a_desc , b_desc , c_desc , expert_id , compute_type , cta_m_size , # M
329318 cta_n_size , # N
330319 K , # K
331320 a_scale_ptr , b_scale_ptr ,
332321 # Strides
333- stride_am , stride_ak , stride_bk , stride_bn , stride_cm , stride_cn , stride_ase , stride_asm , stride_ask , stride_bse , stride_bsk , stride_bsn ,
322+ stride_ak , stride_bk , stride_ase , stride_asm , stride_ask , stride_bse , stride_bsk , stride_bsn ,
334323 # offsets
335324 offs_bn ,
336325 # Blockwise quantization data
@@ -513,8 +502,13 @@ def get_batched_mm_benchmark(
513502 Returns a Mark object containing a Benchmark object for batched matrix multiplication.
514503 """
515504 supported_providers = {
505+ 'triton' : 'triton' ,
516506 'triton-td' : 'triton-td' ,
507+ 'pytorch' : 'pytorch' ,
517508 }
509+ if fp8 :
510+ # pytorch is very slow with fp8 case, for (8, 64, 1024, 2048) case it has ~0.15 TFlops vs 1.5 for triton
511+ del supported_providers ['pytorch' ]
518512
519513 providers = benchmark_suite .filter_providers (supported_providers , providers_filter )
520514 configs = MM_CONFIGS_FP8 if fp8 else MM_CONFIGS_BF16
0 commit comments