Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions examples/onnx_ptq/torch_quant_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,12 @@ def forward_loop(model):
return quantized_model


def get_model_input_shape(model_name):
def get_model_input_shape(model_name, batch_size):
"""Get the input shape from timm model configuration."""
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
data_config = timm.data.resolve_model_data_config(model)
input_size = data_config["input_size"]
return (1, *tuple(input_size)) # Add batch dimension
return (batch_size, *tuple(input_size)) # Add batch dimension


def main():
Expand Down Expand Up @@ -119,11 +119,17 @@ def main():
default=512,
help="Number of images to use in calibration [1-512]",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size for calibration and ONNX model export.",
)

args = parser.parse_args()

# Get input shape from model config
input_shape = get_model_input_shape(args.timm_model_name)
input_shape = get_model_input_shape(args.timm_model_name, args.batch_size)

# Create model and move to appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down
78 changes: 65 additions & 13 deletions modelopt/onnx/quantization/qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import onnx_graphsurgeon as gs
import torch
from onnx import numpy_helper
from onnx.reference.custom_element_types import float8e4m3fn

from modelopt.onnx import utils
from modelopt.onnx.logging_config import logger
Expand All @@ -50,6 +49,7 @@
onnx_dtype_map = {
"BFloat16": onnx.TensorProto.BFLOAT16,
"Float": onnx.TensorProto.FLOAT,
"Float4": onnx.TensorProto.FLOAT4E2M1,
"Float8": onnx.TensorProto.FLOAT8E4M3FN,
"Half": onnx.TensorProto.FLOAT16,
"INT8": onnx.TensorProto.INT8,
Expand Down Expand Up @@ -592,7 +592,7 @@ def _convert_weight(
zp_array = zp_array.reshape(*reshape_dims)

# Convert to INT8/FP8
if zp_array.dtype == float8e4m3fn:
if zp_array.dtype == onnx_dtype_map["Float8"]:
scaled = np.asarray(weight_array / scale_array) + zp_array
else:
scaled = np.asarray((weight_array / scale_array).round())
Expand All @@ -607,17 +607,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
if torch.cuda.is_available():
array_f32_t = array_f32_t.cuda()
array_f8_t = array_f32_t.clamp(min=-448, max=448).to(torch.float8_e4m3fn).view(torch.uint8)
array_f8 = array_f8_t.cpu().numpy().astype((np.uint8, [("e4m3fn", "u1")]))
array_f8 = array_f8_t.cpu().numpy().astype(np.uint8)
return array_f8


def _cast_fp4(array: np.ndarray) -> np.ndarray:
"""Cast a numpy array to FLOAT4E2M1 using PyTorch."""
"""Cast a numpy array to FLOAT4E2M1 using PyTorch.

Note: The first dimension of the array must be divisible by 2
as two FP4 values are packed into a single byte.
"""
array_f32_t = torch.from_numpy(array)
array_f32_t_shape = array_f32_t.shape
assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2"
array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:])
if torch.cuda.is_available():
array_f32_t = array_f32_t.cuda()
array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t)
array_f4 = array_f4_t.cpu().numpy().astype((np.uint8, [("float4e2m1", "u1")]))
array_f4_t = array_f4_t.flatten()
array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape)
array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8)
return array_f4


Expand Down Expand Up @@ -685,7 +694,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node)

# Create and update new weight tensor
if zp_array.dtype == float8e4m3fn:
if zp_array.dtype == onnx_dtype_map["Float8"]:
new_weight = _create_fp8_tensor(scaled, weight_name)
logger.debug(f"Converted {weight_name} to FP8")
else:
Expand Down Expand Up @@ -920,6 +929,10 @@ def quantize_weights_to_int4(
assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}"
reshape_node_output = reshape_node.output[0]

# Remove constant node from reshape node
shape_constant_name = next(input for input in reshape_node.input if "Constant" in input)
nodes_to_remove.append(tensor_producer_map[shape_constant_name].name)

# Get the shape of the output of the reshape node
reshape_output_value_info = value_info_map.get(reshape_node_output)
if reshape_output_value_info is not None:
Expand All @@ -937,12 +950,17 @@ def quantize_weights_to_int4(
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size]
scale = scale.reshape(scale_shape)
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input]
# reshape_node.input = []
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}"

# Remove unnecessary Cast node
cast_node = reshape_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
nodes_to_remove.append(cast_node.name)
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]

