Skip to content

Commit d61d915

Browse files
authored
Apply relaxed mod to streamk splitk atomic op (#3706)
1 parent 86dffe2 commit d61d915

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

benchmarks/triton_kernels_benchmark/gemm_splitk_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _kernel(A, B, C, #
6161
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
6262
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
6363
mask = (rm < M)[:, None] & (rn < N)[None, :]
64-
tl.atomic_add(C, acc, mask=mask)
64+
tl.atomic_add(C, acc, mask=mask, sem='relaxed')
6565

6666

6767
class _matmul(torch.autograd.Function):

benchmarks/triton_kernels_benchmark/gemm_streamk_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def mac_loop(
9494
rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
9595
c_ptr_ = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn
9696
mask = (rm < M)[:, None] & (rn < N)[None, :]
97-
tl.atomic_add(c_ptr_, acc, mask=mask)
97+
tl.atomic_add(c_ptr_, acc, mask=mask, sem='relaxed')
9898

9999

100100
@triton.autotune(

0 commit comments

Comments
 (0)