Skip to content

Commit f2cb3b6

Browse files
ajrasanejingyu-ml
authored andcommitted
Upgrade to ONNX 1.19.0 (#289)
Signed-off-by: ajrasane <[email protected]> Signed-off-by: Jingyu Xin <[email protected]>
1 parent 94e85b9 commit f2cb3b6

File tree

8 files changed

+134
-40
lines changed

8 files changed

+134
-40
lines changed

examples/onnx_ptq/torch_quant_to_onnx.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def forward_loop(model):
8383
return quantized_model
8484

8585

86-
def get_model_input_shape(model_name):
86+
def get_model_input_shape(model_name, batch_size):
8787
"""Get the input shape from timm model configuration."""
8888
model = timm.create_model(model_name, pretrained=True, num_classes=1000)
8989
data_config = timm.data.resolve_model_data_config(model)
9090
input_size = data_config["input_size"]
91-
return (1, *tuple(input_size)) # Add batch dimension
91+
return (batch_size, *tuple(input_size)) # Add batch dimension
9292

9393

9494
def main():
@@ -119,11 +119,17 @@ def main():
119119
default=512,
120120
help="Number of images to use in calibration [1-512]",
121121
)
122+
parser.add_argument(
123+
"--batch_size",
124+
type=int,
125+
default=1,
126+
help="Batch size for calibration and ONNX model export.",
127+
)
122128

123129
args = parser.parse_args()
124130

125131
# Get input shape from model config
126-
input_shape = get_model_input_shape(args.timm_model_name)
132+
input_shape = get_model_input_shape(args.timm_model_name, args.batch_size)
127133

128134
# Create model and move to appropriate device
129135
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
datasets>=2.14.5
2+
onnx==1.18.0
23
torch==2.6.0
34
transformers==4.49.0

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import onnx_graphsurgeon as gs
2424
import torch
2525
from onnx import numpy_helper
26-
from onnx.reference.custom_element_types import float8e4m3fn
2726

