@@ -258,15 +258,16 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, #
258258 if start_pid < num_tiles % NUM_SMS :
259259 tiles_per_SM += 1
260260
261+ # NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being
262+ # used in both the prologue and epilogue, so we duplicate the counters as a work-around.
261263 tile_id = start_pid - NUM_SMS
264+ tile_id_c = start_pid - NUM_SMS
262265 ki = - 1
263266
264267 offs_k_for_mask = tl .arange (0 , BLOCK_SIZE_K )
265268
266269 num_pid_in_group = GROUP_SIZE_M * num_pid_n
267270
268- pid_m = 0
269- pid_n = 0
270271 offs_am = tl .arange (0 , BLOCK_SIZE_M )
271272 offs_bn = tl .arange (0 , BLOCK_SIZE_N )
272273
@@ -293,6 +294,8 @@ def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, #
293294 accumulator = tl .dot (a , b , accumulator )
294295
295296 if ki == k_tiles - 1 :
297+ tile_id_c , pid_m , pid_n = _compute_tile_and_pid (tile_id_c , num_pid_in_group , num_pid_m , GROUP_SIZE_M ,
298+ NUM_SMS )
296299 offs_cm = pid_m * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
297300 offs_cn = pid_n * BLOCK_SIZE_N + tl .arange (0 , BLOCK_SIZE_N )
298301 c_ptrs = c_ptr + stride_cm * offs_cm [:, None ] + stride_cn * offs_cn [None , :]
@@ -366,8 +369,6 @@ def matmul_kernel_tma_persistent(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
366369 tiles_per_SM += 1
367370
368371 tile_id = start_pid - NUM_SMS
369- # tile_id_c is used in the epilogue to break the dependency between
370- # the prologue and the epilogue
371372 tile_id_c = start_pid - NUM_SMS
372373
373374 ki = - 1
0 commit comments