From 17d59a4f61286670509fb4e50c0f70c302cf6a02 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Mon, 1 Dec 2025 17:55:16 +0000 Subject: [PATCH 1/6] [OMNIML-2244] Create MXFP8 quant exporter Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/export/mxfp8_exporter.py | 156 ++++++++++++++++++++- modelopt/onnx/quantization/qdq_utils.py | 102 +------------- modelopt/torch/_deploy/utils/torch_onnx.py | 11 +- tests/unit/onnx/test_qdq_utils.py | 14 +- 4 files changed, 167 insertions(+), 116 deletions(-) diff --git a/modelopt/onnx/export/mxfp8_exporter.py b/modelopt/onnx/export/mxfp8_exporter.py index 360e02b4f..99cdba989 100644 --- a/modelopt/onnx/export/mxfp8_exporter.py +++ b/modelopt/onnx/export/mxfp8_exporter.py @@ -15,27 +15,175 @@ """MXFP8 quantization exporter.""" +import numpy as np import onnx +from onnx import numpy_helper + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.graph_utils import get_tensor_producer_nodes +from modelopt.onnx.quantization.qdq_utils import _cast_fp8, onnx_dtype_map +from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax +from modelopt.onnx.utils import get_attribute, has_attribute from .base_exporter import ONNXQuantExporter +E8_M0_BIAS = 127 +DEFAULT_BLOCK_SIZE = 32 +DEFAULT_QUANT_AXIS = -1 + + +def _get_weight_dq_nodes(graph: onnx.GraphProto) -> list[onnx.NodeProto]: + """Get weight DequantizeLinear nodes from the graph.""" + return [ + node + for node in graph.node + if node.op_type == "TRT_MXFP8DequantizeLinear" + and any(".weight" in inp for inp in node.input) + ] + + +def _get_quant_params(node: onnx.NodeProto) -> tuple[int, int]: + """Extract quantization axis and block size from a node.""" + if has_attribute(node, "axis"): + quant_axis = int(get_attribute(node, "axis")) + else: + quant_axis = DEFAULT_QUANT_AXIS + logger.warning( + "axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1" + ) + + if has_attribute(node, "block_size"): + block_size = int(get_attribute(node, "block_size")) + else: + block_size = DEFAULT_BLOCK_SIZE + logger.warning( + "block_size attribute not found for MXFP8DequantizeLinear node. " + "Setting block_size to 32" + ) + + return quant_axis, block_size + -# TODO: Implement the MXFP8QuantExporter class MXFP8QuantExporter(ONNXQuantExporter): """Exporter for MXFP8 quantization.""" @staticmethod def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: """Pre-processes the ONNX model for MXFP8 quantization.""" + return onnx_model @staticmethod def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Computes the scales for the weights in the ONNX model for MXFP8 quantization.""" + """Computes the e8m0 scales for weights in the ONNX model for MXFP8 quantization.""" + logger.info("Computing MXFP8 scales for weights") + graph = onnx_model.graph + initializer_map = {init.name: init for init in graph.initializer} + tensor_producer_map = get_tensor_producer_nodes(graph) + + for node in _get_weight_dq_nodes(graph): + weight_name = node.input[0] + logger.debug(f"Computing MXFP8 scale for weight {weight_name}") + + weight = numpy_helper.to_array(initializer_map[weight_name]) + quant_axis, block_size = _get_quant_params(node) + + # Compute scales + amax = get_amax(weight, quant_axis, block_size) + se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size) + se8m0 = se8m0_fp32.astype(np.uint8) + + # Remove scale producer if it's a Constant node + scale_name = node.input[1] + scale_producer = tensor_producer_map[scale_name] + if scale_producer.op_type == "Constant": + graph.node.remove(scale_producer) + + # Create and add new scale tensor + scale_name_new = scale_name.replace("Constant_output_0", "scale") + scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name_new) + graph.initializer.append(scale_tensor) + node.input[1] = scale_name_new + + return onnx_model @staticmethod def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Compresses the weights in the ONNX model for MXFP8 quantization.""" + """Compresses the weights in the ONNX model to FP8 format for MXFP8 quantization.""" + logger.info("Compressing weights to MXFP8 format") + graph = onnx_model.graph + initializer_map = {init.name: init for init in graph.initializer} + + for node in _get_weight_dq_nodes(graph): + weight_name = node.input[0] + scale_name = node.input[1] + logger.debug(f"Compressing weight {weight_name} to MXFP8") + + weight = numpy_helper.to_array(initializer_map[weight_name]) + quant_axis, block_size = _get_quant_params(node) + + # Get scale and convert back to fp32 for computation + se8m0 = numpy_helper.to_array(initializer_map[scale_name]) + se8m0_fp32 = se8m0.astype(np.float32) + + # Expand block array so that it can be broadcasted with weight + se8m0_fp32_expanded = np.repeat(se8m0_fp32, block_size, axis=quant_axis) + scaled_weight = weight / np.exp2(se8m0_fp32_expanded - E8_M0_BIAS) + + # Create FP8 weight tensor + weights_e4m3 = onnx.helper.make_tensor( + name=weight_name, + data_type=onnx_dtype_map["Float8"], + dims=[*scaled_weight.shape], + vals=_cast_fp8(scaled_weight).tobytes(), + raw=True, + ) + initializer_map[weight_name].CopyFrom(weights_e4m3) + logger.debug(f"Converted {weight_name} to MXFP8") + + return onnx_model @staticmethod def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: - """Post-processes the ONNX model for MXFP8 quantization.""" + """Post-processes the ONNX model for MXFP8 quantization. + + Sets DQ output type to FP16 and updates GELU nodes to use tanh approximation. + """ + logger.info("Post-processing MXFP8 quantized model") + graph = onnx_model.graph + + # Set output type of DQ to FP16 + for node in graph.node: + if node.op_type == "TRT_MXFP8DequantizeLinear": + for attr in node.attribute: + if attr.name == "output_dtype": + attr.i = onnx_dtype_map["Half"] + + # Currently only tanh approximation is supported for Gelu + for node in graph.node: + if node.op_type == "Gelu": + for attr in node.attribute: + if attr.name == "approximate": + attr.s = b"tanh" + logger.debug(f"Updated GELU node {node.name} to use tanh approximation") + + def is_fp32_cast(node: onnx.NodeProto) -> bool: + return node.op_type == "Cast" and any( + attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute + ) + + # Remove Cast nodes after specific operators + nodes_to_remove = [] + for node in graph.node: + if node.op_type in ["Transpose", "Reshape", "Sqrt", "Add", "Gelu"]: + child_nodes = [n for n in graph.node if node.output[0] in n.input] + if len(child_nodes) == 1 and is_fp32_cast(child_nodes[0]): + cast_node = child_nodes[0] + node.output.clear() + node.output.extend(cast_node.output) + nodes_to_remove.append(cast_node.name) + + # Remove unnecessary casts + new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] + graph.node.extend(new_nodes) + + return onnx_model diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 0bdd62948..ed3ac171c 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -31,8 +31,8 @@ get_tensor_producer_nodes, remove_redundant_cast_nodes, ) -from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax, get_num_bits -from modelopt.onnx.utils import get_attribute, has_attribute +from modelopt.onnx.quantization.quant_utils import get_num_bits + QUANTIZE_NODE_NAME = "QuantizeLinear" DEQUANTIZE_NODE_NAME = "DequantizeLinear" @@ -1036,101 +1036,3 @@ def cast_initializer_to_dtype( input_onnx = onnx.numpy_helper.from_array(input, input_name) input_onnx.data_type = onnx_dtype_map[dtype] initializer_map[input_name].CopyFrom(input_onnx) - - -def quantize_weights_to_mxfp8( - onnx_model: onnx.ModelProto, -) -> onnx.ModelProto: - """Converts the weights to FP8 precision using MXFP8 quantization. - - For TRT_MXFP8DynamicQuantize, we update the output type to FP8. - For TRT_MXFP8DequantizeLinear, we compute the scales in e8m0 format and saves them as a new initializer. - We then expand the scale to the same shape as the weight and divide the weight by the scale to get the FP8 weights. - - Args: - graph: ONNX model protobuf. - - Returns: - ONNX model protobuf with weights quantized to FP8 precision using MXFP8 quantization. - """ - logger.info("Converting weights to MXFP8 precision") - graph = onnx_model.graph - initializer_map = {initializer.name: initializer for initializer in graph.initializer} - tensor_producer_map = get_tensor_producer_nodes(graph) - e8_m0_bias = 127 - weight_dq_nodes = [ - node - for node in graph.node - if node.op_type == "TRT_MXFP8DequantizeLinear" - and any(".weight" in input for input in node.input) - ] - gelu_nodes = [node for node in graph.node if node.op_type == "Gelu"] - logger.debug(f"Found {len(weight_dq_nodes)} weight DQ nodes and {len(gelu_nodes)} GELU nodes") - - for node in weight_dq_nodes: - # Get weights and node attributes - weight_name = node.input[0] - logger.debug(f"Processing MXFP8 conversion for weight {weight_name}") - weight = numpy_helper.to_array(initializer_map[weight_name]) - if has_attribute(node, "axis"): - quant_axis = int(get_attribute(node, "axis")) - else: - quant_axis = -1 - logger.warning( - "axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1" - ) - - if has_attribute(node, "block_size"): - block_size = int(get_attribute(node, "block_size")) - else: - block_size = 32 - logger.warning( - "block_size attribute not found for MXFP8DequantizeLinear node. Setting block_size to 32" - ) - - # Compute and save scales as uint8 - amax = get_amax(weight, quant_axis, block_size) - se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size) - se8m0 = se8m0_fp32.astype(np.uint8) - - # Remove scale producer if it's a Constant node - scale_name = node.input[1] - scale_producer = tensor_producer_map[scale_name] - if scale_producer.op_type == "Constant": - graph.node.remove(scale_producer) - - # Create a new scale tensor - scale_name = scale_name.replace("Constant_output_0", "scale") - scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name) - graph.initializer.append(scale_tensor) - node.input[1] = scale_name - - # Convert weights to FP8 - # Expand block array so that it can be broadcasted with weight - se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis) - scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias) - weights_e4m3 = onnx.helper.make_tensor( - name=weight_name, - data_type=onnx_dtype_map["Float8"], - dims=[*scaled_weight.shape], - vals=_cast_fp8(scaled_weight).tobytes(), - raw=True, - ) - initializer_map[weight_name].CopyFrom(weights_e4m3) - logger.debug(f"Converted {weight_name} to MXFP8") - - # set output type of DQ to FP16 - for node in graph.node: - if node.op_type in ["TRT_MXFP8DequantizeLinear"]: - for attr in node.attribute: - if attr.name == "output_dtype": - attr.i = onnx_dtype_map["Half"] - - # Currently only tanh approximation is supported for Gelu - for node in gelu_nodes: - for attr in node.attribute: - if attr.name == "approximate": - attr.s = b"tanh" - logger.debug(f"Updated GELU node {node.name} to use tanh approximation") - - return onnx_model diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 9bfce35a9..9a0a229bf 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -42,7 +42,6 @@ ) from modelopt.onnx.quantization.qdq_utils import ( qdq_to_dq, - quantize_weights_to_mxfp8, replace_zero_scale_with_smallest_nonzero, ) from modelopt.onnx.utils import ( @@ -364,6 +363,11 @@ def is_fp8_quantized(model: nn.Module) -> bool: and hasattr(module, "input_quantizer") and module.weight_quantizer._num_bits == (4, 3) and module.input_quantizer._num_bits == (4, 3) + # Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits + and not ( + module.input_quantizer.block_sizes + and module.input_quantizer.block_sizes.get("scale_bits", None) == (8, 0) + ) ): return True return False @@ -560,11 +564,8 @@ def get_onnx_bytes_and_metadata( # Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode # Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode - if is_int4_quantized(model) or is_fp4_quantized(model): + if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model): onnx_opt_graph = quantize_weights(model, onnx_opt_graph) - elif is_mxfp8_quantized(model): - # TODO: Implement the MXFP8QuantExporter - onnx_opt_graph = quantize_weights_to_mxfp8(onnx_opt_graph) if dq_only: onnx_opt_graph = qdq_to_dq(onnx_opt_graph) diff --git a/tests/unit/onnx/test_qdq_utils.py b/tests/unit/onnx/test_qdq_utils.py index a05d794c3..1fb05d75e 100644 --- a/tests/unit/onnx/test_qdq_utils.py +++ b/tests/unit/onnx/test_qdq_utils.py @@ -20,7 +20,7 @@ from modelopt.onnx.export import NVFP4QuantExporter from modelopt.onnx.export.int4_exporter import INT4QuantExporter from modelopt.onnx.export.nvfp4_exporter import _cast_fp4, _cast_fp8 -from modelopt.onnx.quantization.qdq_utils import quantize_weights_to_mxfp8 +from modelopt.onnx.export.mxfp8_exporter import MXFP8QuantExporter def create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale: bool = False): @@ -471,15 +471,15 @@ def test_cast_fp4(self, input_array, expected_array): assert np.all(result == expected_array) -class TestQuantizeWeightsToMXFP8: - """Test suite for quantize_weights_to_mxfp8 function.""" +class TestMXFP8QuantExporter: + """Test suite for MXFP8QuantExporter.""" def test_basic_mxfp8_quantization(self): """Test basic MXFP8 quantization with TRT_MXFP8DequantizeLinear nodes.""" model = create_test_model_with_mxfp8_dq() # Run MXFP8 quantization - quantized_model = quantize_weights_to_mxfp8(model) + quantized_model = MXFP8QuantExporter.process_model(model) # Verify weight is converted to FP8 weight_tensor = next( @@ -510,7 +510,7 @@ def test_mxfp8_output_dtype_update(self): model = create_test_model_with_mxfp8_dq() # Run MXFP8 quantization - quantized_model = quantize_weights_to_mxfp8(model) + quantized_model = MXFP8QuantExporter.process_model(model) # Verify output_dtype is set to FP16 dq_node = next( @@ -526,7 +526,7 @@ def test_mxfp8_gelu_approximation_update(self): model = create_test_model_with_mxfp8_dq() # Run MXFP8 quantization - quantized_model = quantize_weights_to_mxfp8(model) + quantized_model = MXFP8QuantExporter.process_model(model) # Verify Gelu approximation is set to tanh gelu_node = next(node for node in quantized_model.graph.node if node.op_type == "Gelu") @@ -574,7 +574,7 @@ def test_mxfp8_with_missing_attributes(self): model = helper.make_model(graph) # Run MXFP8 quantization (should use default values) - quantized_model = quantize_weights_to_mxfp8(model) + quantized_model = MXFP8QuantExporter.process_model(model) # Verify the model is still processed correctly weight_tensor = next( From 918d081f925ff9abe2a8ddfab3f90fc30a6e8745 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Tue, 2 Dec 2025 21:28:08 +0000 Subject: [PATCH 2/6] Integrate autocast for mxfp8 Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/export/mxfp8_exporter.py | 20 -------------------- modelopt/onnx/trt_utils.py | 1 + modelopt/torch/_deploy/utils/torch_onnx.py | 2 +- 3 files changed, 2 insertions(+), 21 deletions(-) diff --git a/modelopt/onnx/export/mxfp8_exporter.py b/modelopt/onnx/export/mxfp8_exporter.py index 99cdba989..f959d821e 100644 --- a/modelopt/onnx/export/mxfp8_exporter.py +++ b/modelopt/onnx/export/mxfp8_exporter.py @@ -166,24 +166,4 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: attr.s = b"tanh" logger.debug(f"Updated GELU node {node.name} to use tanh approximation") - def is_fp32_cast(node: onnx.NodeProto) -> bool: - return node.op_type == "Cast" and any( - attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute - ) - - # Remove Cast nodes after specific operators - nodes_to_remove = [] - for node in graph.node: - if node.op_type in ["Transpose", "Reshape", "Sqrt", "Add", "Gelu"]: - child_nodes = [n for n in graph.node if node.output[0] in n.input] - if len(child_nodes) == 1 and is_fp32_cast(child_nodes[0]): - cast_node = child_nodes[0] - node.output.clear() - node.output.extend(cast_node.output) - nodes_to_remove.append(cast_node.name) - - # Remove unnecessary casts - new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] - graph.node.extend(new_nodes) - return onnx_model diff --git a/modelopt/onnx/trt_utils.py b/modelopt/onnx/trt_utils.py index 5fde66778..1fc21c8d4 100644 --- a/modelopt/onnx/trt_utils.py +++ b/modelopt/onnx/trt_utils.py @@ -140,6 +140,7 @@ def _map_trt_to_onnx_type(trt_type: trt.DataType): trt.bool: onnx.TensorProto.BOOL, trt.fp8: onnx.TensorProto.FLOAT8E4M3FN, trt.fp4: onnx.TensorProto.FLOAT4E2M1, + trt.e8m0: onnx.TensorProto.UINT8, } try: return trt_to_onnx_dtype_mapping[trt_type] diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index 9a0a229bf..fa09a12a4 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -576,7 +576,7 @@ def get_onnx_bytes_and_metadata( except StopIteration: param_dtype = torch.float32 if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: - if is_mxfp8_quantized(model) or is_int4_quantized(model): + if is_int4_quantized(model): assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" onnx_opt_graph = convert_float_to_float16( onnx_opt_graph, From 2146a4e8c3ab280e27ff4309ef2278f169b7eeb0 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 3 Dec 2025 02:27:55 +0000 Subject: [PATCH 3/6] Update tests container Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- .github/workflows/gpu_tests.yml | 2 +- .gitlab/tests.yml | 41 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index be7e88dfc..aa0928c67 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -62,7 +62,7 @@ jobs: runs-on: linux-amd64-gpu-l4-latest-1 timeout-minutes: 120 container: &gpu_container - image: nvcr.io/nvidia/pytorch:25.06-py3 + image: nvcr.io/nvidia/pytorch:25.08-py3 env: GIT_DEPTH: 1000 # For correct version for tests/gpu/torch/quantization/plugins/test_megatron.py PIP_CONSTRAINT: "" # Disable pip constraint for upgrading packages diff --git a/.gitlab/tests.yml b/.gitlab/tests.yml index b83db6f57..96a9d582e 100644 --- a/.gitlab/tests.yml +++ b/.gitlab/tests.yml @@ -13,6 +13,47 @@ example-onnx-bash: extends: .tests-default timeout: 90m + image: nvcr.io/nvidia/pytorch:25.08-py3 + variables: + GIT_DEPTH: 1000 # For correct version for tests/gpu/torch/quantization/plugins/test_megatron.py + tags: [docker, linux, 2-gpu] + before_script: + # Add libcudnn*.so and libnv*.so to path + - export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu" + # Install git-lfs for Daring-Anteater dataset + - apt-get update && apt-get install -y git-lfs + - git lfs install --system + +multi-gpu: + extends: .multi-gpu-tests-default + script: + # Use pre-installed packages without a new venv with tox-current-env + - pip install tox-current-env + - tox -e py312-cuda12-gpu --current-env + +##### Example Tests ##### +example-torch: + extends: .multi-gpu-tests-default + timeout: 30m + parallel: + matrix: + - EXAMPLE: [llm_distill, llm_qat, llm_sparsity, speculative_decoding] + script: + - pip install ".[hf,dev-test]" + - find examples/$EXAMPLE -name "requirements.txt" | while read req_file; do pip install -r "$req_file" || exit 1; done + - pytest -s tests/examples/$EXAMPLE + +example-trtllm: + extends: example-torch + timeout: 60m + image: nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2.post2 + tags: [docker, linux, 2-gpu, sm>=89] + parallel: + matrix: + - EXAMPLE: [llm_autodeploy, llm_eval, llm_ptq, vlm_ptq] + +example-onnx: + extends: example-torch image: nvcr.io/nvidia/tensorrt:25.08-py3 tags: [docker, linux, 2-gpu, sm>=89] parallel: From 9b30e178b4aed05d1d45bc13471514b9dd203ca1 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Wed, 3 Dec 2025 21:48:01 +0000 Subject: [PATCH 4/6] Do not use autocast for mxfp8 Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- .github/workflows/gpu_tests.yml | 2 +- .gitlab/tests.yml | 2 +- modelopt/onnx/export/mxfp8_exporter.py | 30 ++++++++++++++++++++++ modelopt/torch/_deploy/utils/torch_onnx.py | 2 +- 4 files changed, 33 insertions(+), 3 deletions(-) diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml index aa0928c67..be7e88dfc 100644 --- a/.github/workflows/gpu_tests.yml +++ b/.github/workflows/gpu_tests.yml @@ -62,7 +62,7 @@ jobs: runs-on: linux-amd64-gpu-l4-latest-1 timeout-minutes: 120 container: &gpu_container - image: nvcr.io/nvidia/pytorch:25.08-py3 + image: nvcr.io/nvidia/pytorch:25.06-py3 env: GIT_DEPTH: 1000 # For correct version for tests/gpu/torch/quantization/plugins/test_megatron.py PIP_CONSTRAINT: "" # Disable pip constraint for upgrading packages diff --git a/.gitlab/tests.yml b/.gitlab/tests.yml index 96a9d582e..59bf1c2f7 100644 --- a/.gitlab/tests.yml +++ b/.gitlab/tests.yml @@ -54,7 +54,7 @@ example-trtllm: example-onnx: extends: example-torch - image: nvcr.io/nvidia/tensorrt:25.08-py3 + image: nvcr.io/nvidia/tensorrt:25.06-py3 tags: [docker, linux, 2-gpu, sm>=89] parallel: matrix: diff --git a/modelopt/onnx/export/mxfp8_exporter.py b/modelopt/onnx/export/mxfp8_exporter.py index f959d821e..8c1e1f4df 100644 --- a/modelopt/onnx/export/mxfp8_exporter.py +++ b/modelopt/onnx/export/mxfp8_exporter.py @@ -166,4 +166,34 @@ def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto: attr.s = b"tanh" logger.debug(f"Updated GELU node {node.name} to use tanh approximation") + # Insert cast to fp16 after Sqrt nodes + cast_nodes_to_insert = [] + for idx, node in enumerate(graph.node): + if node.op_type == "Sqrt": + sqrt_output = node.output[0] + cast_output = f"{sqrt_output}_cast_fp16" + + # Create Cast node + cast_node = onnx.helper.make_node( + "Cast", + inputs=[sqrt_output], + outputs=[cast_output], + to=onnx_dtype_map["Half"], + name=f"{node.name}_cast_fp16", + ) + cast_nodes_to_insert.append((idx + 1, cast_node)) + + # Update consumers to use cast output + for consumer in graph.node: + if consumer == node: + continue + for i, inp in enumerate(consumer.input): + if inp == sqrt_output: + consumer.input[i] = cast_output + + # Insert Cast nodes in reverse order to preserve indices + for offset, (pos, cast_node) in enumerate(cast_nodes_to_insert): + graph.node.insert(pos + offset, cast_node) + logger.debug(f"Inserted Cast to FP16 after {cast_node.input[0]}") + return onnx_model diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index fa09a12a4..e7602b30a 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -576,7 +576,7 @@ def get_onnx_bytes_and_metadata( except StopIteration: param_dtype = torch.float32 if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32: - if is_int4_quantized(model): + if is_int4_quantized(model) or is_mxfp8_quantized(model): assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet" onnx_opt_graph = convert_float_to_float16( onnx_opt_graph, From b6850193f1b844603e47b3491b34a0bef0cfb366 Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Thu, 4 Dec 2025 21:06:09 +0000 Subject: [PATCH 5/6] Rebase latest changes Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- .gitlab/tests.yml | 43 +------------------------------------------ 1 file changed, 1 insertion(+), 42 deletions(-) diff --git a/.gitlab/tests.yml b/.gitlab/tests.yml index 59bf1c2f7..b83db6f57 100644 --- a/.gitlab/tests.yml +++ b/.gitlab/tests.yml @@ -13,48 +13,7 @@ example-onnx-bash: extends: .tests-default timeout: 90m - image: nvcr.io/nvidia/pytorch:25.08-py3 - variables: - GIT_DEPTH: 1000 # For correct version for tests/gpu/torch/quantization/plugins/test_megatron.py - tags: [docker, linux, 2-gpu] - before_script: - # Add libcudnn*.so and libnv*.so to path - - export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/usr/include:/usr/lib/x86_64-linux-gnu" - # Install git-lfs for Daring-Anteater dataset - - apt-get update && apt-get install -y git-lfs - - git lfs install --system - -multi-gpu: - extends: .multi-gpu-tests-default - script: - # Use pre-installed packages without a new venv with tox-current-env - - pip install tox-current-env - - tox -e py312-cuda12-gpu --current-env - -##### Example Tests ##### -example-torch: - extends: .multi-gpu-tests-default - timeout: 30m - parallel: - matrix: - - EXAMPLE: [llm_distill, llm_qat, llm_sparsity, speculative_decoding] - script: - - pip install ".[hf,dev-test]" - - find examples/$EXAMPLE -name "requirements.txt" | while read req_file; do pip install -r "$req_file" || exit 1; done - - pytest -s tests/examples/$EXAMPLE - -example-trtllm: - extends: example-torch - timeout: 60m - image: nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc2.post2 - tags: [docker, linux, 2-gpu, sm>=89] - parallel: - matrix: - - EXAMPLE: [llm_autodeploy, llm_eval, llm_ptq, vlm_ptq] - -example-onnx: - extends: example-torch - image: nvcr.io/nvidia/tensorrt:25.06-py3 + image: nvcr.io/nvidia/tensorrt:25.08-py3 tags: [docker, linux, 2-gpu, sm>=89] parallel: matrix: From 92ff47a60748bd2d48251eb333877452f546636c Mon Sep 17 00:00:00 2001 From: ajrasane <131806219+ajrasane@users.noreply.github.com> Date: Fri, 5 Dec 2025 20:14:30 +0000 Subject: [PATCH 6/6] Fix imports after rebase Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com> --- modelopt/onnx/quantization/qdq_utils.py | 1 - modelopt/onnx/trt_utils.py | 1 - modelopt/torch/_deploy/utils/torch_onnx.py | 5 +---- tests/unit/onnx/test_qdq_utils.py | 4 +--- 4 files changed, 2 insertions(+), 9 deletions(-) diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index ed3ac171c..026b8d062 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -33,7 +33,6 @@ ) from modelopt.onnx.quantization.quant_utils import get_num_bits - QUANTIZE_NODE_NAME = "QuantizeLinear" DEQUANTIZE_NODE_NAME = "DequantizeLinear" diff --git a/modelopt/onnx/trt_utils.py b/modelopt/onnx/trt_utils.py index 1fc21c8d4..5fde66778 100644 --- a/modelopt/onnx/trt_utils.py +++ b/modelopt/onnx/trt_utils.py @@ -140,7 +140,6 @@ def _map_trt_to_onnx_type(trt_type: trt.DataType): trt.bool: onnx.TensorProto.BOOL, trt.fp8: onnx.TensorProto.FLOAT8E4M3FN, trt.fp4: onnx.TensorProto.FLOAT4E2M1, - trt.e8m0: onnx.TensorProto.UINT8, } try: return trt_to_onnx_dtype_mapping[trt_type] diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index e7602b30a..ba1c6f56b 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -40,10 +40,7 @@ NVFP4QuantExporter, ONNXQuantExporter, ) -from modelopt.onnx.quantization.qdq_utils import ( - qdq_to_dq, - replace_zero_scale_with_smallest_nonzero, -) +from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero from modelopt.onnx.utils import ( get_input_names, get_input_shapes, diff --git a/tests/unit/onnx/test_qdq_utils.py b/tests/unit/onnx/test_qdq_utils.py index 1fb05d75e..2acc4046a 100644 --- a/tests/unit/onnx/test_qdq_utils.py +++ b/tests/unit/onnx/test_qdq_utils.py @@ -17,10 +17,8 @@ import pytest from onnx import TensorProto, helper, numpy_helper -from modelopt.onnx.export import NVFP4QuantExporter -from modelopt.onnx.export.int4_exporter import INT4QuantExporter +from modelopt.onnx.export import INT4QuantExporter, MXFP8QuantExporter, NVFP4QuantExporter from modelopt.onnx.export.nvfp4_exporter import _cast_fp4, _cast_fp8 -from modelopt.onnx.export.mxfp8_exporter import MXFP8QuantExporter def create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale: bool = False):