Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 162 additions & 4 deletions modelopt/onnx/export/mxfp8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,185 @@

"""MXFP8 quantization exporter."""

import numpy as np
import onnx
from onnx import numpy_helper

from modelopt.onnx.logging_config import logger
from modelopt.onnx.quantization.graph_utils import get_tensor_producer_nodes
from modelopt.onnx.quantization.qdq_utils import _cast_fp8, onnx_dtype_map
from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax
from modelopt.onnx.utils import get_attribute, has_attribute

from .base_exporter import ONNXQuantExporter

E8_M0_BIAS = 127
DEFAULT_BLOCK_SIZE = 32
DEFAULT_QUANT_AXIS = -1


def _get_weight_dq_nodes(graph: onnx.GraphProto) -> list[onnx.NodeProto]:
"""Get weight DequantizeLinear nodes from the graph."""
return [
node
for node in graph.node
if node.op_type == "TRT_MXFP8DequantizeLinear"
and any(".weight" in inp for inp in node.input)
]


def _get_quant_params(node: onnx.NodeProto) -> tuple[int, int]:
"""Extract quantization axis and block size from a node."""
if has_attribute(node, "axis"):
quant_axis = int(get_attribute(node, "axis"))
else:
quant_axis = DEFAULT_QUANT_AXIS
logger.warning(
"axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1"
)

if has_attribute(node, "block_size"):
block_size = int(get_attribute(node, "block_size"))
else:
block_size = DEFAULT_BLOCK_SIZE
logger.warning(
"block_size attribute not found for MXFP8DequantizeLinear node. "
"Setting block_size to 32"
)

return quant_axis, block_size


# TODO: Implement the MXFP8QuantExporter
class MXFP8QuantExporter(ONNXQuantExporter):
"""Exporter for MXFP8 quantization."""

@staticmethod
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Pre-processes the ONNX model for MXFP8 quantization."""
return onnx_model

@staticmethod
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Computes the scales for the weights in the ONNX model for MXFP8 quantization."""
"""Computes the e8m0 scales for weights in the ONNX model for MXFP8 quantization."""
logger.info("Computing MXFP8 scales for weights")
graph = onnx_model.graph
initializer_map = {init.name: init for init in graph.initializer}
tensor_producer_map = get_tensor_producer_nodes(graph)

for node in _get_weight_dq_nodes(graph):
weight_name = node.input[0]
logger.debug(f"Computing MXFP8 scale for weight {weight_name}")

weight = numpy_helper.to_array(initializer_map[weight_name])
quant_axis, block_size = _get_quant_params(node)

# Compute scales
amax = get_amax(weight, quant_axis, block_size)
se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size)
se8m0 = se8m0_fp32.astype(np.uint8)

# Remove scale producer if it's a Constant node
scale_name = node.input[1]
scale_producer = tensor_producer_map[scale_name]
if scale_producer.op_type == "Constant":
graph.node.remove(scale_producer)

# Create and add new scale tensor
scale_name_new = scale_name.replace("Constant_output_0", "scale")
scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name_new)
graph.initializer.append(scale_tensor)
node.input[1] = scale_name_new

return onnx_model

@staticmethod
def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Compresses the weights in the ONNX model for MXFP8 quantization."""
"""Compresses the weights in the ONNX model to FP8 format for MXFP8 quantization."""
logger.info("Compressing weights to MXFP8 format")
graph = onnx_model.graph
initializer_map = {init.name: init for init in graph.initializer}

for node in _get_weight_dq_nodes(graph):
weight_name = node.input[0]
scale_name = node.input[1]
logger.debug(f"Compressing weight {weight_name} to MXFP8")

weight = numpy_helper.to_array(initializer_map[weight_name])
quant_axis, block_size = _get_quant_params(node)

# Get scale and convert back to fp32 for computation
se8m0 = numpy_helper.to_array(initializer_map[scale_name])
se8m0_fp32 = se8m0.astype(np.float32)

# Expand block array so that it can be broadcasted with weight
se8m0_fp32_expanded = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
scaled_weight = weight / np.exp2(se8m0_fp32_expanded - E8_M0_BIAS)

# Create FP8 weight tensor
weights_e4m3 = onnx.helper.make_tensor(
name=weight_name,
data_type=onnx_dtype_map["Float8"],
dims=[*scaled_weight.shape],
vals=_cast_fp8(scaled_weight).tobytes(),
raw=True,
)
initializer_map[weight_name].CopyFrom(weights_e4m3)
logger.debug(f"Converted {weight_name} to MXFP8")

return onnx_model

@staticmethod
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
"""Post-processes the ONNX model for MXFP8 quantization."""
"""Post-processes the ONNX model for MXFP8 quantization.

