Skip to content

Commit 17d59a4

Browse files
committed
[OMNIML-2244] Create MXFP8 quant exporter
Signed-off-by: ajrasane <[email protected]>
1 parent 3ef9e39 commit 17d59a4

File tree

4 files changed

+167
-116
lines changed

4 files changed

+167
-116
lines changed

modelopt/onnx/export/mxfp8_exporter.py

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,175 @@
1515

1616
"""MXFP8 quantization exporter."""
1717

18+
import numpy as np
1819
import onnx
20+
from onnx import numpy_helper
21+
22+
from modelopt.onnx.logging_config import logger
23+
from modelopt.onnx.quantization.graph_utils import get_tensor_producer_nodes
24+
from modelopt.onnx.quantization.qdq_utils import _cast_fp8, onnx_dtype_map
25+
from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax
26+
from modelopt.onnx.utils import get_attribute, has_attribute
1927

2028
from .base_exporter import ONNXQuantExporter
2129

30+
E8_M0_BIAS = 127
31+
DEFAULT_BLOCK_SIZE = 32
32+
DEFAULT_QUANT_AXIS = -1
33+
34+
35+
def _get_weight_dq_nodes(graph: onnx.GraphProto) -> list[onnx.NodeProto]:
36+
"""Get weight DequantizeLinear nodes from the graph."""
37+
return [
38+
node
39+
for node in graph.node
40+
if node.op_type == "TRT_MXFP8DequantizeLinear"
41+
and any(".weight" in inp for inp in node.input)
42+
]
43+
44+
45+
def _get_quant_params(node: onnx.NodeProto) -> tuple[int, int]:
46+
"""Extract quantization axis and block size from a node."""
47+
if has_attribute(node, "axis"):
48+
quant_axis = int(get_attribute(node, "axis"))
49+
else:
50+
quant_axis = DEFAULT_QUANT_AXIS
51+
logger.warning(
52+
"axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1"
53+
)
54+
55+
if has_attribute(node, "block_size"):
56+
block_size = int(get_attribute(node, "block_size"))
57+
else:
58+
block_size = DEFAULT_BLOCK_SIZE
59+
logger.warning(
60+
"block_size attribute not found for MXFP8DequantizeLinear node. "
61+
"Setting block_size to 32"
62+
)
63+
64+
return quant_axis, block_size
65+
2266

23-
# TODO: Implement the MXFP8QuantExporter
2467
class MXFP8QuantExporter(ONNXQuantExporter):
2568
"""Exporter for MXFP8 quantization."""
2669

2770
@staticmethod
2871
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
2972
"""Pre-processes the ONNX model for MXFP8 quantization."""
73+
return onnx_model
3074

3175
@staticmethod
3276
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
33-
"""Computes the scales for the weights in the ONNX model for MXFP8 quantization."""
77+
"""Computes the e8m0 scales for weights in the ONNX model for MXFP8 quantization."""
78+
logger.info("Computing MXFP8 scales for weights")
79+
graph = onnx_model.graph
80+
initializer_map = {init.name: init for init in graph.initializer}
81+
tensor_producer_map = get_tensor_producer_nodes(graph)
82+
83+
for node in _get_weight_dq_nodes(graph):
84+
weight_name = node.input[0]
85+
logger.debug(f"Computing MXFP8 scale for weight {weight_name}")
86+
87+
weight = numpy_helper.to_array(initializer_map[weight_name])
88+
quant_axis, block_size = _get_quant_params(node)
89+
90+
# Compute scales
91+
amax = get_amax(weight, quant_axis, block_size)
92+
se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size)
93+
se8m0 = se8m0_fp32.astype(np.uint8)
94+
95+
# Remove scale producer if it's a Constant node
96+
scale_name = node.input[1]
97+
scale_producer = tensor_producer_map[scale_name]
98+
if scale_producer.op_type == "Constant":
99+
graph.node.remove(scale_producer)
100+
101+
# Create and add new scale tensor
102+
scale_name_new = scale_name.replace("Constant_output_0", "scale")
103+
scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name_new)
104+
graph.initializer.append(scale_tensor)
105+
node.input[1] = scale_name_new
106+
107+
return onnx_model
34108

