Skip to content

Commit 04dc719

Browse files
sudhu2kSudharshan Govindan
andauthored
Add fused bias support for GMM and bias‑gradient/accumulate support for TGMM (ROCm#1541)
* Initial commit * Added TGMM test and fixed the wrapper * Added test for accumulate * atol and rtol changes, comment fixes * Addressing PR comments * Added benchmark script modifications * Added bias and accumulate to non persistent * Refactor TGMM test: enabled bias grad test for nptgmm and relaxed tolerance for gfx950 * Removed accumulate boolean for bias grad and added tests for accumulate. * Applied black formatting * Adding accumulate configs because of triton.runtime.errors.OutOfResources: out of resource: shared memory (Required: 131072, Hardware limit: 65536) when accumulate = True * Fixed black formatting --------- Co-authored-by: Sudharshan Govindan <sugovind@amd.com>
1 parent f6fad2f commit 04dc719

File tree

7 files changed

+493
-42
lines changed

7 files changed

+493
-42
lines changed

aiter/ops/triton/_triton_kernels/gmm.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525

2626

2727
@functools.lru_cache()
28-
def get_config(gmm_type: str, M: int, K: int, N: int, G: int) -> dict[str, int]:
28+
def get_config(
29+
gmm_type: str, M: int, K: int, N: int, G: int, accumulate: bool = False
30+
) -> dict[str, int]:
2931
assert gmm_type in {
3032
"gmm",
3133
"ptgmm",
@@ -49,7 +51,8 @@ def get_config(gmm_type: str, M: int, K: int, N: int, G: int) -> dict[str, int]:
4951
assert (
5052
"default" in get_config._config_dict[gmm_type]
5153
), "Default configuration is absent."
52-
return get_config._config_dict[gmm_type]["default"]
54+
key = "accumulate" if accumulate else "default"
55+
return get_config._config_dict[gmm_type][key]
5356

5457

5558
# Common code shared by GMM and TGMM kernels.
@@ -90,6 +93,7 @@ def gmm_kernel(
9093
rhs_ptr,
9194
group_sizes_ptr,
9295
out_ptr,
96+
bias_ptr,
9397
# Tensor shapes:
9498
M: int,
9599
K: int,
@@ -103,6 +107,7 @@ def gmm_kernel(
103107
K_DIVISIBLE_BY_BLOCK_SIZE_K: tl.constexpr,
104108
GROUP_SIZE: tl.constexpr,
105109
GRID_DIM: tl.constexpr,
110+
USE_BIAS: tl.constexpr,
106111
):
107112
tl.assume(M > 0)
108113
tl.assume(K > 0)
@@ -204,6 +209,19 @@ def gmm_kernel(
204209
else:
205210
rhs_ptrs += BLOCK_SIZE_K * N
206211

212+
# Add bias if enabled
213+
if USE_BIAS:
214+
offs_bias_n = tile_n.to(tl.int64) * BLOCK_SIZE_N + tl.arange(
215+
0, BLOCK_SIZE_N
216+
)
217+
bias_ptrs = bias_ptr + g.to(tl.int64) * N + offs_bias_n
218+
bias = tl.load(bias_ptrs, mask=offs_bias_n < N, other=0.0)
219+
# Convert bias to float32 to match accumulator precision
220+
bias = bias.to(tl.float32)
221+
# Broadcast bias across M dimension and add in float32
222+
acc += bias[None, :]
223+
224+
# Convert to output dtype after all computations
207225
acc = acc.to(out_ptr.type.element_ty)
208226

209227
offs_out_m = tile_m.to(tl.int64) * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
@@ -246,6 +264,7 @@ def tgmm_persistent_kernel(
246264
rhs_ptr,
247265
group_sizes_ptr,
248266
out_ptr,
267+
bias_grad_ptr,
249268
# Tensor shapes:
250269
M: int,
251270
K: int,
@@ -258,6 +277,8 @@ def tgmm_persistent_kernel(
258277
BLOCK_SIZE_N: tl.constexpr,
259278
GROUP_SIZE: tl.constexpr,
260279
GRID_DIM: tl.constexpr,
280+
COMPUTE_BIAS_GRAD: tl.constexpr,
281+
ACCUMULATE: tl.constexpr,
261282
):
262283
tl.assume(M > 0)
263284
tl.assume(K > 0)
@@ -334,12 +355,21 @@ def tgmm_persistent_kernel(
334355

335356
acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
336357

358+
# Initialize bias accumulator
359+
bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
360+
337361
for _ in range(0, loop_m):
338362
lhs = tl.load(lhs_ptrs)
339363
rhs = tl.load(rhs_ptrs)
340364

341365
acc += tl.dot(lhs, rhs, input_precision="ieee")
342366

367+
# Accumulate for bias gradient: sum lhs across M dimension
368+
if COMPUTE_BIAS_GRAD and tile_n == 0:
369+
bias_acc += tl.sum(
370+
lhs, axis=1
371+
) # Sum across M dimension [K, M] -> [K]
372+
343373
if TRANS_LHS:
344374
lhs_ptrs += BLOCK_SIZE_M * K
345375
else:
@@ -359,6 +389,10 @@ def tgmm_persistent_kernel(
359389
rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
360390
acc += tl.dot(lhs, rhs, input_precision="ieee")
361391

392+
# Accumulate last chunk for bias gradient
393+
if COMPUTE_BIAS_GRAD and tile_n == 0:
394+
bias_acc += tl.sum(lhs, axis=1)
395+
362396
acc = acc.to(out_ptr.type.element_ty)
363397

364398
offs_out_k = tile_k.to(tl.int64) * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
@@ -371,11 +405,23 @@ def tgmm_persistent_kernel(
371405
+ offs_out_n[None, :]
372406
)
373407

374-
tl.store(
375-
out_ptrs,
376-
acc,
377-
mask=(offs_out_k[:, None] < K) & (offs_out_n[None, :] < N),
378-
)
408+
mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
409+
if ACCUMULATE:
410+
# Load existing values and add to them (like beta=1 in BLAS)
411+
old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
412+
tl.store(out_ptrs, acc + old_vals, mask=mask)
413+
else:
414+
# Overwrite output (like beta=0 in BLAS)
415+
tl.store(out_ptrs, acc, mask=mask)
416+
417+
# Store bias gradient (only for first N tile, sum across all M)
418+
if COMPUTE_BIAS_GRAD and tile_n == 0:
419+
# Keep as float32 for atomic_add (bf16 not supported for atomics)
420+
bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
421+
# Use atomic add since multiple K-tiles may write to same expert's bias
422+
tl.atomic_add(
423+
bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed"
424+
)
379425

380426
# Go to the next tile by advancing number of programs.
381427
tile += GRID_DIM
@@ -405,6 +451,7 @@ def tgmm_non_persistent_kernel(
405451
rhs_ptr,
406452
group_sizes_ptr,
407453
out_ptr,
454+
bias_grad_ptr,
408455
# Tensor shapes:
409456
M: int,
410457
K: int,
@@ -417,6 +464,8 @@ def tgmm_non_persistent_kernel(
417464
BLOCK_SIZE_K: tl.constexpr,
418465
BLOCK_SIZE_N: tl.constexpr,
419466
GROUP_SIZE: tl.constexpr,
467+
COMPUTE_BIAS_GRAD: tl.constexpr,
468+
ACCUMULATE: tl.constexpr,
420469
):
421470
tl.assume(M > 0)
422471
tl.assume(K > 0)
@@ -477,13 +526,19 @@ def tgmm_non_persistent_kernel(
477526
loop_m -= 1
478527

479528
acc = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=tl.float32)
529+
# Initialize bias accumulator
530+
bias_acc = tl.zeros((BLOCK_SIZE_K,), dtype=tl.float32)
480531

481532
for _ in range(0, loop_m):
482533
lhs = tl.load(lhs_ptrs)
483534
rhs = tl.load(rhs_ptrs)
484535

485536
acc += tl.dot(lhs, rhs, input_precision="ieee")
486537

538+
# Accumulate for bias gradient: sum lhs across M dimension
539+
if COMPUTE_BIAS_GRAD and tile_n == 0:
540+
bias_acc += tl.sum(lhs, axis=1) # [K, M] -> [K]
541+
487542
if TRANS_LHS:
488543
lhs_ptrs += BLOCK_SIZE_M * K
489544
else:
@@ -502,6 +557,9 @@ def tgmm_non_persistent_kernel(
502557
lhs = tl.load(lhs_ptrs, mask=offs_m[None, :] < m, other=0)
503558
rhs = tl.load(rhs_ptrs, mask=offs_m[:, None] < m, other=0)
504559
acc += tl.dot(lhs, rhs, input_precision="ieee")
560+
# Accumulate last chunk for bias gradient
561+
if COMPUTE_BIAS_GRAD and tile_n == 0:
562+
bias_acc += tl.sum(lhs, axis=1)
505563

506564
acc = acc.to(out_ptr.type.element_ty)
507565

@@ -512,8 +570,18 @@ def tgmm_non_persistent_kernel(
512570
out_ptr + g.to(tl.int64) * K * N + offs_out_k[:, None] * N + offs_out_n[None, :]
513571
)
514572

515-
tl.store(
516-
out_ptrs,
517-
acc,
518-
mask=(offs_out_k[:, None] < K) & (offs_out_n[None, :] < N),
519-
)
573+
mask = (offs_out_k[:, None] < K) & (offs_out_n[None, :] < N)
574+
if ACCUMULATE:
575+
# Load existing values and add to them (like beta=1 in BLAS)
576+
old_vals = tl.load(out_ptrs, mask=mask, other=0.0)
577+
tl.store(out_ptrs, acc + old_vals, mask=mask)
578+
else:
579+
# Overwrite output (like beta=0 in BLAS)
580+
tl.store(out_ptrs, acc, mask=mask)
581+
582+
# Store bias gradient (only for first N tile, sum across all M)
583+
if COMPUTE_BIAS_GRAD and tile_n == 0:
584+
# Keep as float32 for atomic_add (bf16/fp16 not supported for atomics)
585+
bias_grad_ptrs = bias_grad_ptr + g.to(tl.int64) * K + offs_out_k
586+
# Use atomic add since multiple K-tiles may write to same expert's bias
587+
tl.atomic_add(bias_grad_ptrs, bias_acc, mask=offs_out_k < K, sem="relaxed")

aiter/ops/triton/configs/gfx942-GMM.json

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
"GRID_DIM": 304,
2020
"num_warps": 8,
2121
"num_stages": 1
22+
},
23+
"accumulate": {
24+
"BLOCK_SIZE_M": 64,
25+
"BLOCK_SIZE_K": 128,
26+
"BLOCK_SIZE_N": 128,
27+
"GROUP_SIZE": 1,
28+
"GRID_DIM": 304,
29+
"num_warps": 8,
30+
"num_stages": 1
2231
}
2332
},
2433
"nptgmm": {
@@ -29,6 +38,14 @@
2938
"GROUP_SIZE": 1,
3039
"num_warps": 8,
3140
"num_stages": 1
41+
},
42+
"accumulate": {
43+
"BLOCK_SIZE_M": 64,
44+
"BLOCK_SIZE_K": 128,
45+
"BLOCK_SIZE_N": 128,
46+
"GROUP_SIZE": 1,
47+
"num_warps": 8,
48+
"num_stages": 1
3249
}
3350
}
3451
}

aiter/ops/triton/configs/gfx950-GMM.json

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@
1919
"GRID_DIM": 256,
2020
"num_warps": 8,
2121
"num_stages": 1
22+
},
23+
"accumulate": {
24+
"BLOCK_SIZE_M": 64,
25+
"BLOCK_SIZE_K": 128,
26+
"BLOCK_SIZE_N": 128,
27+
"GROUP_SIZE": 1,
28+
"GRID_DIM": 256,
29+
"num_warps": 8,
30+
"num_stages": 1
2231
}
2332
},
2433
"nptgmm": {
@@ -29,6 +38,14 @@
2938
"GROUP_SIZE": 1,
3039
"num_warps": 8,
3140
"num_stages": 1
41+
},
42+
"accumulate": {
43+
"BLOCK_SIZE_M": 64,
44+
"BLOCK_SIZE_K": 128,
45+
"BLOCK_SIZE_N": 128,
46+
"GROUP_SIZE": 1,
47+
"num_warps": 8,
48+
"num_stages": 1
3249
}
3350
}
3451
}

0 commit comments

Comments
 (0)