Skip to content

Commit aba51f8

Browse files
enable "dilation" for fp8, if chunk_size<32
Signed-off-by: cliu-us <[email protected]>
1 parent e2960d5 commit aba51f8

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

fms_mo/custom_ext_kernels/triton_kernels.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -471,20 +471,21 @@ def isPowerofTwo(x):
471471
min_chunk_size = 32 if a.dtype in DTYPE_8BIT else 16
472472

473473
# because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
474-
# insert 0s in between elements, i.e. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475-
# Do not support I8 or F8 for now. (as F8/FP24 simulation is treated as BF16 currently)
476-
if chunk_size == 8 and a.dtype in [torch.float16, torch.bfloat16]:
474+
# insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475+
# Do not support INT8 for now.
476+
if chunk_size == 8 and a.dtype in [torch.float8_e4m3fn, torch.float16, torch.bfloat16]:
477+
exp_ratio = min_chunk_size//chunk_size
477478
a_padded = torch.zeros(
478-
a.shape[0], a.shape[1] * 2, dtype=a.dtype, device=a.device
479+
a.shape[0], a.shape[1] * exp_ratio, dtype=a.dtype, device=a.device
479480
)
480-
a_padded[:, ::2] = a
481+
a_padded[:, ::exp_ratio] = a
481482
a = a_padded
482483
b_padded = torch.zeros(
483-
b.shape[0] * 2, b.shape[1], dtype=b.dtype, device=b.device
484+
b.shape[0] * exp_ratio, b.shape[1], dtype=b.dtype, device=b.device
484485
)
485-
b_padded[::2, :] = b
486+
b_padded[::exp_ratio, :] = b
486487
b = b_padded
487-
chunk_size = 16
488+
chunk_size = min_chunk_size
488489
else:
489490
chunk_size = (
490491
max(chunk_size, min_chunk_size)

0 commit comments

Comments
 (0)