3535 use_cuda_graph = True ,
3636)
3737@triton .heuristics ({
38- 'EVEN_K' : lambda args : args ['K' ] % args ['BLOCK_SIZE_K' ] == 0 ,
38+ 'EVEN_K' :
39+ lambda args : args ['K' ] % args ['BLOCK_SIZE_K' ] == 0 , 'GRID_MN' :
40+ lambda args : triton .cdiv (args ['M' ], args ['BLOCK_SIZE_M' ]) * triton .cdiv (args ['N' ], args ['BLOCK_SIZE_N' ])
3941})
4042@triton .jit
4143def matmul_kernel (
@@ -61,11 +63,14 @@ def matmul_kernel(
6163 GROUP_SIZE_M : tl .constexpr ,
6264 APPLY_SCALE : tl .constexpr ,
6365 ACTIVATION : tl .constexpr ,
66+ GRID_MN : tl .constexpr ,
6467):
6568 """Kernel for computing the matmul C = A x B.
6669 A has shape (M, K), B has shape (K, N) and C has shape (M, N)
6770 """
6871
72+ NUM_XCDS : tl .constexpr = 8
73+
6974 tl .assume (stride_am > 0 )
7075 tl .assume (stride_ak > 0 )
7176 tl .assume (stride_bk > 0 )
@@ -80,6 +85,28 @@ def matmul_kernel(
8085 pid = tl .program_id (axis = 0 )
8186 num_pid_m = tl .cdiv (M , BLOCK_SIZE_M )
8287 num_pid_n = tl .cdiv (N , BLOCK_SIZE_N )
88+
89+ ## pid remapping on xcds
90+ # Number of pids per XCD in the new arrangement
91+ pids_per_xcd = (GRID_MN + NUM_XCDS - 1 ) // NUM_XCDS
92+ # When GRID_MN cannot divide NUM_XCDS, some xcds will have
93+ # pids_per_xcd pids, the other will have pids_per_xcd - 1 pids.
94+ # We calculate the number of xcds that have pids_per_xcd pids as
95+ # tall_xcds
96+ tall_xcds = GRID_MN % NUM_XCDS
97+ tall_xcds = NUM_XCDS if tall_xcds == 0 else tall_xcds
98+ # Compute current XCD and local pid within the XCD
99+ xcd = pid % NUM_XCDS
100+ local_pid = pid // NUM_XCDS
101+ # Calculate new pid based on the new grouping
102+ # Note that we need to consider the following two cases:
103+ # 1. the current pid is on a tall xcd
104+ # 2. the current pid is on a short xcd
105+ if xcd < tall_xcds :
106+ pid = xcd * pids_per_xcd + local_pid
107+ else :
108+ pid = tall_xcds * pids_per_xcd + (xcd - tall_xcds ) * (pids_per_xcd - 1 ) + local_pid
109+
83110 if GROUP_SIZE_M == 1 :
84111 pid_m = pid // num_pid_n
85112 pid_n = pid % num_pid_n
0 commit comments