Skip to content

Commit 27b11c8

Browse files
committed
INT4 ONNX Version Fix: Code Quality Improvements
Signed-off-by: Hrishith Thadicherla <[email protected]>
1 parent e09ec22 commit 27b11c8

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

modelopt/onnx/quantization/int4.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,20 @@
9898
# supported and working
9999
CLIP_MIN = 1e-5
100100

101-
def safe_cupy_array(tensor):
102-
"""
103-
Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
101+
def safe_cupy_array(tensor):
102+
"""Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
104103
105104
In ONNX 1.19, int4 tensors use ml_dtypes.int4 which CuPy doesn't support.
106105
This function converts them to regular numpy.int8 while preserving values.
107-
108106
Args:
109-
tensor: numpy array that may have ml_dtypes.int4 dtype
110-
107+
tensor: numpy array that may have ml_dtypes.int4 dtype
111108
Returns:
112-
cupy or numpy array (if cupy is not supported) with numpy.int8 dtype if input was ml_dtypes.int4, otherwise unchanged
109+
cupy or numpy array (if cupy is not supported) with numpy.int8 dtype if input was ml_dtypes.int4,
110+
otherwise unchanged
113111
"""
114112
try:
115113
import ml_dtypes
114+
116115
if hasattr(tensor, 'dtype') and tensor.dtype == ml_dtypes.int4:
117116
return np.asarray(tensor.astype(numpy.int8))
118117
except ImportError:

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from functools import partial
2121

2222
import torch
23-
from _test_utils.import_helper import skip_if_no_libcudnn, skip_if_onnx_version_above_1_18
23+
from _test_utils.import_helper import skip_if_no_libcudnn
2424
from _test_utils.onnx_quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init
2525
from _test_utils.torch_quantization.quantize_common import get_awq_config
2626

@@ -40,8 +40,6 @@
4040

4141

4242
def test_int4_awq(tmp_path):
43-
# skip_if_onnx_version_above_1_18()
44-
4543
def _forward_loop(model, dataloader):
4644
"""Forward loop for calibration."""
4745
for data in dataloader:
@@ -115,7 +113,6 @@ def _forward_loop(model, dataloader):
115113

116114

117115
def test_int4_awq_cuda(tmp_path):
118-
# skip_if_onnx_version_above_1_18()
119116
skip_if_no_libcudnn()
120117
block_size = 128
121118

0 commit comments

Comments
 (0)