Skip to content

Commit c06fcac

Browse files
committed
Upgrade to ONNX 1.19.0
Signed-off-by: ajrasane <[email protected]>
1 parent 1cf78b2 commit c06fcac

File tree

5 files changed

+106
-37
lines changed

5 files changed

+106
-37
lines changed

examples/onnx_ptq/torch_quant_to_onnx.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def forward_loop(model):
8383
return quantized_model
8484

8585

86-
def get_model_input_shape(model_name):
86+
def get_model_input_shape(model_name, batch_size):
8787
"""Get the input shape from timm model configuration."""
8888
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
8989
data_config = timm.data.resolve_model_data_config(model)
9090
input_size = data_config["input_size"]
91-
return (1, *tuple(input_size)) # Add batch dimension
91+
return (batch_size, *tuple(input_size)) # Add batch dimension
9292

9393

9494
def main():
@@ -119,11 +119,17 @@ def main():
119119
default=512,
120120
help="Number of images to use in calibration [1-512]",
121121
)
122+
parser.add_argument(
123+
"--batch_size",
124+
type=int,
125+
default=1,
126+
help="Batch size for calibration.",
127+
)
122128

123129
args = parser.parse_args()
124130

125131
# Get input shape from model config
126-
input_shape = get_model_input_shape(args.timm_model_name)
132+
input_shape = get_model_input_shape(args.timm_model_name, args.batch_size)
127133

128134
# Create model and move to appropriate device
129135
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import onnx_graphsurgeon as gs
2424
import torch
2525
from onnx import numpy_helper
26-
from onnx.reference.custom_element_types import float8e4m3fn
2726

2827
from modelopt.onnx import utils
2928
from modelopt.onnx.logging_config import logger
@@ -50,6 +49,7 @@
5049
onnx_dtype_map = {
5150
"BFloat16": onnx.TensorProto.BFLOAT16,
5251
"Float": onnx.TensorProto.FLOAT,
52+
"Float4": onnx.TensorProto.FLOAT4E2M1,
5353
"Float8": onnx.TensorProto.FLOAT8E4M3FN,
5454
"Half": onnx.TensorProto.FLOAT16,
5555
"INT8": onnx.TensorProto.INT8,
@@ -592,7 +592,7 @@ def _convert_weight(
592592
zp_array = zp_array.reshape(*reshape_dims)
593593

594594
# Convert to INT8/FP8
595-
if zp_array.dtype == float8e4m3fn:
595+
if zp_array.dtype == onnx_dtype_map["Float8"]:
596596
scaled = np.asarray(weight_array / scale_array) + zp_array
597597
else:
598598
scaled = np.asarray((weight_array / scale_array).round())
@@ -607,17 +607,22 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
607607
if torch.cuda.is_available():
608608
array_f32_t = array_f32_t.cuda()
609609
array_f8_t = array_f32_t.clamp(min=-448, max=448).to(torch.float8_e4m3fn).view(torch.uint8)
610-
array_f8 = array_f8_t.cpu().numpy().astype((np.uint8, [("e4m3fn", "u1")]))
610+
array_f8 = array_f8_t.cpu().numpy().astype(np.uint8)
611611
return array_f8
612612

613613

614614
def _cast_fp4(array: np.ndarray) -> np.ndarray:
615615
"""Cast a numpy array to FLOAT4E2M1 using PyTorch."""
616616
array_f32_t = torch.from_numpy(array)
617+
array_f32_t_shape = array_f32_t.shape
618+
assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2"
619+
array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:])
617620
if torch.cuda.is_available():
618621
array_f32_t = array_f32_t.cuda()
619622
array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t)
620-
array_f4 = array_f4_t.cpu().numpy().astype((np.uint8, [("float4e2m1", "u1")]))
623+
array_f4_t = array_f4_t.flatten()
624+
array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape)
625+
array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8)
621626
return array_f4
622627

623628

@@ -685,7 +690,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
685690
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node)
686691

