2525
2626
2727@functools .lru_cache ()
28- def get_config (gmm_type : str , M : int , K : int , N : int , G : int ) -> dict [str , int ]:
28+ def get_config (
29+ gmm_type : str , M : int , K : int , N : int , G : int , accumulate : bool = False
30+ ) -> dict [str , int ]:
2931 assert gmm_type in {
3032 "gmm" ,
3133 "ptgmm" ,
@@ -49,7 +51,8 @@ def get_config(gmm_type: str, M: int, K: int, N: int, G: int) -> dict[str, int]:
4951 assert (
5052 "default" in get_config ._config_dict [gmm_type ]
5153 ), "Default configuration is absent."
52- return get_config ._config_dict [gmm_type ]["default" ]
54+ key = "accumulate" if accumulate else "default"
55+ return get_config ._config_dict [gmm_type ][key ]
5356
5457
5558# Common code shared by GMM and TGMM kernels.
@@ -90,6 +93,7 @@ def gmm_kernel(
9093 rhs_ptr ,
9194 group_sizes_ptr ,
9295 out_ptr ,
96+ bias_ptr ,
9397 # Tensor shapes:
9498 M : int ,
9599 K : int ,
@@ -103,6 +107,7 @@ def gmm_kernel(
103107 K_DIVISIBLE_BY_BLOCK_SIZE_K : tl .constexpr ,
104108 GROUP_SIZE : tl .constexpr ,
105109 GRID_DIM : tl .constexpr ,
110+ USE_BIAS : tl .constexpr ,
106111):
107112 tl .assume (M > 0 )
108113 tl .assume (K > 0 )
@@ -204,6 +209,19 @@ def gmm_kernel(
204209 else :
205210 rhs_ptrs += BLOCK_SIZE_K * N
206211
212+ # Add bias if enabled
213+ if USE_BIAS :
214+ offs_bias_n = tile_n .to (tl .int64 ) * BLOCK_SIZE_N + tl .arange (
215+ 0 , BLOCK_SIZE_N
216+ )
217+ bias_ptrs = bias_ptr + g .to (tl .int64 ) * N + offs_bias_n
218+ bias = tl .load (bias_ptrs , mask = offs_bias_n < N , other = 0.0 )
219+ # Convert bias to float32 to match accumulator precision
220+ bias = bias .to (tl .float32 )
221+ # Broadcast bias across M dimension and add in float32
222+ acc += bias [None , :]
223+
224+ # Convert to output dtype after all computations
207225 acc = acc .to (out_ptr .type .element_ty )
208226
209227 offs_out_m = tile_m .to (tl .int64 ) * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
@@ -246,6 +264,7 @@ def tgmm_persistent_kernel(
246264 rhs_ptr ,
247265 group_sizes_ptr ,
248266 out_ptr ,
267+ bias_grad_ptr ,
249268 # Tensor shapes:
250269 M : int ,
251270 K : int ,
@@ -258,6 +277,8 @@ def tgmm_persistent_kernel(
258277 BLOCK_SIZE_N : tl .constexpr ,
259278 GROUP_SIZE : tl .constexpr ,
260279 GRID_DIM : tl .constexpr ,
280+ COMPUTE_BIAS_GRAD : tl .constexpr ,
281+ ACCUMULATE : tl .constexpr ,
261282):
262283 tl .assume (M > 0 )
263284 tl .assume (K > 0 )
@@ -334,12 +355,21 @@ def tgmm_persistent_kernel(
334355
335356 acc = tl .zeros ((BLOCK_SIZE_K , BLOCK_SIZE_N ), dtype = tl .float32 )
336357
358+ # Initialize bias accumulator
359+ bias_acc = tl .zeros ((BLOCK_SIZE_K ,), dtype = tl .float32 )
360+
337361 for _ in range (0 , loop_m ):
338362 lhs = tl .load (lhs_ptrs )
339363 rhs = tl .load (rhs_ptrs )
340364
341365 acc += tl .dot (lhs , rhs , input_precision = "ieee" )
342366
367+ # Accumulate for bias gradient: sum lhs across M dimension
368+ if COMPUTE_BIAS_GRAD and tile_n == 0 :
369+ bias_acc += tl .sum (
370+ lhs , axis = 1
371+ ) # Sum across M dimension [K, M] -> [K]
372+
343373 if TRANS_LHS :
344374 lhs_ptrs += BLOCK_SIZE_M * K
345375 else :
@@ -359,6 +389,10 @@ def tgmm_persistent_kernel(
359389 rhs = tl .load (rhs_ptrs , mask = offs_m [:, None ] < m , other = 0 )
360390 acc += tl .dot (lhs , rhs , input_precision = "ieee" )
361391
392+ # Accumulate last chunk for bias gradient
393+ if COMPUTE_BIAS_GRAD and tile_n == 0 :
394+ bias_acc += tl .sum (lhs , axis = 1 )
395+
362396 acc = acc .to (out_ptr .type .element_ty )
363397
364398 offs_out_k = tile_k .to (tl .int64 ) * BLOCK_SIZE_K + tl .arange (0 , BLOCK_SIZE_K )
@@ -371,11 +405,23 @@ def tgmm_persistent_kernel(
371405 + offs_out_n [None , :]
372406 )
373407
374- tl .store (
375- out_ptrs ,
376- acc ,
377- mask = (offs_out_k [:, None ] < K ) & (offs_out_n [None , :] < N ),
378- )
408+ mask = (offs_out_k [:, None ] < K ) & (offs_out_n [None , :] < N )
409+ if ACCUMULATE :
410+ # Load existing values and add to them (like beta=1 in BLAS)
411+ old_vals = tl .load (out_ptrs , mask = mask , other = 0.0 )
412+ tl .store (out_ptrs , acc + old_vals , mask = mask )
413+ else :
414+ # Overwrite output (like beta=0 in BLAS)
415+ tl .store (out_ptrs , acc , mask = mask )
416+
417+ # Store bias gradient (only for first N tile, sum across all M)
418+ if COMPUTE_BIAS_GRAD and tile_n == 0 :
419+ # Keep as float32 for atomic_add (bf16 not supported for atomics)
420+ bias_grad_ptrs = bias_grad_ptr + g .to (tl .int64 ) * K + offs_out_k
421+ # Use atomic add since multiple K-tiles may write to same expert's bias
422+ tl .atomic_add (
423+ bias_grad_ptrs , bias_acc , mask = offs_out_k < K , sem = "relaxed"
424+ )
379425
380426 # Go to the next tile by advancing number of programs.
381427 tile += GRID_DIM
@@ -405,6 +451,7 @@ def tgmm_non_persistent_kernel(
405451 rhs_ptr ,
406452 group_sizes_ptr ,
407453 out_ptr ,
454+ bias_grad_ptr ,
408455 # Tensor shapes:
409456 M : int ,
410457 K : int ,
@@ -417,6 +464,8 @@ def tgmm_non_persistent_kernel(
417464 BLOCK_SIZE_K : tl .constexpr ,
418465 BLOCK_SIZE_N : tl .constexpr ,
419466 GROUP_SIZE : tl .constexpr ,
467+ COMPUTE_BIAS_GRAD : tl .constexpr ,
468+ ACCUMULATE : tl .constexpr ,
420469):
421470 tl .assume (M > 0 )
422471 tl .assume (K > 0 )
@@ -477,13 +526,19 @@ def tgmm_non_persistent_kernel(
477526 loop_m -= 1
478527
479528 acc = tl .zeros ((BLOCK_SIZE_K , BLOCK_SIZE_N ), dtype = tl .float32 )
529+ # Initialize bias accumulator
530+ bias_acc = tl .zeros ((BLOCK_SIZE_K ,), dtype = tl .float32 )
480531
481532 for _ in range (0 , loop_m ):
482533 lhs = tl .load (lhs_ptrs )
483534 rhs = tl .load (rhs_ptrs )
484535
485536 acc += tl .dot (lhs , rhs , input_precision = "ieee" )
486537
538+ # Accumulate for bias gradient: sum lhs across M dimension
539+ if COMPUTE_BIAS_GRAD and tile_n == 0 :
540+ bias_acc += tl .sum (lhs , axis = 1 ) # [K, M] -> [K]
541+
487542 if TRANS_LHS :
488543 lhs_ptrs += BLOCK_SIZE_M * K
489544 else :
@@ -502,6 +557,9 @@ def tgmm_non_persistent_kernel(
502557 lhs = tl .load (lhs_ptrs , mask = offs_m [None , :] < m , other = 0 )
503558 rhs = tl .load (rhs_ptrs , mask = offs_m [:, None ] < m , other = 0 )
504559 acc += tl .dot (lhs , rhs , input_precision = "ieee" )
560+ # Accumulate last chunk for bias gradient
561+ if COMPUTE_BIAS_GRAD and tile_n == 0 :
562+ bias_acc += tl .sum (lhs , axis = 1 )
505563
506564 acc = acc .to (out_ptr .type .element_ty )
507565
@@ -512,8 +570,18 @@ def tgmm_non_persistent_kernel(
512570 out_ptr + g .to (tl .int64 ) * K * N + offs_out_k [:, None ] * N + offs_out_n [None , :]
513571 )
514572
515- tl .store (
516- out_ptrs ,
517- acc ,
518- mask = (offs_out_k [:, None ] < K ) & (offs_out_n [None , :] < N ),
519- )
573+ mask = (offs_out_k [:, None ] < K ) & (offs_out_n [None , :] < N )
574+ if ACCUMULATE :
575+ # Load existing values and add to them (like beta=1 in BLAS)
576+ old_vals = tl .load (out_ptrs , mask = mask , other = 0.0 )
577+ tl .store (out_ptrs , acc + old_vals , mask = mask )
578+ else :
579+ # Overwrite output (like beta=0 in BLAS)
580+ tl .store (out_ptrs , acc , mask = mask )
581+
582+ # Store bias gradient (only for first N tile, sum across all M)
583+ if COMPUTE_BIAS_GRAD and tile_n == 0 :
584+ # Keep as float32 for atomic_add (bf16/fp16 not supported for atomics)
585+ bias_grad_ptrs = bias_grad_ptr + g .to (tl .int64 ) * K + offs_out_k
586+ # Use atomic add since multiple K-tiles may write to same expert's bias
587+ tl .atomic_add (bias_grad_ptrs , bias_acc , mask = offs_out_k < K , sem = "relaxed" )
0 commit comments