35109
@staticmethod
36110
def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
37-
"""Compresses the weights in the ONNX model for MXFP8 quantization."""
111+
"""Compresses the weights in the ONNX model to FP8 format for MXFP8 quantization."""
112+
logger.info("Compressing weights to MXFP8 format")
113+
graph = onnx_model.graph
114+
initializer_map = {init.name: init for init in graph.initializer}
115+
116+
for node in _get_weight_dq_nodes(graph):
117+
weight_name = node.input[0]
118+
scale_name = node.input[1]
119+
logger.debug(f"Compressing weight {weight_name} to MXFP8")
120+
121+
weight = numpy_helper.to_array(initializer_map[weight_name])
122+
quant_axis, block_size = _get_quant_params(node)
123+
124+
# Get scale and convert back to fp32 for computation
125+
se8m0 = numpy_helper.to_array(initializer_map[scale_name])
126+
se8m0_fp32 = se8m0.astype(np.float32)
127+
128+
# Expand block array so that it can be broadcasted with weight
129+
se8m0_fp32_expanded = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
130+
scaled_weight = weight / np.exp2(se8m0_fp32_expanded - E8_M0_BIAS)
131+
132+
# Create FP8 weight tensor
133+
weights_e4m3 = onnx.helper.make_tensor(
134+
name=weight_name,
135+
data_type=onnx_dtype_map["Float8"],
136+
dims=[*scaled_weight.shape],
137+
vals=_cast_fp8(scaled_weight).tobytes(),
138+
raw=True,
139+
)
140+
initializer_map[weight_name].CopyFrom(weights_e4m3)
141+
logger.debug(f"Converted {weight_name} to MXFP8")
142+
143+
return onnx_model
38144

39145
@staticmethod
40146
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
41-
"""Post-processes the ONNX model for MXFP8 quantization."""
147+
"""Post-processes the ONNX model for MXFP8 quantization.
148+
149+
Sets DQ output type to FP16 and updates GELU nodes to use tanh approximation.
150+
"""
151+
logger.info("Post-processing MXFP8 quantized model")
152+
graph = onnx_model.graph
153+
154+
# Set output type of DQ to FP16
155+
for node in graph.node:
156+
if node.op_type == "TRT_MXFP8DequantizeLinear":
157+
for attr in node.attribute:
158+
if attr.name == "output_dtype":
159+
attr.i = onnx_dtype_map["Half"]
160+
161+
# Currently only tanh approximation is supported for Gelu
162+
for node in graph.node:
163+
if node.op_type == "Gelu":
164+
for attr in node.attribute:
165+
if attr.name == "approximate":
166+
attr.s = b"tanh"
167+
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")
168+
169+
def is_fp32_cast(node: onnx.NodeProto) -> bool:
170+
return node.op_type == "Cast" and any(
171+
attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute
172+
)
173+
174+
# Remove Cast nodes after specific operators
175+
nodes_to_remove = []
176+
for node in graph.node:
177+
if node.op_type in ["Transpose", "Reshape", "Sqrt", "Add", "Gelu"]:
178+
child_nodes = [n for n in graph.node if node.output[0] in n.input]
179+
if len(child_nodes) == 1 and is_fp32_cast(child_nodes[0]):
180+
cast_node = child_nodes[0]
181+
node.output.clear()
182+
node.output.extend(cast_node.output)
183+
nodes_to_remove.append(cast_node.name)
184+
185+
# Remove unnecessary casts
186+
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
187+
graph.node.extend(new_nodes)
188+
189+
return onnx_model

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 2 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
get_tensor_producer_nodes,
3232
remove_redundant_cast_nodes,
3333
)
34-
from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax, get_num_bits
35-
from modelopt.onnx.utils import get_attribute, has_attribute
34+
from modelopt.onnx.quantization.quant_utils import get_num_bits
35+
3636

