@@ -314,8 +314,6 @@ def tl_matmul_chunk_truncate(
314314 if cuda_cc [0 ] >= 9 or cuda_cc == (8 , 9 ):
315315 allowed_dtypes += DTYPE_F8
316316 assert a .dtype in allowed_dtypes , "Input dtype is not supported"
317- M , K = a .shape
318- K , N = b .shape
319317
320318 # Allocates output, always accumulate in FP32 (if floats) or INT32 then cast
321319 def isPowerofTwo (x ):
@@ -325,7 +323,7 @@ def isPowerofTwo(x):
325323 min_chunk_size = 32 if a .dtype in DTYPE_8BIT else 16
326324
327325 # because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
328- # insert 0s in between elements, i.e. pad [m,k] -> [m,2k], [k,n]->[k,2n ], out=[m,n] unchanged.
326+ # insert 0s in between elements, i.e. pad [m,k] -> [m,2k], [k,n]->[2k,n ], out=[m,n] unchanged.
329327 # Do not support I8 or F8 for now. (as F8/FP24 simulation is treated as BF16 currently)
330328 if chunk_size == 8 and a .dtype in [torch .float16 , torch .bfloat16 ]:
331329 a_padded = torch .zeros (a .shape [0 ], a .shape [1 ]* 2 , dtype = a .dtype , device = a .device )
@@ -338,6 +336,8 @@ def isPowerofTwo(x):
338336 else :
339337 chunk_size = max (chunk_size , min_chunk_size ) if isPowerofTwo (chunk_size ) else min_chunk_size
340338
339+ M , K = a .shape
340+ K , N = b .shape
341341 if a .dtype in DTYPE_I8 :
342342 acc_dtype = torch .int32
343343 mm_kernel = imatmul_kernel
0 commit comments