Skip to content

Commit 7ccaa53

Browse files
authored
ONNX 1.19 compatibility fix for INT4 quantization (#423)
Signed-off-by: Hrishith Thadicherla <[email protected]>
1 parent ff8a1ed commit 7ccaa53

File tree

3 files changed

+79
-11
lines changed

3 files changed

+79
-11
lines changed

modelopt/onnx/quantization/gs_patching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def _export_tensor_proto(tensor: gs.Constant) -> onnx.TensorProto:
7070
vals = tensor.values
7171
if _onnx_supports_int4() and dtype in [onnx.TensorProto.INT4, onnx.TensorProto.UINT4]:
7272
signed = dtype == onnx.TensorProto.INT4
73-
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(dtype)
74-
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np_dtype)
73+
packed_dtype = np.int8 if signed else np.uint8
74+
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(packed_dtype)
7575

7676
onnx_tensor = onnx.helper.make_tensor(
7777
tensor.name,

modelopt/onnx/quantization/int4.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,29 @@
9999
CLIP_MIN = 1e-5
100100

101101

102+
def safe_cupy_array(tensor):
103+
"""Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
104+
105+
In ONNX 1.19, int4 tensors use ml_dtypes.int4 which CuPy doesn't support.
106+
This function converts them to regular numpy.int8 while preserving values.
107+
108+
Args:
109+
tensor: numpy array that may have ml_dtypes.int4 dtype
110+
Returns:
111+
cupy or numpy array (if cupy is not supported) with numpy.int8 dtype if input was ml_dtypes.int4,
112+
otherwise unchanged
113+
"""
114+
try:
115+
import ml_dtypes
116+
117+
if hasattr(tensor, "dtype") and tensor.dtype == ml_dtypes.int4:
118+
return np.asarray(tensor.astype(numpy.int8))
119+
except ImportError:
120+
pass
121+
122+
return np.asarray(tensor)
123+
124+
102125
def _quantize_gather_nodes(
103126
graph: onnx.GraphProto,
104127
nodes_to_exclude: list[str],
@@ -271,19 +294,26 @@ def quantize_rtn(
271294
scales[name] = np.asnumpy(scales[name])
272295
gemm_weights_quantized[name] = numpy.asarray(qw)
273296
scales = reshape_scales_for_per_channel_nodes(scales, block_size, precision_info)
297+
dq_node_attributes = {"axis": 0, "block_size": block_size}
274298
qdq.insert_dq_nodes(
275299
graph,
276300
scales,
277301
quantized_weights=gemm_weights_quantized,
302+
attributes=dq_node_attributes,
278303
precision_info=precision_info,
279304
)
280305

281306
if gather_w_map is not None:
282307
assert gather_s_map is not None, "scale-map not found for quantizable gather nodes"
308+
gather_dq_node_attributes = {
309+
"axis": gather_quantize_axis,
310+
"block_size": gather_block_size,
311+
}
283312
qdq.insert_dq_nodes(
284313
graph,
285314
gather_s_map,
286315
quantized_weights=gather_w_map,
316+
attributes=gather_dq_node_attributes,
287317
precision_info=precision_info,
288318
)
289319
else:
@@ -299,7 +329,10 @@ def quantize_rtn(
299329
)
300330

301331
logger.info(f"RTN quantization completed in {time.time() - t_start:.2f} seconds")
302-
return gs.export_onnx(graph)
332+
model = gs.export_onnx(graph)
333+
model.ir_version = 10
334+
335+
return model
303336

304337

305338
class AWQClipHelper:

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from functools import partial
2121

2222
import torch
23-
from _test_utils.import_helper import skip_if_no_libcudnn, skip_if_onnx_version_above_1_18
23+
from _test_utils.import_helper import skip_if_no_libcudnn
2424
from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init
2525
from _test_utils.torch_quantization.quantize_common import get_awq_config
2626

@@ -39,9 +39,45 @@
3939
# test_qdq_utils_fp8.py::test_fused_q[bf16,fp16] fails if this script runs after the int4 test, but not before.
4040

4141

42-
def test_int4_awq(tmp_path):
43-
skip_if_onnx_version_above_1_18()
42+
def test_safe_cupy_array(monkeypatch):
43+
"""Comprehensive test for safe_cupy_array covering all code paths."""
44+
import builtins
45+
46+
import numpy # Import actual numpy for creating int4 tensors
47+
48+
# Test 1: Regular numpy array (should hit line 122)
49+
result = int4.safe_cupy_array(numpy.array([1, 2, 3, 4], dtype=numpy.float32))
50+
assert isinstance(result, np.ndarray)
51+
52+
# Test 2: With real ml_dtypes.int4 (covers lines 117-118)
53+
try:
54+
import ml_dtypes
55+
56+
int4_tensor = numpy.array([1, 2, -3, 4], dtype=numpy.float32).astype(ml_dtypes.int4)
57+
result = int4.safe_cupy_array(int4_tensor)
58+
assert isinstance(result, np.ndarray) and result.dtype == numpy.int8
59+
expected = int4_tensor.astype(numpy.int8)
60+
actual = result.get() if int4.has_cupy else result
61+
np.testing.assert_array_equal(actual, expected)
62+
except ImportError:
63+
pass # ml_dtypes not available
64+
65+
# Test 3: When ml_dtypes import fails (covers ImportError catch and line 122)
66+
original_import = builtins.__import__
4467

68+
def mock_import(name, *args, **kwargs):
69+
if name == "ml_dtypes":
70+
raise ImportError("ml_dtypes not available")
71+
return original_import(name, *args, **kwargs)
72+
73+
monkeypatch.setattr(builtins, "__import__", mock_import)
74+
75+
# Use actual numpy for creating the array
76+
result = int4.safe_cupy_array(numpy.array([5, 6, 7, 8], dtype=numpy.int8))
77+
assert isinstance(result, np.ndarray)
78+
79+
80+
def test_int4_awq(tmp_path):
4581
def _forward_loop(model, dataloader):
4682
"""Forward loop for calibration."""
4783
for data in dataloader:
@@ -94,20 +130,19 @@ def _forward_loop(model, dataloader):
94130
scale_awq_lite = find_init(onnx_model_awq_lite, scale_names[i])
95131

96132
if int4.has_cupy:
97-
wq_onnx_awq_lite = np.array(wq_onnx_awq_lite)
98-
scale_awq_lite = np.array(scale_awq_lite)
133+
wq_onnx_awq_lite = int4.safe_cupy_array(wq_onnx_awq_lite)
134+
scale_awq_lite = int4.safe_cupy_array(scale_awq_lite)
99135

100136
wq_onnx_awq_lite = dq_tensor(wq_onnx_awq_lite, scale_awq_lite, block_size)
101-
102137
wq_torch_awq_clip = model_torch_copy.net[i * 2].weight_quantizer(
103138
model_torch_copy.net[i * 2].weight
104139
)
105140
wq_onnx_awq_clip = find_init(onnx_model_awq_clip, wq_names[i])
106141
scale_awq_clip = find_init(onnx_model_awq_clip, scale_names[i])
107142

108143
if int4.has_cupy:
109-
wq_onnx_awq_clip = np.array(wq_onnx_awq_clip)
110-
scale_awq_clip = np.array(scale_awq_clip)
144+
wq_onnx_awq_clip = int4.safe_cupy_array(wq_onnx_awq_clip)
145+
scale_awq_clip = int4.safe_cupy_array(scale_awq_clip)
111146

112147
wq_onnx_awq_clip = dq_tensor(wq_onnx_awq_clip, scale_awq_clip, block_size)
113148

0 commit comments

Comments
 (0)