3737
QUANTIZE_NODE_NAME = "QuantizeLinear"
3838
DEQUANTIZE_NODE_NAME = "DequantizeLinear"
@@ -1036,101 +1036,3 @@ def cast_initializer_to_dtype(
10361036
input_onnx = onnx.numpy_helper.from_array(input, input_name)
10371037
input_onnx.data_type = onnx_dtype_map[dtype]
10381038
initializer_map[input_name].CopyFrom(input_onnx)
1039-
1040-
1041-
def quantize_weights_to_mxfp8(
1042-
onnx_model: onnx.ModelProto,
1043-
) -> onnx.ModelProto:
1044-
"""Converts the weights to FP8 precision using MXFP8 quantization.
1045-
1046-
For TRT_MXFP8DynamicQuantize, we update the output type to FP8.
1047-
For TRT_MXFP8DequantizeLinear, we compute the scales in e8m0 format and saves them as a new initializer.
1048-
We then expand the scale to the same shape as the weight and divide the weight by the scale to get the FP8 weights.
1049-
1050-
Args:
1051-
graph: ONNX model protobuf.
1052-
1053-
Returns:
1054-
ONNX model protobuf with weights quantized to FP8 precision using MXFP8 quantization.
1055-
"""
1056-
logger.info("Converting weights to MXFP8 precision")
1057-
graph = onnx_model.graph
1058-
initializer_map = {initializer.name: initializer for initializer in graph.initializer}
1059-
tensor_producer_map = get_tensor_producer_nodes(graph)
1060-
e8_m0_bias = 127
1061-
weight_dq_nodes = [
1062-
node
1063-
for node in graph.node
1064-
if node.op_type == "TRT_MXFP8DequantizeLinear"
1065-
and any(".weight" in input for input in node.input)
1066-
]
1067-
gelu_nodes = [node for node in graph.node if node.op_type == "Gelu"]
1068-
logger.debug(f"Found {len(weight_dq_nodes)} weight DQ nodes and {len(gelu_nodes)} GELU nodes")
1069-
1070-
for node in weight_dq_nodes:
1071-
# Get weights and node attributes
1072-
weight_name = node.input[0]
1073-
logger.debug(f"Processing MXFP8 conversion for weight {weight_name}")
1074-
weight = numpy_helper.to_array(initializer_map[weight_name])
1075-
if has_attribute(node, "axis"):
1076-
quant_axis = int(get_attribute(node, "axis"))
1077-
else:
1078-
quant_axis = -1
1079-
logger.warning(
1080-
"axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1"
1081-
)
1082-
1083-
if has_attribute(node, "block_size"):
1084-
block_size = int(get_attribute(node, "block_size"))
1085-
else:
1086-
block_size = 32
1087-
logger.warning(
1088-
"block_size attribute not found for MXFP8DequantizeLinear node. Setting block_size to 32"
1089-
)
1090-
1091-
# Compute and save scales as uint8
1092-
amax = get_amax(weight, quant_axis, block_size)
1093-
se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size)
1094-
se8m0 = se8m0_fp32.astype(np.uint8)
1095-
1096-
# Remove scale producer if it's a Constant node
1097-
scale_name = node.input[1]
1098-
scale_producer = tensor_producer_map[scale_name]
1099-
if scale_producer.op_type == "Constant":
1100-
graph.node.remove(scale_producer)
1101-
1102-
# Create a new scale tensor
1103-
scale_name = scale_name.replace("Constant_output_0", "scale")
1104-
scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name)
1105-
graph.initializer.append(scale_tensor)
1106-
node.input[1] = scale_name
1107-
1108-
# Convert weights to FP8
1109-
# Expand block array so that it can be broadcasted with weight
1110-
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
1111-
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias)
1112-
weights_e4m3 = onnx.helper.make_tensor(
1113-
name=weight_name,
1114-
data_type=onnx_dtype_map["Float8"],
1115-
dims=[*scaled_weight.shape],
1116-
vals=_cast_fp8(scaled_weight).tobytes(),
1117-
raw=True,
1118-
)
1119-
initializer_map[weight_name].CopyFrom(weights_e4m3)
1120-
logger.debug(f"Converted {weight_name} to MXFP8")
1121-
1122-
# set output type of DQ to FP16
1123-
for node in graph.node:
1124-
if node.op_type in ["TRT_MXFP8DequantizeLinear"]:
1125-
for attr in node.attribute:
1126-
if attr.name == "output_dtype":
1127-
attr.i = onnx_dtype_map["Half"]
1128-
1129-
# Currently only tanh approximation is supported for Gelu
1130-
for node in gelu_nodes:
1131-
for attr in node.attribute:
1132-
if attr.name == "approximate":
1133-
attr.s = b"tanh"
1134-
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")
1135-
1136-
return onnx_model

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
)
4343
from modelopt.onnx.quantization.qdq_utils import (
4444
qdq_to_dq,
45-
quantize_weights_to_mxfp8,
4645
replace_zero_scale_with_smallest_nonzero,
4746
)
4847
from modelopt.onnx.utils import (
@@ -364,6 +363,11 @@ def is_fp8_quantized(model: nn.Module) -> bool:
364363
and hasattr(module, "input_quantizer")
365364
and module.weight_quantizer._num_bits == (4, 3)
366365
and module.input_quantizer._num_bits == (4, 3)
366+
# Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits
367+
and not (
368+
module.input_quantizer.block_sizes
369+
and module.input_quantizer.block_sizes.get("scale_bits", None) == (8, 0)
370+
)
367371
):
368372
return True
369373
return False
@@ -560,11 +564,8 @@ def get_onnx_bytes_and_metadata(
560564

561565
# Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode
562566
# Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode
563-
if is_int4_quantized(model) or is_fp4_quantized(model):
567+
if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model):
564568
onnx_opt_graph = quantize_weights(model, onnx_opt_graph)
565-
elif is_mxfp8_quantized(model):
566-
# TODO: Implement the MXFP8QuantExporter
567-
onnx_opt_graph = quantize_weights_to_mxfp8(onnx_opt_graph)
568569