687692
# Create and update new weight tensor
688-
if zp_array.dtype == float8e4m3fn:
693+
if zp_array.dtype == onnx_dtype_map["Float8"]:
689694
new_weight = _create_fp8_tensor(scaled, weight_name)
690695
logger.debug(f"Converted {weight_name} to FP8")
691696
else:
@@ -920,6 +925,10 @@ def quantize_weights_to_int4(
920925
assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}"
921926
reshape_node_output = reshape_node.output[0]
922927

928+
# Remove constant node from reshape node
929+
shape_constant_name = next(input for input in reshape_node.input if "Constant" in input)
930+
nodes_to_remove.append(tensor_producer_map[shape_constant_name].name)
931+
923932
# Get the shape of the output of the reshape node
924933
reshape_output_value_info = value_info_map.get(reshape_node_output)
925934
if reshape_output_value_info is not None:
@@ -937,12 +946,17 @@ def quantize_weights_to_int4(
937946
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size]
938947
scale = scale.reshape(scale_shape)
939948
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input]
940-
# reshape_node.input = []
941949
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}"
942950

951+
# Remove unnecessary Cast node
952+
cast_node = reshape_child_nodes[0]
953+
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
954+
nodes_to_remove.append(cast_node.name)
955+
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
956+
943957
# Transpose weights and scales if present
944-
if reshape_child_nodes[0].op_type == "Transpose":
945-
transpose_node = reshape_child_nodes[0]
958+
if cast_child_nodes[0].op_type == "Transpose":
959+
transpose_node = cast_child_nodes[0]
946960
nodes_to_remove.append(transpose_node.name)
947961
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}"
948962
perm = None
@@ -959,7 +973,7 @@ def quantize_weights_to_int4(
959973
)
960974
matmul_node = transpose_child_nodes[0]
961975
else:
962-
matmul_node = reshape_child_nodes[0]
976+
matmul_node = cast_child_nodes[0]
963977
assert matmul_node.op_type in ["MatMul", "Gemm"], (
964978
f"Expected MatMul or Gemm node for {node.name}"
965979
)
@@ -990,6 +1004,21 @@ def quantize_weights_to_int4(
9901004
initializer_map[weight_name].CopyFrom(weights_int4_onnx)
9911005
logger.debug(f"Converted {weight_name} to INT4 precision")
9921006

1007+
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
1008+
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
1009+
return node.op_type == "Mul" and has_pqs_input
1010+
1011+
# Remove unnecessay Cast after Pre-quant scale
1012+
for node in graph.node:
1013+
if is_pre_quant_scale_node(node):
1014+
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
1015+
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
1016+
cast_node = pqs_child_nodes[0]
1017+
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
1018+
node.output.clear()
1019+
node.output.extend(cast_node.output)
1020+
nodes_to_remove.append(cast_node.name)
1021+
9931022
# Remove transpose and reshape nodes
9941023
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
9951024
graph.node.clear()
@@ -1004,7 +1033,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
10041033
for node in graph.node:
10051034
if node.op_type == "Cast":
10061035
# Skip Cast nodes that are part of normalization layers and outputs
1007-
if ("norm/Cast" in node.name and is_fp32_cast(node)) or node.name == "/Cast":
1036+
if "norm/Cast" in node.name and is_fp32_cast(node):
10081037
continue
10091038
for attr in node.attribute:
10101039
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
@@ -1099,7 +1128,13 @@ def quantize_weights_to_mxfp8(
10991128
# Expand block array so that it can be broadcasted with weight
11001129
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
11011130
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias)
1102-
weights_e4m3 = onnx.numpy_helper.from_array(_cast_fp8(scaled_weight), weight_name)
1131+
weights_e4m3 = onnx.helper.make_tensor(
1132+
name=weight_name,
1133+
data_type=onnx_dtype_map["Float8"],
1134+
dims=[*scaled_weight.shape],
1135+
vals=_cast_fp8(scaled_weight).tobytes(),
1136+
raw=True,
1137+
)
11031138
initializer_map[weight_name].CopyFrom(weights_e4m3)
11041139
logger.debug(f"Converted {weight_name} to MXFP8")
11051140

@@ -1181,11 +1216,24 @@ def _add_input_value_info(graph, tensor_proto):
11811216
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"
11821217

11831218
# Create TensorProto for initializers
1184-
w_f4_proto = onnx.numpy_helper.from_array(w_f4, w_f4_name)
1219+
w_f4_proto = onnx.helper.make_tensor(
1220+
name=w_f4_name,
1221+
data_type=onnx_dtype_map["Float4"],
1222+
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
1223+
vals=w_f4.tobytes(),
1224+
raw=True,
1225+
)
11851226
sw_f32_per_tensor_proto = onnx.numpy_helper.from_array(
11861227
sw_f32_per_tensor, sw_f32_per_tensor_name
11871228
)
11881229
sw_f8_per_block_proto = onnx.numpy_helper.from_array(sw_f8_per_block, sw_f8_per_block_name)
1230+
sw_f8_per_block_proto = onnx.helper.make_tensor(
1231+
name=sw_f8_per_block_name,
1232+
data_type=onnx_dtype_map["Float8"],
1233+
dims=[*sw_f8_per_block.shape],
1234+
vals=sw_f8_per_block.tobytes(),
1235+
raw=True,
1236+
)
11891237

11901238
# Add ValueInfo for the initializers if not present
11911239
_add_input_value_info(graph, w_f4_proto)

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def get_onnx_bytes_and_metadata(
485485
except StopIteration:
486486
param_dtype = torch.float32
487487
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
488-
if is_mxfp8_quantized(model):
488+
if is_mxfp8_quantized(model) or is_int4_quantized(model):
489489
assert weights_dtype == "fp16", "BF16 + MXFP8 mixed precision is not supported yet"
490490
onnx_opt_graph = convert_float_to_float16(
491491
onnx_opt_graph,

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
"cupy-cuda12x; platform_machine != 'aarch64' and platform_system != 'Darwin'",
4848
"ml_dtypes", # for bfloat16 conversion
4949
"onnx-graphsurgeon",
50-
"onnx~=1.18.0",
51-
"onnxconverter-common",
50+
"onnx~=1.19.0",
51+
"onnxconverter-common~=1.16.0",
5252
"onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'",
5353
"onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501
5454
"onnxruntime-directml==1.20.0; platform_system == 'Windows'",

tests/unit/onnx/test_qdq_utils.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa
3333

3434
# Create reshape shape tensor
3535
reshape_shape = np.array([16, 16], dtype=np.int64)
36-
reshape_shape_tensor = numpy_helper.from_array(reshape_shape, "reshape_shape")
3736

3837
# Create input tensor for MatMul
3938
input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [None, 2])
@@ -53,16 +52,32 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa
5352
"DequantizeLinear", inputs=dq_inputs, outputs=["dq_output"], name="weight_dq"
5453
)
5554

55+
reshape_constant = helper.make_node(
56+
"Constant",
57+
inputs=[],
58+
outputs=["reshape_shape_Constant"],
59+
value=numpy_helper.from_array(reshape_shape),
60+
name="reshape_constant",
61+
)
62+
5663
reshape_node = helper.make_node(
5764
"Reshape",
58-
inputs=["dq_output", "reshape_shape"],
65+
inputs=["dq_output", "reshape_shape_Constant"],
5966
outputs=["reshape_output"],
6067
name="weight_reshape",
6168
)
6269

70+
cast_node = helper.make_node(
71+
"Cast",
72+
inputs=["reshape_output"],
73+
outputs=["cast_output"],
74+
to=TensorProto.FLOAT,
75+
name="weight_cast",
76+
)
77+
6378
transpose_node = helper.make_node(
6479
"Transpose",
65-
inputs=["reshape_output"],
80+
inputs=["cast_output"],
6681
outputs=["transpose_output"],
6782
perm=[1, 0],
6883
name="weight_transpose",
@@ -78,15 +93,15 @@ def create_test_model_with_dq_reshape_transpose_matmul(constant_scale: bool = Fa
7893
)
7994

8095
# Create graph
81-
nodes = [dq_node, reshape_node, transpose_node, matmul_node]
96+
nodes = [dq_node, reshape_constant, reshape_node, cast_node, transpose_node, matmul_node]
8297
if constant_scale:
8398
nodes.append(scale_constant)
8499
graph = helper.make_graph(
85100
nodes=nodes,
86101
name="test_graph",
87102
inputs=[input_tensor],
88103
outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [None, 16])],
89-
initializer=[weight_tensor, scale_tensor, reshape_shape_tensor],
104+
initializer=[weight_tensor, scale_tensor],
90105
value_info=[reshape_output_info],
91106
)
92107

