diff --git a/modelopt/onnx/export/mxfp8_exporter.py b/modelopt/onnx/export/mxfp8_exporter.py index 360e02b4f..8c1e1f4df 100644 --- a/modelopt/onnx/export/mxfp8_exporter.py +++ b/modelopt/onnx/export/mxfp8_exporter.py @@ -15,27 +15,185 @@ """MXFP8 quantization exporter.""" +import numpy as np import onnx +from onnx import numpy_helper + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.graph_utils import get_tensor_producer_nodes +from modelopt.onnx.quantization.qdq_utils import _cast_fp8, onnx_dtype_map +from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax +from modelopt.onnx.utils import get_attribute, has_attribute from .base_exporter import ONNXQuantExporter +E8_M0_BIAS = 127 +DEFAULT_BLOCK_SIZE = 32 +DEFAULT_QUANT_AXIS = -1 + + +def _get_weight_dq_nodes(graph: onnx.GraphProto) -> list[onnx.NodeProto]: + """Get weight DequantizeLinear nodes from the graph.""" + return [ + node + for node in graph.node + if node.op_type == "TRT_MXFP8DequantizeLinear" + and any(".weight" in inp for inp in node.input) + ] + + +def _get_quant_params(node: onnx.NodeProto) -> tuple[int, int]: + """Extract quantization axis and block size from a node.""" + if has_attribute(node, "axis"): + quant_axis = int(get_attribute(node, "axis")) + else: + quant_axis = DEFAULT_QUANT_AXIS + logger.warning( + "axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1" + ) + + if has_attribute(node, "block_size"): + block_size = int(get_attribute(node, "block_size")) + else: + block_size = DEFAULT_BLOCK_SIZE + logger.warning( + "block_size attribute not found for MXFP8DequantizeLinear node. " + "Setting block_size to 32" + ) + + return quant_axis, block_size + -# TODO: Implement the MXFP8QuantExporter class MXFP8QuantExporter(ONNXQuantExporter): """Exporter for MXFP8 quantization.""" @staticmethod def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Pre-processes the ONNX model for MXFP8 quantization.""" + return onnx_model @staticmethod def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Computes the scales for the weights in the ONNX model for MXFP8 quantization.""" + """Computes the e8m0 scales for weights in the ONNX model for MXFP8 quantization.""" + logger.info("Computing MXFP8 scales for weights") + graph = onnx_model.graph + initializer_map = {init.name: init for init in graph.initializer} + tensor_producer_map = get_tensor_producer_nodes(graph) + + for node in _get_weight_dq_nodes(graph): + weight_name = node.input[0] + logger.debug(f"Computing MXFP8 scale for weight {weight_name}") + + weight = numpy_helper.to_array(initializer_map[weight_name]) + quant_axis, block_size = _get_quant_params(node) + + # Compute scales + amax = get_amax(weight, quant_axis, block_size) + se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size) + se8m0 = se8m0_fp32.astype(np.uint8) + + # Remove scale producer if it's a Constant node + scale_name = node.input[1] + scale_producer = tensor_producer_map[scale_name] + if scale_producer.op_type == "Constant": + graph.node.remove(scale_producer) + + # Create and add new scale tensor + scale_name_new = scale_name.replace("Constant_output_0", "scale") + scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name_new) + graph.initializer.append(scale_tensor) + node.input[1] = scale_name_new + + return onnx_model @staticmethod def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Compresses the weights in the ONNX model for MXFP8 quantization.""" + """Compresses the weights in the ONNX model to FP8 format for MXFP8 quantization.""" + logger.info("Compressing weights to MXFP8 format") + graph = onnx_model.graph + initializer_map = {init.name: init for init in graph.initializer} + + for node in _get_weight_dq_nodes(graph): + weight_name = node.input[0] + scale_name = node.input[1] + logger.debug(f"Compressing weight {weight_name} to MXFP8") + + weight = numpy_helper.to_array(initializer_map[weight_name]) + quant_axis, block_size = _get_quant_params(node) + + # Get scale and convert back to fp32 for computation + se8m0 = numpy_helper.to_array(initializer_map[scale_name]) + se8m0_fp32 = se8m0.astype(np.float32) + + # Expand block array so that it can be broadcasted with weight + se8m0_fp32_expanded = np.repeat(se8m0_fp32, block_size, axis=quant_axis) + scaled_weight = weight / np.exp2(se8m0_fp32_expanded - E8_M0_BIAS) + + # Create FP8 weight tensor + weights_e4m3 = onnx.helper.make_tensor( + name=weight_name, + data_type=onnx_dtype_map["Float8"], + dims=[*scaled_weight.shape], + vals=_cast_fp8(scaled_weight).tobytes(), + raw=True, + ) + initializer_map[weight_name].CopyFrom(weights_e4m3) + logger.debug(f"Converted {weight_name} to MXFP8") + + return onnx_model @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Post-processes the ONNX model for MXFP8 quantization.""" + """Post-processes the ONNX model for MXFP8 quantization. + + Sets DQ output type to FP16 and updates GELU nodes to use tanh approximation. + """ + logger.info("Post-processing MXFP8 quantized model") + graph = onnx_model.graph + + # Set output type of DQ to FP16 + for node in graph.node: + if node.op_type == "TRT_MXFP8DequantizeLinear": + for attr in node.attribute: + if attr.name == "output_dtype": + attr.i = onnx_dtype_map["Half"] + + # Currently only tanh approximation is supported for Gelu + for node in graph.node: + if node.op_type == "Gelu": + for attr in node.attribute: + if attr.name == "approximate": + attr.s = b"tanh" + logger.debug(f"Updated GELU node {node.name} to use tanh approximation") + + # Insert cast to fp16 after Sqrt nodes + cast_nodes_to_insert = [] + for idx, node in enumerate(graph.node): + if node.op_type == "Sqrt": + sqrt_output = node.output[0] + cast_output = f"{sqrt_output}_cast_fp16" + + # Create Cast node + cast_node = onnx.helper.make_node( + "Cast", + inputs=[sqrt_output], + outputs=[cast_output], + to=onnx_dtype_map["Half"], + name=f"{node.name}_cast_fp16", + ) + cast_nodes_to_insert.append((idx + 1, cast_node)) + + # Update consumers to use cast output + for consumer in graph.node: + if consumer == node: + continue + for i, inp in enumerate(consumer.input): + if inp == sqrt_output: + consumer.input[i] = cast_output + + # Insert Cast nodes in reverse order to preserve indices + for offset, (pos, cast_node) in enumerate(cast_nodes_to_insert): + graph.node.insert(pos + offset, cast_node) + logger.debug(f"Inserted Cast to FP16 after {cast_node.input[0]}") + + return onnx_model diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 0bdd62948..026b8d062 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -31,8 +31,7 @@ get_tensor_producer_nodes, remove_redundant_cast_nodes, ) -from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax, get_num_bits -from modelopt.onnx.utils import get_attribute, has_attribute +from modelopt.onnx.quantization.quant_utils import get_num_bits QUANTIZE_NODE_NAME = "QuantizeLinear" DEQUANTIZE_NODE_NAME = "DequantizeLinear" @@ -1036,101 +1035,3 @@ def cast_initializer_to_dtype( input_onnx = onnx.numpy_helper.from_array(input, input_name) input_onnx.data_type = onnx_dtype_map[dtype] initializer_map[input_name].CopyFrom(input_onnx) - - -def quantize_weights_to_mxfp8( - onnx_model: onnx.ModelProto, -) -> onnx.ModelProto: - """Converts the weights to FP8 precision using MXFP8 quantization. - - For TRT_MXFP8DynamicQuantize, we update the output type to FP8. - For TRT_MXFP8DequantizeLinear, we compute the scales in e8m0 format and saves them as a new initializer. - We then expand the scale to the same shape as the weight and divide the weight by the scale to get the FP8 weights. - - Args: - graph: ONNX model protobuf. - - Returns: - ONNX model protobuf with weights quantized to FP8 precision using MXFP8 quantization. - """ - logger.info("Converting weights to MXFP8 precision") - graph = onnx_model.graph - initializer_map = {initializer.name: initializer for initializer in graph.initializer} - tensor_producer_map = get_tensor_producer_nodes(graph) - e8_m0_bias = 127 - weight_dq_nodes = [ - node - for node in graph.node - if node.op_type == "TRT_MXFP8DequantizeLinear" - and any(".weight" in input for input in node.input) - ] - gelu_nodes = [node for node in graph.node if node.op_type == "Gelu"] - logger.debug(f"Found {len(weight_dq_nodes)} weight DQ nodes and {len(gelu_nodes)} GELU nodes") - - for node in weight_dq_nodes: - # Get weights and node attributes - weight_name = node.input[0] - logger.debug(f"Processing MXFP8 conversion for weight {weight_name}") - weight = numpy_helper.to_array(initializer_map[weight_name]) - if has_attribute(node, "axis"): - quant_axis = int(get_attribute(node, "axis")) - else: - quant_axis = -1 - logger.warning( - "axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1" - ) - - if has_attribute(node, "block_size"): - block_size = int(get_attribute(node, "block_size")) - else: - block_size = 32 - logger.warning( - "block_size attribute not found for MXFP8DequantizeLinear node. Setting block_size to 32" - ) - - # Compute and save scales as uint8 - amax = get_amax(weight, quant_axis, block_size) - se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size) - se8m0 = se8m0_fp32.astype(np.uint8) - - # Remove scale producer if it's a Constant node - scale_name = node.input[1] - scale_producer = tensor_producer_map[scale_name] - if scale_producer.op_type == "Constant": - graph.node.remove(scale_producer) - - # Create a new scale tensor - scale_name = scale_name.replace("Constant_output_0", "scale") - scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name) - graph.initializer.append(scale_tensor) - node.input[1] = scale_name - - # Convert weights to FP8 - # Expand block array so that it can be broadcasted with weight - se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis) - scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias) - weights_e4m3 = onnx.helper.make_tensor( - name=weight_name, - data_type=onnx_dtype_map["Float8"], - dims=[*scaled_weight.shape], - vals=_cast_fp8(scaled_weight).tobytes(), - raw=True, - ) - initializer_map[weight_name].CopyFrom(weights_e4m3) - logger.debug(f"Converted {weight_name} to MXFP8") - - # set output type of DQ to FP16 - for node in graph.node: - if node.op_type in ["TRT_MXFP8DequantizeLinear"]: - for attr in node.attribute: - if attr.name == "output_dtype": - attr.i = onnx_dtype_map["Half"] - - # Currently only tanh approximation is supported for Gelu - for node in gelu_nodes: - for attr in node.attribute: - if attr.name == "approximate": - attr.s = b"tanh" - logger.debug(f"Updated GELU node {node.name} to use tanh approximation") - - return onnx_model diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 9bfce35a9..ba1c6f56b 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -40,11 +40,7 @@ NVFP4QuantExporter, ONNXQuantExporter, ) -from modelopt.onnx.quantization.qdq_utils import ( - qdq_to_dq, - quantize_weights_to_mxfp8, - replace_zero_scale_with_smallest_nonzero, -) +from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero from modelopt.onnx.utils import ( get_input_names, get_input_shapes, @@ -364,6 +360,11 @@ def is_fp8_quantized(model: nn.Module) -> bool: and hasattr(module, "input_quantizer") and module.weight_quantizer._num_bits == (4, 3) and module.input_quantizer._num_bits == (4, 3) + # Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits + and not ( + module.input_quantizer.block_sizes + and module.input_quantizer.block_sizes.get("scale_bits", None) == (8, 0) + ) ): return True return False @@ -560,11 +561,8 @@ def get_onnx_bytes_and_metadata( # Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode # Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode - if is_int4_quantized(model) or is_fp4_quantized(model): + if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model): onnx_opt_graph = quantize_weights(model, onnx_opt_graph) - elif is_mxfp8_quantized(model): - # TODO: Implement the MXFP8QuantExporter - onnx_opt_graph = quantize_weights_to_mxfp8(onnx_opt_graph) if dq_only: onnx_opt_graph = qdq_to_dq(onnx_opt_graph) @@ -575,7 +573,7 @@ def get_onnx_bytes_and_metadata( except StopIteration: param_dtype = torch.float32 if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: - if is_mxfp8_quantized(model) or is_int4_quantized(model): + if is_int4_quantized(model) or is_mxfp8_quantized(model): assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" onnx_opt_graph = convert_float_to_float16( onnx_opt_graph, diff --git a/tests/unit/onnx/test_qdq_utils.py b/tests/unit/onnx/test_qdq_utils.py index a05d794c3..2acc4046a 100644 --- a/tests/unit/onnx/test_qdq_utils.py +++ b/tests/unit/onnx/test_qdq_utils.py @@ -17,10 +17,8 @@ import pytest from onnx import TensorProto, helper, numpy_helper -from modelopt.onnx.export import NVFP4QuantExporter -from modelopt.onnx.export.int4_exporter import INT4QuantExporter +from modelopt.onnx.export import INT4QuantExporter, MXFP8QuantExporter, NVFP4QuantExporter from modelopt.onnx.export.nvfp4_exporter import _cast_fp4, _cast_fp8 -from modelopt.onnx.quantization.qdq_utils import quantize_weights_to_mxfp8 def create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale: bool = False): @@ -471,15 +469,15 @@ def test_cast_fp4(self, input_array, expected_array): assert np.all(result == expected_array) -class TestQuantizeWeightsToMXFP8: - """Test suite for quantize_weights_to_mxfp8 function.""" +class TestMXFP8QuantExporter: + """Test suite for MXFP8QuantExporter.""" def test_basic_mxfp8_quantization(self): """Test basic MXFP8 quantization with TRT_MXFP8DequantizeLinear nodes.""" model = create_test_model_with_mxfp8_dq() # Run MXFP8 quantization - quantized_model = quantize_weights_to_mxfp8(model) + quantized_model = MXFP8QuantExporter.process_model(model) # Verify weight is converted to FP8 weight_tensor = next( @@ -510,7 +508,7 @@ def test_mxfp8_output_dtype_update(self): model = create_test_model_with_mxfp8_dq() # Run MXFP8 quantization - quantized_model = quantize_weights_to_mxfp8(model) + quantized_model = MXFP8QuantExporter.process_model(model) # Verify output_dtype is set to FP16 dq_node = next( @@ -526,7 +524,7 @@ def test_mxfp8_gelu_approximation_update(self): model = create_test_model_with_mxfp8_dq() # Run MXFP8 quantization - quantized_model = quantize_weights_to_mxfp8(model) + quantized_model = MXFP8QuantExporter.process_model(model) # Verify Gelu approximation is set to tanh gelu_node = next(node for node in quantized_model.graph.node if node.op_type == "Gelu") @@ -574,7 +572,7 @@ def test_mxfp8_with_missing_attributes(self): model = helper.make_model(graph) # Run MXFP8 quantization (should use default values) - quantized_model = quantize_weights_to_mxfp8(model) + quantized_model = MXFP8QuantExporter.process_model(model) # Verify the model is still processed correctly weight_tensor = next(