Skip to content

Commit 4d62165

Browse files
committed
Fix INT4 ONNX quantization issue for version > 1.18
Signed-off-by: Hrishith Thadicherla <[email protected]>
1 parent 4df4091 commit 4d62165

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

modelopt/onnx/quantization/gs_patching.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ def _export_tensor_proto(tensor: gs.Constant) -> onnx.TensorProto:
6969

7070
vals = tensor.values
7171
if _onnx_supports_int4() and dtype in [onnx.TensorProto.INT4, onnx.TensorProto.UINT4]:
72-
signed = dtype == onnx.TensorProto.INT4
73-
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(dtype)
74-
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np_dtype)
72+
signed = dtype == onnx.TensorProto.INT4
73+
if(signed):
74+
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.int8)
75+
else:
76+
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.uint8)
7577

7678
onnx_tensor = onnx.helper.make_tensor(
7779
tensor.name,

modelopt/onnx/quantization/int4.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,28 @@
9898
# supported and working
9999
CLIP_MIN = 1e-5
100100

101+
def convert_ml_dtypes_int4_to_int8_format(tensor):
102+
"""
103+
Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
104+
105+
In ONNX 1.19, int4 tensors use ml_dtypes.int4 which CuPy doesn't support.
106+
This function converts them to regular numpy.int8 while preserving values.
107+
108+
Args:
109+
tensor: numpy array that may have ml_dtypes.int4 dtype
110+
111+
Returns:
112+
cupy or numpy array (if cupy is not supported) with numpy.int8 dtype if input was ml_dtypes.int4, otherwise unchanged
113+
"""
114+
try:
115+
import ml_dtypes
116+
if hasattr(tensor, 'dtype') and tensor.dtype == ml_dtypes.int4:
117+
return np.asarray(tensor.astype(numpy.int8))
118+
except ImportError:
119+
pass
120+
121+
return np.asarray(tensor)
122+
101123

102124
def _quantize_gather_nodes(
103125
graph: onnx.GraphProto,

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040

4141

4242
def test_int4_awq(tmp_path):
43-
skip_if_onnx_version_above_1_18()
43+
# skip_if_onnx_version_above_1_18()
4444

4545
def _forward_loop(model, dataloader):
4646
"""Forward loop for calibration."""
@@ -94,20 +94,19 @@ def _forward_loop(model, dataloader):
9494
scale_awq_lite = find_init(onnx_model_awq_lite, scale_names[i])
9595

9696
if int4.has_cupy:
97-
wq_onnx_awq_lite = np.array(wq_onnx_awq_lite)
98-
scale_awq_lite = np.array(scale_awq_lite)
97+
wq_onnx_awq_lite = int4.convert_ml_dtypes_int4_to_int8_format(wq_onnx_awq_lite)
98+
scale_awq_lite = int4.convert_ml_dtypes_int4_to_int8_format(scale_awq_lite)
9999

100100
wq_onnx_awq_lite = dq_tensor(wq_onnx_awq_lite, scale_awq_lite, block_size)
101-
102101
wq_torch_awq_clip = model_torch_copy.net[i * 2].weight_quantizer(
103102
model_torch_copy.net[i * 2].weight
104103
)
105104
wq_onnx_awq_clip = find_init(onnx_model_awq_clip, wq_names[i])
106105
scale_awq_clip = find_init(onnx_model_awq_clip, scale_names[i])
107-
106+
108107
if int4.has_cupy:
109-
wq_onnx_awq_clip = np.array(wq_onnx_awq_clip)
110-
scale_awq_clip = np.array(scale_awq_clip)
108+
wq_onnx_awq_clip = int4.convert_ml_dtypes_int4_to_int8_format(wq_onnx_awq_clip)
109+
scale_awq_clip = int4.convert_ml_dtypes_int4_to_int8_format(scale_awq_clip)
111110

112111
wq_onnx_awq_clip = dq_tensor(wq_onnx_awq_clip, scale_awq_clip, block_size)
113112

@@ -116,7 +115,7 @@ def _forward_loop(model, dataloader):
116115

117116

118117
def test_int4_awq_cuda(tmp_path):
119-
skip_if_onnx_version_above_1_18()
118+
# skip_if_onnx_version_above_1_18()
120119
skip_if_no_libcudnn()
121120
block_size = 128
122121

0 commit comments

Comments
 (0)