diff --git a/torchao/prototype/moe_training/kernels/mxfp8/comms.py b/torchao/prototype/moe_training/kernels/mxfp8/comms.py index 7c6999fbf1..abd96a6be5 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8/comms.py +++ b/torchao/prototype/moe_training/kernels/mxfp8/comms.py @@ -11,7 +11,10 @@ blockwise_barrier, sync_threads, ) -from torchao.prototype.mx_formats.config import ScaleCalculationMode +from torchao.prototype.mx_formats.kernels import ( + triton_mxfp8_dequant_dim0, + triton_to_mxfp8_dim0, +) from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx @@ -473,11 +476,9 @@ def forward( """ # Quantize input block_size = 32 - input_scales, input_data = to_mx( + input_data, input_scales = triton_to_mxfp8_dim0( input, - elem_dtype=torch.float8_e4m3fn, - block_size=block_size, - scaling_mode=ScaleCalculationMode.RCEIL, + inner_block_size=block_size, ) # Dispatch data (async) @@ -501,20 +502,17 @@ def forward( output_data = torch.ops._c10d_functional.wait_tensor(output_data) # Dequantize output - lowp_dtype = output_data.dtype hp_dtype = input.dtype - hp_output = to_dtype( + triton_hp_output = triton_mxfp8_dequant_dim0( output_data, - output_scales.view(torch.float8_e8m0fnu), - lowp_dtype, - block_size, + output_scales, hp_dtype, + block_size, ) - ctx.input_splits = input_splits ctx.output_splits = output_splits ctx.group = group - return hp_output + return triton_hp_output @staticmethod def backward(ctx, grad_output_hp): @@ -529,11 +527,9 @@ def backward(ctx, grad_output_hp): # Quantize grad_output block_size = 32 - grad_out_scales, grad_out_data = to_mx( + grad_out_data, grad_out_scales = triton_to_mxfp8_dim0( grad_output_hp, - elem_dtype=torch.float8_e4m3fn, - block_size=block_size, - scaling_mode=ScaleCalculationMode.RCEIL, + inner_block_size=block_size, ) # Dispatch data (async) @@ -557,13 +553,11 @@ def backward(ctx, grad_output_hp): grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales) hp_dtype = grad_output_hp.dtype - lowp_dtype = grad_input_data.dtype - grad_input_hp = to_dtype( + grad_input_hp = triton_mxfp8_dequant_dim0( grad_input_data, - grad_input_scales.view(torch.float8_e8m0fnu), - lowp_dtype, - block_size, + grad_input_scales, hp_dtype, + block_size, ) return grad_input_hp, None, None, None