2827
from modelopt.onnx import utils
2928
from modelopt.onnx.logging_config import logger
@@ -50,6 +49,7 @@
5049
onnx_dtype_map = {
5150
"BFloat16": onnx.TensorProto.BFLOAT16,
5251
"Float": onnx.TensorProto.FLOAT,
52+
"Float4": onnx.TensorProto.FLOAT4E2M1,
5353
"Float8": onnx.TensorProto.FLOAT8E4M3FN,
5454
"Half": onnx.TensorProto.FLOAT16,
5555
"INT8": onnx.TensorProto.INT8,
@@ -592,7 +592,7 @@ def _convert_weight(
592592
zp_array = zp_array.reshape(*reshape_dims)
593593

594594
# Convert to INT8/FP8
595-
if zp_array.dtype == float8e4m3fn:
595+
if zp_array.dtype == onnx_dtype_map["Float8"]:
596596
scaled = np.asarray(weight_array / scale_array) + zp_array
597597
else:
598598
scaled = np.asarray((weight_array / scale_array).round())
@@ -607,17 +607,26 @@ def _cast_fp8(array: np.ndarray) -> np.ndarray:
607607
if torch.cuda.is_available():
608608
array_f32_t = array_f32_t.cuda()
609609
array_f8_t = array_f32_t.clamp(min=-448, max=448).to(torch.float8_e4m3fn).view(torch.uint8)
610-
array_f8 = array_f8_t.cpu().numpy().astype((np.uint8, [("e4m3fn", "u1")]))
610+
array_f8 = array_f8_t.cpu().numpy().astype(np.uint8)
611611
return array_f8
612612

613613

614614
def _cast_fp4(array: np.ndarray) -> np.ndarray:
615-
"""Cast a numpy array to FLOAT4E2M1 using PyTorch."""
615+
"""Cast a numpy array to FLOAT4E2M1 using PyTorch.
616+
617+
Note: The first dimension of the array must be divisible by 2
618+
as two FP4 values are packed into a single byte.
619+
"""
616620
array_f32_t = torch.from_numpy(array)
621+
array_f32_t_shape = array_f32_t.shape
622+
assert array_f32_t_shape[0] % 2 == 0, "array_f32_t_shape[0] must be divisible by 2"
623+
array_f4_t_shape = (array_f32_t_shape[0] // 2, *array_f32_t_shape[1:])
617624
if torch.cuda.is_available():
618625
array_f32_t = array_f32_t.cuda()
619626
array_f4_t = NVFP4QTensor._cast_fp4(array_f32_t)
620-
array_f4 = array_f4_t.cpu().numpy().astype((np.uint8, [("float4e2m1", "u1")]))
627+
array_f4_t = array_f4_t.flatten()
628+
array_f4_t_packed = (array_f4_t[::2] | (array_f4_t[1::2] << 4)).reshape(array_f4_t_shape)
629+
array_f4 = array_f4_t_packed.cpu().numpy().astype(np.uint8)
621630
return array_f4
622631

623632

@@ -685,7 +694,7 @@ def qdq_to_dq(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
685694
scaled = _convert_weight(weight_array, scale_array, zp_array, quantized_node)
686695

687696
# Create and update new weight tensor
688-
if zp_array.dtype == float8e4m3fn:
697+
if zp_array.dtype == onnx_dtype_map["Float8"]:
689698
new_weight = _create_fp8_tensor(scaled, weight_name)
690699
logger.debug(f"Converted {weight_name} to FP8")
691700
else:
@@ -925,6 +934,10 @@ def quantize_weights_to_int4(
925934
assert reshape_node.op_type == "Reshape", f"Expected Reshape node for {node.name}"
926935
reshape_node_output = reshape_node.output[0]
927936

937+
# Remove constant node from reshape node
938+
shape_constant_name = next(input for input in reshape_node.input if "Constant" in input)
939+
nodes_to_remove.append(tensor_producer_map[shape_constant_name].name)
940+
928941
# Get the shape of the output of the reshape node
929942
reshape_output_value_info = value_info_map.get(reshape_node_output)
930943
if reshape_output_value_info is not None:
@@ -942,12 +955,17 @@ def quantize_weights_to_int4(
942955
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size]
943956
scale = scale.reshape(scale_shape)
944957
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input]
945-
# reshape_node.input = []
946958
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}"
947959

960+
# Remove unnecessary Cast node
961+
cast_node = reshape_child_nodes[0]
962+
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
963+
nodes_to_remove.append(cast_node.name)
964+
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
965+
948966
# Transpose weights and scales if present
949-
if reshape_child_nodes[0].op_type == "Transpose":
950-
transpose_node = reshape_child_nodes[0]
967+
if cast_child_nodes[0].op_type == "Transpose":
968+
transpose_node = cast_child_nodes[0]
951969
nodes_to_remove.append(transpose_node.name)
952970
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}"
953971
perm = None
@@ -964,7 +982,7 @@ def quantize_weights_to_int4(
964982
)
965983
matmul_node = transpose_child_nodes[0]
966984
else:
967-
matmul_node = reshape_child_nodes[0]
985+
matmul_node = cast_child_nodes[0]
968986
assert matmul_node.op_type in ["MatMul", "Gemm"], (
969987
f"Expected MatMul or Gemm node for {node.name}"
970988
)
@@ -995,6 +1013,21 @@ def quantize_weights_to_int4(
9951013
initializer_map[weight_name].CopyFrom(weights_int4_onnx)
9961014
logger.debug(f"Converted {weight_name} to INT4 precision")
9971015

1016+
def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
1017+
has_pqs_input = any(input for input in node.input if "_pre_quant_scale" in input)
1018+
return node.op_type == "Mul" and has_pqs_input
1019+
1020+
# Remove unnecessay Cast after Pre-quant scale
1021+
for node in graph.node:
1022+
if is_pre_quant_scale_node(node):
1023+
pqs_child_nodes = [n for n in graph.node if node.output[0] in n.input]
1024+
assert len(pqs_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
1025+
cast_node = pqs_child_nodes[0]
1026+
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
1027+
node.output.clear()
1028+
node.output.extend(cast_node.output)
1029+
nodes_to_remove.append(cast_node.name)
1030+
9981031
# Remove transpose and reshape nodes
9991032
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove]
10001033
graph.node.clear()
@@ -1009,7 +1042,7 @@ def is_fp32_cast(node: onnx.NodeProto) -> bool:
10091042
for node in graph.node:
10101043
if node.op_type == "Cast":
10111044
# Skip Cast nodes that are part of normalization layers and outputs
1012-
if ("norm/Cast" in node.name and is_fp32_cast(node)) or node.name == "/Cast":
1045+
if "norm/Cast" in node.name and is_fp32_cast(node):
10131046
continue
10141047
for attr in node.attribute:
10151048
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
@@ -1104,7 +1137,13 @@ def quantize_weights_to_mxfp8(
11041137
# Expand block array so that it can be broadcasted with weight
11051138
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
11061139
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias)
1107-
weights_e4m3 = onnx.numpy_helper.from_array(_cast_fp8(scaled_weight), weight_name)
1140+
weights_e4m3 = onnx.helper.make_tensor(
1141+
name=weight_name,
1142+
data_type=onnx_dtype_map["Float8"],
1143+
dims=[*scaled_weight.shape],
1144+
vals=_cast_fp8(scaled_weight).tobytes(),
1145+
raw=True,
1146+
)
11081147
initializer_map[weight_name].CopyFrom(weights_e4m3)
11091148
logger.debug(f"Converted {weight_name} to MXFP8")
11101149

@@ -1186,11 +1225,24 @@ def _add_input_value_info(graph, tensor_proto):
11861225
sw_f32_per_tensor_name = sw_f8_per_block_name + "_f32_scale"
11871226

11881227
# Create TensorProto for initializers
1189-
w_f4_proto = onnx.numpy_helper.from_array(w_f4, w_f4_name)
1228+
w_f4_proto = onnx.helper.make_tensor(
1229+
name=w_f4_name,
1230+
data_type=onnx_dtype_map["Float4"],
1231+
dims=[w_f4.shape[0] * 2, *w_f4.shape[1:]],
1232+
vals=w_f4.tobytes(),
1233+
raw=True,
1234+
)
11901235
sw_f32_per_tensor_proto = onnx.numpy_helper.from_array(
11911236
sw_f32_per_tensor, sw_f32_per_tensor_name
11921237
)
11931238
sw_f8_per_block_proto = onnx.numpy_helper.from_array(sw_f8_per_block, sw_f8_per_block_name)
1239+
sw_f8_per_block_proto = onnx.helper.make_tensor(
1240+
name=sw_f8_per_block_name,
1241+
data_type=onnx_dtype_map["Float8"],
1242+
dims=[*sw_f8_per_block.shape],
1243+
vals=sw_f8_per_block.tobytes(),
1244+
raw=True,
1245+
)
11941246

11951247
# Add ValueInfo for the initializers if not present
11961248
_add_input_value_info(graph, w_f4_proto)

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,8 @@ def get_onnx_bytes_and_metadata(
485485
except StopIteration:
486486
param_dtype = torch.float32
487487
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
488-
if is_mxfp8_quantized(model):
489-
assert weights_dtype == "fp16", "BF16 + MXFP8 mixed precision is not supported yet"
488+
if is_mxfp8_quantized(model) or is_int4_quantized(model):
489+
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
490490
onnx_opt_graph = convert_float_to_float16(
491491
onnx_opt_graph,
492492
keep_io_types=False,

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
"cupy-cuda12x; platform_machine != 'aarch64' and platform_system != 'Darwin'",
4848
"ml_dtypes", # for bfloat16 conversion
4949
"onnx-graphsurgeon",
50-
"onnx~=1.18.0",
51-
"onnxconverter-common",
50+
"onnx~=1.19.0",
51+
"onnxconverter-common~=1.16.0",
5252
"onnxruntime~=1.22.0 ; platform_machine == 'aarch64' or platform_system == 'Darwin'",
5353
"onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501
5454
"onnxruntime-directml==1.20.0; platform_system == 'Windows'",

tests/_test_utils/import_helper.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import importlib.metadata
1617
import shutil
1718

1819
import pytest
20+
from packaging import version
1921

2022

2123
def skip_if_no_tensorrt():
@@ -73,3 +75,18 @@ def skip_if_no_megatron(apex_or_te_required: bool = False, mamba_required: bool
7375

7476
if mamba_required and not has_mamba:
7577
pytest.skip("Mamba required for Megatron test", allow_module_level=True)
78+
79+
80+
def skip_if_onnx_version_above_1_18():
81+
package_name = "onnx"
82+
required_version = "1.18.0"
83+
84+
try:
85+
installed_version = importlib.metadata.version(package_name)
86+
except importlib.metadata.PackageNotFoundError:
87+
pytest.skip(f"{package_name} is not installed")
88+
89+
if version.parse(installed_version) > version.parse(required_version):
90+
pytest.skip(
91+
f"{package_name} version {installed_version} is less than required {required_version}"
92+
)

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from functools import partial
2121

2222
import torch
23-
from _test_utils.import_helper import skip_if_no_libcudnn
23+
from _test_utils.import_helper import skip_if_no_libcudnn, skip_if_onnx_version_above_1_18
2424
from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init
2525
from _test_utils.torch_quantization.quantize_common import get_awq_config
2626

@@ -40,6 +40,8 @@
4040

4141

4242
def test_int4_awq(tmp_path):
43+
skip_if_onnx_version_above_1_18()
44+
4345
def _forward_loop(model, dataloader):
4446
"""Forward loop for calibration."""
4547
for data in dataloader:
@@ -114,6 +116,7 @@ def _forward_loop(model, dataloader):
114116

115117

116118
def test_int4_awq_cuda(tmp_path):
119+
skip_if_onnx_version_above_1_18()
117120
skip_if_no_libcudnn()
118121
block_size = 128
119122

0 commit comments

Comments
 (0)