Skip to content

Commit 5bb32e8

Browse files
authored
Change grouping calculation in gemm.py (#732)
* Update gemm.py grouping to account for XCDs * Formatting * Address comments * Pulled in fix from #722
1 parent 83871ea commit 5bb32e8

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

python/perf-kernels/gemm.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
use_cuda_graph=True,
3636
)
3737
@triton.heuristics({
38-
'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0,
38+
'EVEN_K':
39+
lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, 'GRID_MN':
40+
lambda args: triton.cdiv(args['M'], args['BLOCK_SIZE_M']) * triton.cdiv(args['N'], args['BLOCK_SIZE_N'])
3941
})
4042
@triton.jit
4143
def matmul_kernel(
@@ -61,11 +63,14 @@ def matmul_kernel(
6163
GROUP_SIZE_M: tl.constexpr,
6264
APPLY_SCALE: tl.constexpr,
6365
ACTIVATION: tl.constexpr,
66+
GRID_MN: tl.constexpr,
6467
):
6568
"""Kernel for computing the matmul C = A x B.
6669
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
6770
"""
6871

72+
NUM_XCDS: tl.constexpr = 8
73+
6974
tl.assume(stride_am > 0)
7075
tl.assume(stride_ak > 0)
7176
tl.assume(stride_bk > 0)
@@ -80,6 +85,28 @@ def matmul_kernel(
8085
pid = tl.program_id(axis=0)
8186
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
8287
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
88+
89+
## pid remapping on xcds
90+
# Number of pids per XCD in the new arrangement
91+
pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS
92+
# When GRID_MN cannot divide NUM_XCDS, some xcds will have
93+
# pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
94+
# We calculate the number of xcds that have pids_per_xcd pids as
95+
# tall_xcds
96+
tall_xcds = GRID_MN % NUM_XCDS
97+
tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
98+
# Compute current XCD and local pid within the XCD
99+
xcd = pid % NUM_XCDS
100+
local_pid = pid // NUM_XCDS
101+
# Calculate new pid based on the new grouping
102+
# Note that we need to consider the following two cases:
103+
# 1. the current pid is on a tall xcd
104+
# 2. the current pid is on a short xcd
105+
if xcd < tall_xcds:
106+
pid = xcd * pids_per_xcd + local_pid
107+
else:
108+
pid = tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid
109+
83110
if GROUP_SIZE_M == 1:
84111
pid_m = pid // num_pid_n
85112
pid_n = pid % num_pid_n

0 commit comments

Comments
 (0)