# Transpose weights and scales if present
if reshape_child_nodes[0].op_type == "Transpose":
transpose_node = reshape_child_nodes[0]
if cast_child_nodes[0].op_type == "Transpose":
transpose_node = cast_child_nodes[0]
nodes_to_remove.append(transpose_node.name)
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}"
perm = None
Expand All @@ -959,7 +977,7 @@ def quantize_weights_to_int4(
)
matmul_node = transpose_child_nodes[0]
else:
matmul_node = reshape_child_nodes[0]
matmul_node = cast_child_nodes[0]
assert matmul_node.op_type in ["MatMul", "Gemm"], (
f"Expected MatMul or Gemm node for {node.name}"
)
Expand Down Expand Up @@ -990,6 +1008,21 @@ def quantize_weights_to_int4(
initializer_map[weight_name].CopyFrom(weights_int4_onnx)
logger.debug(f"Converted {weight_name} to INT4 precision")

def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
return node.op_type == "Mul" and has_pqs_input

# Remove unnecessay Cast after Pre-quant scale
for node in graph.node:
if is_pre_quant_scale_node(node):
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
cast_node = pqs_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
node.output.clear()
node.output.extend(cast_node.output)
nodes_to_remove.append(cast_node.name)

Comment on lines +1011 to +1025
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Pre-quant Mul detection is wrong; the naming doesn’t match inserted nodes

The scale input is named with “_awq_scale” and/or produced by a Constant/initializer, not “_pre_quant_scale”. This block won’t trigger.

Make detection structural: Mul with one initializer or Constant input.

-    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
-        has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
-        return node.op_type == "Mul" and has_pqs_input
+    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+        if node.op_type != "Mul" or len(node.input) != 2:
+            return False
+        for inp in node.input:
+            if inp in initializer_map:
+                return True
+            prod = tensor_producer_map.get(inp)
+            if prod is not None and prod.op_type == "Constant":
+                return True
+        return False

Please run the unit test that exercises removal of Cast after pre-quant Mul to confirm this now triggers.


🏁 Script executed:

#!/bin/bash

echo "Occurrences of _awq_scale:"
rg -n "_awq_scale" -g "*.py" || true

echo
echo "Occurrences of pre_quant_scale:"
rg -n "pre_quant_scale" -g "*.py" || true

echo
echo "Definition of is_pre_quant_scale_node and context:"
rg -n -C3 "def is_pre_quant_scale_node" -g "modelopt/onnx/quantization/qdq_utils.py" || true

echo
echo "initializer_map occurrences in qdq_utils.py:"
rg -n "initializer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true

echo
echo "tensor_producer_map occurrences in qdq_utils.py:"
rg -n "tensor_producer_map" -g "modelopt/onnx/quantization/qdq_utils.py" || true

Length of output: 28239


Use structural detection for pre-quant Mul in qdq_utils.py
The current is_pre_quant_scale_node looks for “_pre_quant_scale” in the input name, but the inserted scale tensors are named with “_awq_scale”, so this code never matches and the Cast isn’t removed. Replace it with a structural check—i.e. a Mul with exactly two inputs, one coming from an initializer or Constant node:

-    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
-        has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
-        return node.op_type == "Mul" and has_pqs_input
+    def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
+        if node.op_type != "Mul" or len(node.input) != 2:
+            return False
+        for inp in node.input:
+            if inp in initializer_map:
+                return True
+            prod = tensor_producer_map.get(inp)
+            if prod is not None and prod.op_type == "Constant":
+                return True
+        return False
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
return node.op_type == "Mul" and has_pqs_input
# Remove unnecessay Cast after Pre-quant scale
for node in graph.node:
if is_pre_quant_scale_node(node):
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
cast_node = pqs_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
node.output.clear()
node.output.extend(cast_node.output)
nodes_to_remove.append(cast_node.name)
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
# A pre-quantization scale is always a Mul with one constant input
if node.op_type != "Mul" or len(node.input) != 2:
return False
for inp in node.input:
# Check if one of the inputs is a graph initializer
if inp in initializer_map:
return True
# Or produced by a Constant node
prod = tensor_producer_map.get(inp)
if prod is not None and prod.op_type == "Constant":
return True
return False
# Remove unnecessary Cast after Pre-quant scale
for node in graph.node:
if is_pre_quant_scale_node(node):
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
cast_node = pqs_child_nodes[0]
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
node.output.clear()
node.output.extend(cast_node.output)
nodes_to_remove.append(cast_node.name)

# Remove transpose and reshape nodes
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
graph.node.clear()
Expand All @@ -1004,7 +1037,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
for node in graph.node:
if node.op_type == "Cast":
# Skip Cast nodes that are part of normalization layers and outputs
if ("norm/Cast" in node.name and is_fp32_cast(node)) or node.name == "/Cast":
if "norm/Cast" in node.name and is_fp32_cast(node):
continue
for attr in node.attribute:
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
Expand Down Expand Up @@ -1099,7 +1132,13 @@ def quantize_weights_to_mxfp8(
# Expand block array so that it can be broadcasted with weight
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias)
weights_e4m3 = onnx.numpy_helper.from_array(_cast_fp8(scaled_weight), weight_name)
weights_e4m3 = onnx.helper.make_tensor(
name=weight_name,
data_type=onnx_dtype_map["Float8"],
dims=[*scaled_weight.shape],
vals=_cast_fp8(scaled_weight).tobytes(),
raw=True,
)
initializer_map[weight_name].CopyFrom(weights_e4m3)
logger.debug(f"Converted {weight_name} to MXFP8")

Expand Down Expand Up @@ -1181,11 +1220,24 @@ def _add_input_value_info(graph, tensor_proto):
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"

# Create TensorProto for initializers
w_f4_proto = onnx.numpy_helper.from_array(w_f4, w_f4_name)
w_f4_proto = onnx.helper.make_tensor(
name=w_f4_name,
data_type=onnx_dtype_map["Float4"],
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
vals=w_f4.tobytes(),
raw=True,
)
Comment on lines +1223 to +1229
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

FP4 initializer dims should reflect packing along the last axis

After fixing _cast_fp4 to pack along the last dim, adjust dims accordingly.

-    w_f4_proto = onnx.helper.make_tensor(
-        name=w_f4_name,
-        data_type=onnx_dtype_map["Float4"],
-        dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
-        vals=w_f4.tobytes(),
-        raw=True,
-    )
+    w_f4_proto = onnx.helper.make_tensor(
+        name=w_f4_name,
+        data_type=onnx_dtype_map["Float4"],
+        dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2],
+        vals=w_f4.tobytes(),
+        raw=True,
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
w_f4_proto = onnx.helper.make_tensor(
name=w_f4_name,
data_type=onnx_dtype_map["Float4"],
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
vals=w_f4.tobytes(),
raw=True,
)
w_f4_proto = onnx.helper.make_tensor(
name=w_f4_name,
data_type=onnx_dtype_map["Float4"],
dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2],
vals=w_f4.tobytes(),
raw=True,
)
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/qdq_utils.py around lines 1219 to 1225, the FP4
initializer currently doubles the first dimension but FP4 packing was changed to
pack along the last axis; update the dims to reflect packing along the last axis
by replacing dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]] with
dims=[*w_f4.shape[:-1], w_f4.shape[-1] * 2] (or equivalent list/tuple
construction) so the last dimension is doubled instead of the first.

