Skip to content

Commit 52cdc00

Browse files
committed
[OMNIML-2244] Create MXFP8 quant exporter
Signed-off-by: ajrasane <[email protected]>
1 parent 7edf59c commit 52cdc00

File tree

4 files changed

+166
-122
lines changed

4 files changed

+166
-122
lines changed

modelopt/onnx/export/mxfp8_exporter.py

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,175 @@
1515

1616
"""MXFP8 quantization exporter."""
1717

18+
import numpy as np
1819
import onnx
20+
from onnx import numpy_helper
21+
22+
from modelopt.onnx.logging_config import logger
23+
from modelopt.onnx.quantization.graph_utils import get_tensor_producer_nodes
24+
from modelopt.onnx.quantization.qdq_utils import _cast_fp8, onnx_dtype_map
25+
from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax
26+
from modelopt.onnx.utils import get_attribute, has_attribute
1927

2028
from .base_exporter import ONNXQuantExporter
2129

30+
E8_M0_BIAS = 127
31+
DEFAULT_BLOCK_SIZE = 32
32+
DEFAULT_QUANT_AXIS = -1
33+
34+
35+
def _get_weight_dq_nodes(graph: onnx.GraphProto) -> list[onnx.NodeProto]:
36+
"""Get weight DequantizeLinear nodes from the graph."""
37+
return [
38+
node
39+
for node in graph.node
40+
if node.op_type == "TRT_MXFP8DequantizeLinear"
41+
and any(".weight" in inp for inp in node.input)
42+
]
43+
44+
45+
def _get_quant_params(node: onnx.NodeProto) -> tuple[int, int]:
46+
"""Extract quantization axis and block size from a node."""
47+
if has_attribute(node, "axis"):
48+
quant_axis = int(get_attribute(node, "axis"))
49+
else:
50+
quant_axis = DEFAULT_QUANT_AXIS
51+
logger.warning(
52+
"axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1"
53+
)
54+
55+
if has_attribute(node, "block_size"):
56+
block_size = int(get_attribute(node, "block_size"))
57+
else:
58+
block_size = DEFAULT_BLOCK_SIZE
59+
logger.warning(
60+
"block_size attribute not found for MXFP8DequantizeLinear node. "
61+
"Setting block_size to 32"
62+
)
63+
64+
return quant_axis, block_size
65+
2266

23-
# TODO: Implement the MXFP8QuantExporter
2467
class MXFP8QuantExporter(ONNXQuantExporter):
2568
"""Exporter for MXFP8 quantization."""
2669

2770
@staticmethod
2871
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
2972
"""Pre-processes the ONNX model for MXFP8 quantization."""
73+
return onnx_model
3074

3175
@staticmethod
3276
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
33-
"""Computes the scales for the weights in the ONNX model for MXFP8 quantization."""
77+
"""Computes the e8m0 scales for weights in the ONNX model for MXFP8 quantization."""
78+
logger.info("Computing MXFP8 scales for weights")
79+
graph = onnx_model.graph
80+
initializer_map = {init.name: init for init in graph.initializer}
81+
tensor_producer_map = get_tensor_producer_nodes(graph)
82+
83+
for node in _get_weight_dq_nodes(graph):
84+
weight_name = node.input[0]
85+
logger.debug(f"Computing MXFP8 scale for weight {weight_name}")
86+
87+
weight = numpy_helper.to_array(initializer_map[weight_name])
88+
quant_axis, block_size = _get_quant_params(node)
89+
90+
# Compute scales
91+
amax = get_amax(weight, quant_axis, block_size)
92+
se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size)
93+
se8m0 = se8m0_fp32.astype(np.uint8)
94+
95+
# Remove scale producer if it's a Constant node
96+
scale_name = node.input[1]
97+
scale_producer = tensor_producer_map[scale_name]
98+
if scale_producer.op_type == "Constant":
99+
graph.node.remove(scale_producer)
100+
101+
# Create and add new scale tensor
102+
scale_name_new = scale_name.replace("Constant_output_0", "scale")
103+
scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name_new)
104+
graph.initializer.append(scale_tensor)
105+
node.input[1] = scale_name_new
106+
107+
return onnx_model
34108

