@@ -384,6 +384,17 @@ def _p_matmul_ogs(
384384 block_shape = [BLOCK_M , OUT_BLOCK_N ],
385385 )
386386
387+ # bias + scale
388+ offs_y_n = off_n1 + tl .arange (0 , BLOCK_N )
389+ mask_n = offs_y_n < N
390+ if B is not None :
391+ BPtrs = B + expt_id1 * stride_b_e + offs_y_n
392+ if pid_k1 == 0 :
393+ bias = tl .load (BPtrs , mask = mask_n , other = 0 )
394+ else :
395+ bias = tl .full ([BLOCK_N ], 0 , dtype = tl .float32 )
396+ else :
397+ bias = tl .full ([BLOCK_N ], 0 , dtype = tl .float32 )
387398 if Betas is not None :
388399 betas = tl .load (Betas + start_m1 + offs_m , mask = mask_m , other = 0.0 )
389400 else :
@@ -399,15 +410,21 @@ def _p_matmul_ogs(
399410 w_scale = load_scale (WScale )
400411
401412 accs = (acc ,)
413+ biases = (bias ,)
402414
403415 if SUBTILE_FACTOR >= 2 :
404416 acc0 , acc1 = acc .reshape (BLOCK_M , 2 , BLOCK_N // 2 ).permute (0 , 2 , 1 ).split ()
405417 accs = (acc0 , acc1 )
418+ bias0 , bias1 = bias .reshape (2 , BLOCK_N // 2 ).permute (1 , 0 ).split ()
419+ biases = (bias0 , bias1 )
406420
407421 if SUBTILE_FACTOR >= 4 :
408422 acc00 , acc01 = acc0 .reshape (BLOCK_M , 2 , BLOCK_N // 4 ).permute (0 , 2 , 1 ).split ()
409423 acc10 , acc11 = acc1 .reshape (BLOCK_M , 2 , BLOCK_N // 4 ).permute (0 , 2 , 1 ).split ()
410424 accs = (acc00 , acc01 , acc10 , acc11 )
425+ bias00 , bias01 = bias0 .reshape (2 , BLOCK_N // 4 ).permute (1 , 0 ).split ()
426+ bias10 , bias11 = bias1 .reshape (2 , BLOCK_N // 4 ).permute (1 , 0 ).split ()
427+ biases = (bias00 , bias01 , bias10 , bias11 )
411428
412429 tl .static_assert (EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR )
413430 tl .static_assert (len (accs ) == SUBTILE_FACTOR )
@@ -419,18 +436,7 @@ def _p_matmul_ogs(
419436 if SWAP_XW :
420437 acc_tile = acc_tile .T
421438
422- if B is not None :
423- offs_y_n = off_n1 + EPILOGUE_BLOCK_N * a_i + tl .arange (0 , EPILOGUE_BLOCK_N )
424- mask_n = offs_y_n < N
425- BPtrs = B + expt_id1 * stride_b_e + offs_y_n
426- if pid_k1 == 0 :
427- bias = tl .load (BPtrs , mask = mask_n , other = 0 )
428- else :
429- bias = tl .full ([EPILOGUE_BLOCK_N ], 0 , dtype = tl .float32 )
430- else :
431- bias = tl .full ([EPILOGUE_BLOCK_N ], 0 , dtype = tl .float32 )
432-
433- acc_tile = acc_tile + bias [None , :] * betas [:, None ]
439+ acc_tile = acc_tile + biases [a_i ][None , :] * betas [:, None ]
434440 if out_alpha is not None :
435441 acc_tile *= out_alpha
436442
0 commit comments