Sets DQ output type to FP16 and updates GELU nodes to use tanh approximation.
"""
logger.info("Post-processing MXFP8 quantized model")
graph = onnx_model.graph

# Set output type of DQ to FP16
for node in graph.node:
if node.op_type == "TRT_MXFP8DequantizeLinear":
for attr in node.attribute:
if attr.name == "output_dtype":
attr.i = onnx_dtype_map["Half"]

# Currently only tanh approximation is supported for Gelu
for node in graph.node:
if node.op_type == "Gelu":
for attr in node.attribute:
if attr.name == "approximate":
attr.s = b"tanh"
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")

# Insert cast to fp16 after Sqrt nodes
cast_nodes_to_insert = []
for idx, node in enumerate(graph.node):
if node.op_type == "Sqrt":
sqrt_output = node.output[0]
cast_output = f"{sqrt_output}_cast_fp16"

# Create Cast node
cast_node = onnx.helper.make_node(
"Cast",
inputs=[sqrt_output],
outputs=[cast_output],
to=onnx_dtype_map["Half"],
name=f"{node.name}_cast_fp16",
)
cast_nodes_to_insert.append((idx + 1, cast_node))

# Update consumers to use cast output
for consumer in graph.node:
if consumer == node:
continue
for i, inp in enumerate(consumer.input):
if inp == sqrt_output:
consumer.input[i] = cast_output

# Insert Cast nodes in reverse order to preserve indices
for offset, (pos, cast_node) in enumerate(cast_nodes_to_insert):
graph.node.insert(pos + offset, cast_node)
logger.debug(f"Inserted Cast to FP16 after {cast_node.input[0]}")

return onnx_model
101 changes: 1 addition & 100 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
get_tensor_producer_nodes,
remove_redundant_cast_nodes,
)
from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax, get_num_bits
from modelopt.onnx.utils import get_attribute, has_attribute
from modelopt.onnx.quantization.quant_utils import get_num_bits

QUANTIZE_NODE_NAME = "QuantizeLinear"
DEQUANTIZE_NODE_NAME = "DequantizeLinear"
Expand Down Expand Up @@ -1036,101 +1035,3 @@ def cast_initializer_to_dtype(
input_onnx = onnx.numpy_helper.from_array(input, input_name)
input_onnx.data_type = onnx_dtype_map[dtype]
initializer_map[input_name].CopyFrom(input_onnx)


def quantize_weights_to_mxfp8(
onnx_model: onnx.ModelProto,
) -> onnx.ModelProto:
"""Converts the weights to FP8 precision using MXFP8 quantization.

For TRT_MXFP8DynamicQuantize, we update the output type to FP8.
For TRT_MXFP8DequantizeLinear, we compute the scales in e8m0 format and saves them as a new initializer.
We then expand the scale to the same shape as the weight and divide the weight by the scale to get the FP8 weights.

Args:
graph: ONNX model protobuf.