35109
@staticmethod
36110
def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
37-
"""Compresses the weights in the ONNX model for MXFP8 quantization."""
111+
"""Compresses the weights in the ONNX model to FP8 format for MXFP8 quantization."""
112+
logger.info("Compressing weights to MXFP8 format")
113+
graph = onnx_model.graph
114+
initializer_map = {init.name: init for init in graph.initializer}
115+
116+
for node in _get_weight_dq_nodes(graph):
117+
weight_name = node.input[0]
118+
scale_name = node.input[1]
119+
logger.debug(f"Compressing weight {weight_name} to MXFP8")
120+
121+
weight = numpy_helper.to_array(initializer_map[weight_name])
122+
quant_axis, block_size = _get_quant_params(node)
123+
124+
# Get scale and convert back to fp32 for computation
125+
se8m0 = numpy_helper.to_array(initializer_map[scale_name])
126+
se8m0_fp32 = se8m0.astype(np.float32)
127+
128+
# Expand block array so that it can be broadcasted with weight
129+
se8m0_fp32_expanded = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
130+
scaled_weight = weight / np.exp2(se8m0_fp32_expanded - E8_M0_BIAS)
131+
132+
# Create FP8 weight tensor
133+
weights_e4m3 = onnx.helper.make_tensor(
134+
name=weight_name,
135+
data_type=onnx_dtype_map["Float8"],
136+
dims=[*scaled_weight.shape],
137+
vals=_cast_fp8(scaled_weight).tobytes(),
138+
raw=True,
139+
)
140+
initializer_map[weight_name].CopyFrom(weights_e4m3)
141+
logger.debug(f"Converted {weight_name} to MXFP8")
142+
143+
return onnx_model
38144

39145
@staticmethod
40146
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
41-
"""Post-processes the ONNX model for MXFP8 quantization."""
147+
"""Post-processes the ONNX model for MXFP8 quantization.
148+
149+
Sets DQ output type to FP16 and updates GELU nodes to use tanh approximation.
150+
"""
151+
logger.info("Post-processing MXFP8 quantized model")
152+
graph = onnx_model.graph
153+
154+
# Set output type of DQ to FP16
155+
for node in graph.node:
156+
if node.op_type == "TRT_MXFP8DequantizeLinear":
157+
for attr in node.attribute:
158+
if attr.name == "output_dtype":
159+
attr.i = onnx_dtype_map["Half"]
160+
161+
# Currently only tanh approximation is supported for Gelu
162+
for node in graph.node:
163+
if node.op_type == "Gelu":
164+
for attr in node.attribute:
165+
if attr.name == "approximate":
166+
attr.s = b"tanh"
167+
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")
168+
169+
def is_fp32_cast(node: onnx.NodeProto) -> bool:
170+
return node.op_type == "Cast" and any(
171+
attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute
172+
)
173+
174+
# Remove Cast nodes after specific operators
175+
nodes_to_remove = []
176+
for node in graph.node:
177+
if node.op_type in ["Transpose", "Reshape", "Sqrt", "Add", "Gelu"]:
178+
child_nodes = [n for n in graph.node if node.output[0] in n.input]
179+
if len(child_nodes) == 1 and is_fp32_cast(child_nodes[0]):
180+
cast_node = child_nodes[0]
181+
node.output.clear()
182+
node.output.extend(cast_node.output)
183+
nodes_to_remove.append(cast_node.name)
184+
185+
# Remove unnecessary casts
186+
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
187+
graph.node.extend(new_nodes)
188+
189+
return onnx_model

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,11 @@
3333
remove_redundant_cast_nodes,
3434
)
3535
from modelopt.onnx.quantization.quant_utils import (
36-
compute_e8m0,
37-
get_amax,
3836
get_num_bits,
3937
get_weights_scaling_factor,
4038
get_weights_scaling_factor_2,
4139
quantize,
4240
)
43-
from modelopt.onnx.utils import get_attribute, has_attribute
4441
from modelopt.torch.quantization.qtensor import NVFP4QTensor
4542