sw_f32_per_tensor_proto = onnx.numpy_helper.from_array(
sw_f32_per_tensor, sw_f32_per_tensor_name
)
sw_f8_per_block_proto = onnx.numpy_helper.from_array(sw_f8_per_block, sw_f8_per_block_name)
sw_f8_per_block_proto = onnx.helper.make_tensor(
name=sw_f8_per_block_name,
data_type=onnx_dtype_map["Float8"],
dims=[*sw_f8_per_block.shape],
vals=sw_f8_per_block.tobytes(),
raw=True,
)

# Add ValueInfo for the initializers if not present
_add_input_value_info(graph, w_f4_proto)
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/_deploy/utils/torch_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ def get_onnx_bytes_and_metadata(
except StopIteration:
param_dtype = torch.float32
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
if is_mxfp8_quantized(model):
assert weights_dtype == "fp16", "BF16 + MXFP8 mixed precision is not supported yet"
if is_mxfp8_quantized(model) or is_int4_quantized(model):
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
onnx_opt_graph = convert_float_to_float16(
onnx_opt_graph,
keep_io_types=False,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
"cupy-cuda12x; platform_machine != 'aarch64' and platform_system != 'Darwin'",
"ml_dtypes", # for bfloat16 conversion
"onnx-graphsurgeon",
"onnx~=1.18.0",
"onnxconverter-common",
"onnx~=1.19.0",
"onnxconverter-common~=1.16.0",
"onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'",
"onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501
"onnxruntime-directml==1.20.0; platform_system == 'Windows'",
Expand Down
53 changes: 34 additions & 19 deletions tests/unit/onnx/test_qdq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa

# Create reshape shape tensor
reshape_shape = np.array([16, 16], dtype=np.int64)
reshape_shape_tensor = numpy_helper.from_array(reshape_shape, "reshape_shape")

# Create input tensor for MatMul
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [None, 2])
Expand All @@ -53,16 +52,32 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa
"DequantizeLinear", inputs=dq_inputs, outputs=["dq_output"], name="weight_dq"
)

