Skip to content

Commit 262db69

Browse files
bug fix, M, N incorrect when using chunk_size 8 padding/dilation
1 parent 0f7201e commit 262db69

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,6 @@ def tl_matmul_chunk_truncate(
314314
if cuda_cc[0] >= 9 or cuda_cc == (8, 9):
315315
allowed_dtypes += DTYPE_F8
316316
assert a.dtype in allowed_dtypes, "Input dtype is not supported"
317-
M, K = a.shape
318-
K, N = b.shape
319317

320318
# Allocates output, always accumulate in FP32 (if floats) or INT32 then cast
321319
def isPowerofTwo(x):
@@ -325,7 +323,7 @@ def isPowerofTwo(x):
325323
min_chunk_size = 32 if a.dtype in DTYPE_8BIT else 16
326324

327325
# because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
328-
# insert 0s in between elements, i.e. pad [m,k] -> [m,2k], [k,n]->[k,2n], out=[m,n] unchanged.
326+
# insert 0s in between elements, i.e. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
329327
# Do not support I8 or F8 for now. (as F8/FP24 simulation is treated as BF16 currently)
330328
if chunk_size == 8 and a.dtype in [torch.float16, torch.bfloat16]:
331329
a_padded = torch.zeros(a.shape[0], a.shape[1]*2, dtype=a.dtype, device=a.device)
@@ -338,6 +336,8 @@ def isPowerofTwo(x):
338336
else:
339337
chunk_size = max(chunk_size, min_chunk_size) if isPowerofTwo(chunk_size) else min_chunk_size
340338

339+
M, K = a.shape
340+
K, N = b.shape
341341
if a.dtype in DTYPE_I8:
342342
acc_dtype = torch.int32
343343
mm_kernel = imatmul_kernel

0 commit comments

Comments
 (0)