@@ -234,7 +249,7 @@ def test_cast_node_conversion(self):
234249
if node.op_type == "Cast":
235250
to_attr = next(attr for attr in node.attribute if attr.name == "to")
236251

237-
if "norm/Cast" in node.name or node.name == "/Cast":
252+
if "norm/Cast" in node.name:
238253
# These should remain as float32
239254
assert to_attr.i == TensorProto.FLOAT
240255
else:
@@ -297,39 +312,39 @@ def test_cast_fp8(self, input_array, expected_array):
297312
[
298313
# Basic positive values
299314
(
300-
np.array([0.0, 0.5, 1.0], dtype=np.float32),
301-
np.array([0, 1, 2], dtype=(np.uint8, [("float4e2m1", "u1")])),
315+
np.array([[0.0, 0.5], [1.0, 1.5]], dtype=np.float32),
316+
np.array([[16, 50]], dtype=np.uint8),
302317
),
303318
# Basic negative values
304319
(
305-
np.array([-0.5, -1.0, -1.5], dtype=np.float32),
306-
np.array([9, 10, 11], dtype=(np.uint8, [("float4e2m1", "u1")])),
320+
np.array([[-0.5, -1.0], [-1.5, 1.75]], dtype=np.float32),
321+
np.array([[169, 75]], dtype=np.uint8),
307322
),
308323
# Boundary values with rounding
309324
(
310-
np.array([0.75, 1.75, 3.5], dtype=np.float32),
311-
np.array([2, 4, 6], dtype=(np.uint8, [("float4e2m1", "u1")])),
325+
np.array([[0.0, 0.75], [1.75, 3.5]], dtype=np.float32),
326+
np.array([[32, 100]], dtype=np.uint8),
312327
),
313328
# Large values (saturate to max)
314329
(
315-
np.array([10.0, -10.0], dtype=np.float32),
316-
np.array([7, 15], dtype=(np.uint8, [("float4e2m1", "u1")])),
330+
np.array([[10.0], [-10.0]], dtype=np.float32),
331+
np.array([[247]], dtype=np.uint8),
317332
),
318333
# Very small values (map to zero)
319334
(
320-
np.array([0.1, -0.1], dtype=np.float32),
321-
np.array([0, 8], dtype=(np.uint8, [("float4e2m1", "u1")])),
335+
np.array([[0.1], [-0.1]], dtype=np.float32),
336+
np.array([[128]], dtype=np.uint8),
322337
),
323338
# Zero and negative zero
324339
(
325-
np.array([0.0, -0.0], dtype=np.float32),
326-
np.array([0, 0], dtype=(np.uint8, [("float4e2m1", "u1")])),
340+
np.array([[0.0], [-0.0]], dtype=np.float32),
341+
np.array([[0]], dtype=np.uint8),
327342
),
328343
],
329344
)
330345
def test_cast_fp4(self, input_array, expected_array):
331346
"""Test FP4 casting functionality."""
332347
result = _cast_fp4(input_array)
333-
assert result.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
348+
assert result.dtype == np.dtype(np.uint8)
334349
assert result.shape == expected_array.shape
335350
assert np.all(result == expected_array)

0 commit comments

Comments
 (0)