4643
QUANTIZE_NODE_NAME = "QuantizeLinear"
@@ -1066,104 +1063,6 @@ def cast_initializer_to_dtype(
10661063
initializer_map[input_name].CopyFrom(input_onnx)
10671064

10681065

1069-
def quantize_weights_to_mxfp8(
1070-
onnx_model: onnx.ModelProto,
1071-
) -> onnx.ModelProto:
1072-
"""Converts the weights to FP8 precision using MXFP8 quantization.
1073-
1074-
For TRT_MXFP8DynamicQuantize, we update the output type to FP8.
1075-
For TRT_MXFP8DequantizeLinear, we compute the scales in e8m0 format and saves them as a new initializer.
1076-
We then expand the scale to the same shape as the weight and divide the weight by the scale to get the FP8 weights.
1077-
1078-
Args:
1079-
graph: ONNX model protobuf.
1080-
1081-
Returns:
1082-
ONNX model protobuf with weights quantized to FP8 precision using MXFP8 quantization.
1083-
"""
1084-
logger.info("Converting weights to MXFP8 precision")
1085-
graph = onnx_model.graph
1086-
initializer_map = {initializer.name: initializer for initializer in graph.initializer}
1087-
tensor_producer_map = get_tensor_producer_nodes(graph)
1088-
e8_m0_bias = 127
1089-
weight_dq_nodes = [
1090-
node
1091-
for node in graph.node
1092-
if node.op_type == "TRT_MXFP8DequantizeLinear"
1093-
and any(".weight" in input for input in node.input)
1094-
]
1095-
gelu_nodes = [node for node in graph.node if node.op_type == "Gelu"]
1096-
logger.debug(f"Found {len(weight_dq_nodes)} weight DQ nodes and {len(gelu_nodes)} GELU nodes")
1097-
1098-
for node in weight_dq_nodes:
1099-
# Get weights and node attributes
1100-
weight_name = node.input[0]
1101-
logger.debug(f"Processing MXFP8 conversion for weight {weight_name}")
1102-
weight = numpy_helper.to_array(initializer_map[weight_name])
1103-
if has_attribute(node, "axis"):
1104-
quant_axis = int(get_attribute(node, "axis"))
1105-
else:
1106-
quant_axis = -1
1107-
logger.warning(
1108-
"axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1"
1109-
)
1110-
1111-
if has_attribute(node, "block_size"):
1112-
block_size = int(get_attribute(node, "block_size"))
1113-
else:
1114-
block_size = 32
1115-
logger.warning(
1116-
"block_size attribute not found for MXFP8DequantizeLinear node. Setting block_size to 32"
1117-
)
1118-
1119-
# Compute and save scales as uint8
1120-
amax = get_amax(weight, quant_axis, block_size)
1121-
se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size)
1122-
se8m0 = se8m0_fp32.astype(np.uint8)
1123-
1124-
# Remove scale producer if it's a Constant node
1125-
scale_name = node.input[1]
1126-
scale_producer = tensor_producer_map[scale_name]
1127-
if scale_producer.op_type == "Constant":
1128-
graph.node.remove(scale_producer)
1129-
1130-
# Create a new scale tensor
1131-
scale_name = scale_name.replace("Constant_output_0", "scale")
1132-
scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name)
1133-
graph.initializer.append(scale_tensor)
1134-
node.input[1] = scale_name
1135-
1136-
# Convert weights to FP8
1137-
# Expand block array so that it can be broadcasted with weight
1138-
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
1139-
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias)
1140-
weights_e4m3 = onnx.helper.make_tensor(
1141-
name=weight_name,
1142-
data_type=onnx_dtype_map["Float8"],
1143-
dims=[*scaled_weight.shape],
1144-
vals=_cast_fp8(scaled_weight).tobytes(),
1145-
raw=True,
1146-
)
1147-
initializer_map[weight_name].CopyFrom(weights_e4m3)
1148-
logger.debug(f"Converted {weight_name} to MXFP8")
1149-
1150-
# set output type of DQ to FP16
1151-
for node in graph.node:
1152-
if node.op_type in ["TRT_MXFP8DequantizeLinear"]:
1153-
for attr in node.attribute:
1154-
if attr.name == "output_dtype":
1155-
attr.i = onnx_dtype_map["Half"]
1156-
1157-
# Currently only tanh approximation is supported for Gelu
1158-
for node in gelu_nodes:
1159-
for attr in node.attribute:
1160-
if attr.name == "approximate":
1161-
attr.s = b"tanh"
1162-
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")
1163-
1164-
return onnx_model
1165-
1166-
11671066
def replace_fp4qdq_with_2dq(
11681067
graph: onnx.GraphProto,
11691068
node: onnx.NodeProto,

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from modelopt.onnx.quantization.qdq_utils import (
4444
fp4qdq_to_2dq,
4545
qdq_to_dq,
46-
quantize_weights_to_mxfp8,
4746
replace_zero_scale_with_smallest_nonzero,
4847
)
4948
from modelopt.onnx.utils import (
@@ -365,6 +364,11 @@ def is_fp8_quantized(model: nn.Module) -> bool:
365364
and hasattr(module, "input_quantizer")
366365
and module.weight_quantizer._num_bits == (4, 3)
367366
and module.input_quantizer._num_bits == (4, 3)
367+
# Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits
368+
and not (
369+
module.input_quantizer.block_sizes
370+
and module.input_quantizer.block_sizes.get("scale_bits", None) == (8, 0)
371+
)
368372
):
369373
return True
370374
return False
@@ -561,14 +565,11 @@ def get_onnx_bytes_and_metadata(
561565

562566
# Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode
563567
# Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode
564-
if is_int4_quantized(model):
568+
if is_int4_quantized(model) or is_mxfp8_quantized(model):
565569
onnx_opt_graph = quantize_weights(model, onnx_opt_graph)
566570
elif is_fp4_quantized(model):
567571
# TODO: Implement the NVFP4QuantExporter
568572
onnx_opt_graph = fp4qdq_to_2dq(onnx_opt_graph)
569-
elif is_mxfp8_quantized(model):
570-
# TODO: Implement the MXFP8QuantExporter
571-
onnx_opt_graph = quantize_weights_to_mxfp8(onnx_opt_graph)
572573

573574
if dq_only:
574575
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)

tests/unit/onnx/test_qdq_utils.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,8 @@
1818
from onnx import TensorProto, helper, numpy_helper
1919

2020
from modelopt.onnx.export.int4_exporter import INT4QuantExporter
21-
from modelopt.onnx.quantization.qdq_utils import (
22-
_cast_fp4,
23-
_cast_fp8,
24-
fp4qdq_to_2dq,
25-
quantize_weights_to_mxfp8,
26-
)
21+
from modelopt.onnx.export.mxfp8_exporter import MXFP8QuantExporter
22+
from modelopt.onnx.quantization.qdq_utils import _cast_fp4, _cast_fp8, fp4qdq_to_2dq
2723

2824

2925
def create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale: bool = False):
@@ -474,15 +470,15 @@ def test_cast_fp4(self, input_array, expected_array):
474470
assert np.all(result == expected_array)
475471

