Skip to content

Commit a18fc87

Browse files
authored
Fix pid remapping logic when GRID_MN cannot divide NUM_XCDS (#722)
1 parent c7fea1b commit a18fc87

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

python/perf-kernels/tools/occ.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ get_occ_per_CU() {
3939

4040
$1 > output.mlir 2>&1
4141

42-
LDS_line=$(sed -n '/triton_gpu\.shared\ /p' output.mlir | tail -n 1 | grep -o 'triton_gpu.shared = [0-9]*')
43-
numWarps_line=$(sed -n '/triton_gpu\.num-warps/p' output.mlir | tail -n 1 | grep -o 'triton_gpu.num-warps. = [0-9]*')
42+
LDS_line=$(sed -n '/ttg\.shared\ /p' output.mlir | tail -n 1 | grep -o 'ttg.shared = [0-9]*')
43+
numWarps_line=$(sed -n '/ttg\.num-warps/p' output.mlir | tail -n 1 | grep -o 'ttg.num-warps. = [0-9]*')
4444

4545
LDS=${LDS_line##*=}
4646
num_warps=${numWarps_line##*=}

python/perf-kernels/tools/tune_gemm/matmul_kernel.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,23 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
2525
## pid remapping on xcds
2626
# Number of pids per XCD in the new arrangement
2727
pids_per_xcd = (GRID_MN + NUM_XCDS - 1) // NUM_XCDS
28+
# When GRID_MN cannot divide NUM_XCDS, some xcds will have
29+
# pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
30+
# We calculate the number of xcds that have pids_per_xcd pids as
31+
# tall_xcds
32+
tall_xcds = GRID_MN % NUM_XCDS
33+
tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
2834
# Compute current XCD and local pid within the XCD
2935
xcd = pid % NUM_XCDS
3036
local_pid = pid // NUM_XCDS
3137
# Calculate new pid based on the new grouping
32-
pid = xcd * pids_per_xcd + local_pid
38+
# Note that we need to consider the following two cases:
39+
# 1. the currnt pid is on a tall xcd
40+
# 2. the current pid is on a short xcd
41+
if xcd < tall_xcds:
42+
pid = xcd * pids_per_xcd + local_pid
43+
else:
44+
pid = tall_xcds * pids_per_xcd + (xcd - tall_xcds) * (pids_per_xcd - 1) + local_pid
3345

3446
if GROUP_SIZE_M == 1:
3547
pid_m = pid // num_pid_n

0 commit comments

Comments
 (0)