@@ -384,6 +384,17 @@ def _p_matmul_ogs(
384
384
block_shape = [BLOCK_M , OUT_BLOCK_N ],
385
385
)
386
386
387
+ # bias + scale
388
+ offs_y_n = off_n1 + tl .arange (0 , BLOCK_N )
389
+ mask_n = offs_y_n < N
390
+ if B is not None :
391
+ BPtrs = B + expt_id1 * stride_b_e + offs_y_n
392
+ if pid_k1 == 0 :
393
+ bias = tl .load (BPtrs , mask = mask_n , other = 0 )
394
+ else :
395
+ bias = tl .full ([BLOCK_N ], 0 , dtype = tl .float32 )
396
+ else :
397
+ bias = tl .full ([BLOCK_N ], 0 , dtype = tl .float32 )
387
398
if Betas is not None :
388
399
betas = tl .load (Betas + start_m1 + offs_m , mask = mask_m , other = 0.0 )
389
400
else :
@@ -399,15 +410,21 @@ def _p_matmul_ogs(
399
410
w_scale = load_scale (WScale )
400
411
401
412
accs = (acc ,)
413
+ biases = (bias ,)
402
414
403
415
if SUBTILE_FACTOR >= 2 :
404
416
acc0 , acc1 = acc .reshape (BLOCK_M , 2 , BLOCK_N // 2 ).permute (0 , 2 , 1 ).split ()
405
417
accs = (acc0 , acc1 )
418
+ bias0 , bias1 = bias .reshape (2 , BLOCK_N // 2 ).permute (1 , 0 ).split ()
419
+ biases = (bias0 , bias1 )
406
420
407
421
if SUBTILE_FACTOR >= 4 :
408
422
acc00 , acc01 = acc0 .reshape (BLOCK_M , 2 , BLOCK_N // 4 ).permute (0 , 2 , 1 ).split ()
409
423
acc10 , acc11 = acc1 .reshape (BLOCK_M , 2 , BLOCK_N // 4 ).permute (0 , 2 , 1 ).split ()
410
424
accs = (acc00 , acc01 , acc10 , acc11 )
425
+ bias00 , bias01 = bias0 .reshape (2 , BLOCK_N // 4 ).permute (1 , 0 ).split ()
426
+ bias10 , bias11 = bias1 .reshape (2 , BLOCK_N // 4 ).permute (1 , 0 ).split ()
427
+ biases = (bias00 , bias01 , bias10 , bias11 )
411
428
412
429
tl .static_assert (EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR )
413
430
tl .static_assert (len (accs ) == SUBTILE_FACTOR )
@@ -419,18 +436,7 @@ def _p_matmul_ogs(
419
436
if SWAP_XW :
420
437
acc_tile = acc_tile .T
421
438
422
- if B is not None :
423
- offs_y_n = off_n1 + EPILOGUE_BLOCK_N * a_i + tl .arange (0 , EPILOGUE_BLOCK_N )
424
- mask_n = offs_y_n < N
425
- BPtrs = B + expt_id1 * stride_b_e + offs_y_n
426
- if pid_k1 == 0 :
427
- bias = tl .load (BPtrs , mask = mask_n , other = 0 )
428
- else :
429
- bias = tl .full ([EPILOGUE_BLOCK_N ], 0 , dtype = tl .float32 )
430
- else :
431
- bias = tl .full ([EPILOGUE_BLOCK_N ], 0 , dtype = tl .float32 )
432
-
433
- acc_tile = acc_tile + bias [None , :] * betas [:, None ]
439
+ acc_tile = acc_tile + biases [a_i ][None , :] * betas [:, None ]
434
440
if out_alpha is not None :
435
441
acc_tile *= out_alpha
436
442
0 commit comments