reshape_constant = helper.make_node(
"Constant",
inputs=[],
outputs=["reshape_shape_Constant"],
value=numpy_helper.from_array(reshape_shape),
name="reshape_constant",
)

reshape_node = helper.make_node(
"Reshape",
inputs=["dq_output", "reshape_shape"],
inputs=["dq_output", "reshape_shape_Constant"],
outputs=["reshape_output"],
name="weight_reshape",
)

cast_node = helper.make_node(
"Cast",
inputs=["reshape_output"],
outputs=["cast_output"],
to=TensorProto.FLOAT,
name="weight_cast",
)

transpose_node = helper.make_node(
"Transpose",
inputs=["reshape_output"],
inputs=["cast_output"],
outputs=["transpose_output"],
perm=[1, 0],
name="weight_transpose",
Expand All @@ -78,15 +93,15 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa
)

# Create graph
nodes = [dq_node, reshape_node, transpose_node, matmul_node]
nodes = [dq_node, reshape_constant, reshape_node, cast_node, transpose_node, matmul_node]
if constant_scale:
nodes.append(scale_constant)
graph = helper.make_graph(
nodes=nodes,
name="test_graph",
inputs=[input_tensor],
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [None, 16])],
initializer=[weight_tensor, scale_tensor, reshape_shape_tensor],
initializer=[weight_tensor, scale_tensor],
value_info=[reshape_output_info],
)

Expand Down Expand Up @@ -234,7 +249,7 @@ def test_cast_node_conversion(self):
if node.op_type == "Cast":
to_attr = next(attr for attr in node.attribute if attr.name == "to")

if "norm/Cast" in node.name or node.name == "/Cast":
if "norm/Cast" in node.name:
# These should remain as float32
assert to_attr.i == TensorProto.FLOAT
else:
Expand Down Expand Up @@ -297,39 +312,39 @@ def test_cast_fp8(self, input_array, expected_array):
[
# Basic positive values
(
np.array([0.0, 0.5, 1.0], dtype=np.float32),
np.array([0, 1, 2], dtype=(np.uint8, [("float4e2m1", "u1")])),
np.array([[0.0, 0.5], [1.0, 1.5]], dtype=np.float32),
np.array([[16, 50]], dtype=np.uint8),
),
# Basic negative values
(
np.array([-0.5, -1.0, -1.5], dtype=np.float32),
np.array([9, 10, 11], dtype=(np.uint8, [("float4e2m1", "u1")])),
np.array([[-0.5, -1.0], [-1.5, 1.75]], dtype=np.float32),
np.array([[169, 75]], dtype=np.uint8),
),
# Boundary values with rounding
(
np.array([0.75, 1.75, 3.5], dtype=np.float32),
np.array([2, 4, 6], dtype=(np.uint8, [("float4e2m1", "u1")])),
np.array([[0.0, 0.75], [1.75, 3.5]], dtype=np.float32),
np.array([[32, 100]], dtype=np.uint8),
),
# Large values (saturate to max)
(
np.array([10.0, -10.0], dtype=np.float32),
np.array([7, 15], dtype=(np.uint8, [("float4e2m1", "u1")])),
np.array([[10.0], [-10.0]], dtype=np.float32),
np.array([[247]], dtype=np.uint8),
),
# Very small values (map to zero)
(
np.array([0.1, -0.1], dtype=np.float32),
np.array([0, 8], dtype=(np.uint8, [("float4e2m1", "u1")])),
np.array([[0.1], [-0.1]], dtype=np.float32),
np.array([[128]], dtype=np.uint8),
),
# Zero and negative zero
(
np.array([0.0, -0.0], dtype=np.float32),
np.array([0, 0], dtype=(np.uint8, [("float4e2m1", "u1")])),
np.array([[0.0], [-0.0]], dtype=np.float32),
np.array([[0]], dtype=np.uint8),
),
],
)
def test_cast_fp4(self, input_array, expected_array):
"""Test FP4 casting functionality."""
result = _cast_fp4(input_array)
assert result.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
assert result.dtype == np.dtype(np.uint8)
assert result.shape == expected_array.shape
assert np.all(result == expected_array)
Loading