@@ -471,20 +471,21 @@ def isPowerofTwo(x):
471471 min_chunk_size = 32 if a .dtype in DTYPE_8BIT else 16
472472
473473 # because min k (chunk size in this case) for fp16/bf16 is 16, if smaller is needed, we could
474- # insert 0s in between elements, i.e. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475- # Do not support I8 or F8 for now. (as F8/FP24 simulation is treated as BF16 currently)
476- if chunk_size == 8 and a .dtype in [torch .float16 , torch .bfloat16 ]:
474+ # insert 0s in between elements, e.g. pad [m,k] -> [m,2k], [k,n]->[2k,n], out=[m,n] unchanged.
475+ # Do not support INT8 for now.
476+ if chunk_size == 8 and a .dtype in [torch .float8_e4m3fn , torch .float16 , torch .bfloat16 ]:
477+ exp_ratio = min_chunk_size // chunk_size
477478 a_padded = torch .zeros (
478- a .shape [0 ], a .shape [1 ] * 2 , dtype = a .dtype , device = a .device
479+ a .shape [0 ], a .shape [1 ] * exp_ratio , dtype = a .dtype , device = a .device
479480 )
480- a_padded [:, ::2 ] = a
481+ a_padded [:, ::exp_ratio ] = a
481482 a = a_padded
482483 b_padded = torch .zeros (
483- b .shape [0 ] * 2 , b .shape [1 ], dtype = b .dtype , device = b .device
484+ b .shape [0 ] * exp_ratio , b .shape [1 ], dtype = b .dtype , device = b .device
484485 )
485- b_padded [::2 , :] = b
486+ b_padded [::exp_ratio , :] = b
486487 b = b_padded
487- chunk_size = 16
488+ chunk_size = min_chunk_size
488489 else :
489490 chunk_size = (
490491 max (chunk_size , min_chunk_size )
0 commit comments