Returns:
ONNX model protobuf with weights quantized to FP8 precision using MXFP8 quantization.
"""
logger.info("Converting weights to MXFP8 precision")
graph = onnx_model.graph
initializer_map = {initializer.name: initializer for initializer in graph.initializer}
tensor_producer_map = get_tensor_producer_nodes(graph)
e8_m0_bias = 127
weight_dq_nodes = [
node
for node in graph.node
if node.op_type == "TRT_MXFP8DequantizeLinear"
and any(".weight" in input for input in node.input)
]
gelu_nodes = [node for node in graph.node if node.op_type == "Gelu"]
logger.debug(f"Found {len(weight_dq_nodes)} weight DQ nodes and {len(gelu_nodes)} GELU nodes")

for node in weight_dq_nodes:
# Get weights and node attributes
weight_name = node.input[0]
logger.debug(f"Processing MXFP8 conversion for weight {weight_name}")
weight = numpy_helper.to_array(initializer_map[weight_name])
if has_attribute(node, "axis"):
quant_axis = int(get_attribute(node, "axis"))
else:
quant_axis = -1
logger.warning(
"axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1"
)

if has_attribute(node, "block_size"):
block_size = int(get_attribute(node, "block_size"))
else:
block_size = 32
logger.warning(
"block_size attribute not found for MXFP8DequantizeLinear node. Setting block_size to 32"
)

# Compute and save scales as uint8
amax = get_amax(weight, quant_axis, block_size)
se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size)
se8m0 = se8m0_fp32.astype(np.uint8)

# Remove scale producer if it's a Constant node
scale_name = node.input[1]
scale_producer = tensor_producer_map[scale_name]
if scale_producer.op_type == "Constant":
graph.node.remove(scale_producer)

# Create a new scale tensor
scale_name = scale_name.replace("Constant_output_0", "scale")
scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name)
graph.initializer.append(scale_tensor)
node.input[1] = scale_name

# Convert weights to FP8
# Expand block array so that it can be broadcasted with weight
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias)
weights_e4m3 = onnx.helper.make_tensor(
name=weight_name,
data_type=onnx_dtype_map["Float8"],
dims=[*scaled_weight.shape],
vals=_cast_fp8(scaled_weight).tobytes(),
raw=True,
)
initializer_map[weight_name].CopyFrom(weights_e4m3)
logger.debug(f"Converted {weight_name} to MXFP8")

# set output type of DQ to FP16
for node in graph.node:
if node.op_type in ["TRT_MXFP8DequantizeLinear"]:
for attr in node.attribute:
if attr.name == "output_dtype":
attr.i = onnx_dtype_map["Half"]

# Currently only tanh approximation is supported for Gelu
for node in gelu_nodes:
for attr in node.attribute:
if attr.name == "approximate":
attr.s = b"tanh"
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")

return onnx_model
18 changes: 8 additions & 10 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@
NVFP4QuantExporter,
ONNXQuantExporter,
)
from modelopt.onnx.quantization.qdq_utils import (
qdq_to_dq,
quantize_weights_to_mxfp8,
replace_zero_scale_with_smallest_nonzero,
)
from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero
from modelopt.onnx.utils import (
get_input_names,
get_input_shapes,
Expand Down Expand Up @@ -364,6 +360,11 @@ def is_fp8_quantized(model: nn.Module) -> bool:
and hasattr(module, "input_quantizer")
and module.weight_quantizer._num_bits == (4, 3)
and module.input_quantizer._num_bits == (4, 3)
# Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits
and not (
module.input_quantizer.block_sizes
and module.input_quantizer.block_sizes.get("scale_bits", None) == (8, 0)
)
):
return True
return False
Expand Down Expand Up @@ -560,11 +561,8 @@ def get_onnx_bytes_and_metadata(

# Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode
# Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode
if is_int4_quantized(model) or is_fp4_quantized(model):
if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model):
onnx_opt_graph = quantize_weights(model, onnx_opt_graph)
elif is_mxfp8_quantized(model):
# TODO: Implement the MXFP8QuantExporter
onnx_opt_graph = quantize_weights_to_mxfp8(onnx_opt_graph)

if dq_only:
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)
Expand All @@ -575,7 +573,7 @@ def get_onnx_bytes_and_metadata(
except StopIteration:
param_dtype = torch.float32
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
if is_mxfp8_quantized(model) or is_int4_quantized(model):
if is_int4_quantized(model) or is_mxfp8_quantized(model):
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
Expand Down
Loading