476472

477-
class TestQuantizeWeightsToMXFP8:
478-
"""Test suite for quantize_weights_to_mxfp8 function."""
473+
class TestMXFP8QuantExporter:
474+
"""Test suite for MXFP8QuantExporter."""
479475

480476
def test_basic_mxfp8_quantization(self):
481477
"""Test basic MXFP8 quantization with TRT_MXFP8DequantizeLinear nodes."""
482478
model = create_test_model_with_mxfp8_dq()
483479

484480
# Run MXFP8 quantization
485-
quantized_model = quantize_weights_to_mxfp8(model)
481+
quantized_model = MXFP8QuantExporter.process_model(model)
486482

487483
# Verify weight is converted to FP8
488484
weight_tensor = next(
@@ -513,7 +509,7 @@ def test_mxfp8_output_dtype_update(self):
513509
model = create_test_model_with_mxfp8_dq()
514510

515511
# Run MXFP8 quantization
516-
quantized_model = quantize_weights_to_mxfp8(model)
512+
quantized_model = MXFP8QuantExporter.process_model(model)
517513

518514
# Verify output_dtype is set to FP16
519515
dq_node = next(
@@ -529,7 +525,7 @@ def test_mxfp8_gelu_approximation_update(self):
529525
model = create_test_model_with_mxfp8_dq()
530526

531527
# Run MXFP8 quantization
532-
quantized_model = quantize_weights_to_mxfp8(model)
528+
quantized_model = MXFP8QuantExporter.process_model(model)
533529

534530
# Verify Gelu approximation is set to tanh
535531
gelu_node = next(node for node in quantized_model.graph.node if node.op_type == "Gelu")
@@ -577,7 +573,7 @@ def test_mxfp8_with_missing_attributes(self):
577573
model = helper.make_model(graph)
578574

579575
# Run MXFP8 quantization (should use default values)
580-
quantized_model = quantize_weights_to_mxfp8(model)
576+
quantized_model = MXFP8QuantExporter.process_model(model)
581577

582578
# Verify the model is still processed correctly
583579
weight_tensor = next(

0 commit comments

Comments
 (0)