569570
if dq_only:
570571
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)

tests/unit/onnx/test_qdq_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from modelopt.onnx.export import NVFP4QuantExporter
2121
from modelopt.onnx.export.int4_exporter import INT4QuantExporter
2222
from modelopt.onnx.export.nvfp4_exporter import _cast_fp4, _cast_fp8
23-
from modelopt.onnx.quantization.qdq_utils import quantize_weights_to_mxfp8
23+
from modelopt.onnx.export.mxfp8_exporter import MXFP8QuantExporter
2424

2525

2626
def create_test_model_with_int4_dq_reshape_transpose_matmul(constant_scale: bool = False):
@@ -471,15 +471,15 @@ def test_cast_fp4(self, input_array, expected_array):
471471
assert np.all(result == expected_array)
472472

473473

474-
class TestQuantizeWeightsToMXFP8:
475-
"""Test suite for quantize_weights_to_mxfp8 function."""
474+
class TestMXFP8QuantExporter:
475+
"""Test suite for MXFP8QuantExporter."""
476476

477477
def test_basic_mxfp8_quantization(self):
478478
"""Test basic MXFP8 quantization with TRT_MXFP8DequantizeLinear nodes."""
479479
model = create_test_model_with_mxfp8_dq()
480480

481481
# Run MXFP8 quantization
482-
quantized_model = quantize_weights_to_mxfp8(model)
482+
quantized_model = MXFP8QuantExporter.process_model(model)
483483

484484
# Verify weight is converted to FP8
485485
weight_tensor = next(
@@ -510,7 +510,7 @@ def test_mxfp8_output_dtype_update(self):
510510
model = create_test_model_with_mxfp8_dq()
511511

512512
# Run MXFP8 quantization
513-
quantized_model = quantize_weights_to_mxfp8(model)
513+
quantized_model = MXFP8QuantExporter.process_model(model)
514514

515515
# Verify output_dtype is set to FP16
516516
dq_node = next(
@@ -526,7 +526,7 @@ def test_mxfp8_gelu_approximation_update(self):
526526
model = create_test_model_with_mxfp8_dq()
527527

528528
# Run MXFP8 quantization
529-
quantized_model = quantize_weights_to_mxfp8(model)
529+
quantized_model = MXFP8QuantExporter.process_model(model)
530530

531531
# Verify Gelu approximation is set to tanh
532532
gelu_node = next(node for node in quantized_model.graph.node if node.op_type == "Gelu")
@@ -574,7 +574,7 @@ def test_mxfp8_with_missing_attributes(self):
574574
model = helper.make_model(graph)
575575

576576
# Run MXFP8 quantization (should use default values)
577-
quantized_model = quantize_weights_to_mxfp8(model)
577+
quantized_model = MXFP8QuantExporter.process_model(model)
578578

579579
# Verify the model is still processed correctly
580580
weight_tensor = next(

0 commit comments

Comments
 (0)