Skip to content

Commit 83d9a9e

Browse files
committed
INT4 ONNX Version Fix: Code Quality Improvements
Signed-off-by: Hrishith Thadicherla <[email protected]>
1 parent 3a98a23 commit 83d9a9e

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

modelopt/onnx/quantization/gs_patching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def _export_tensor_proto(tensor: gs.Constant) -> onnx.TensorProto:
6969

7070
vals = tensor.values
7171
if _onnx_supports_int4() and dtype in [onnx.TensorProto.INT4, onnx.TensorProto.UINT4]:
72-
signed = dtype == onnx.TensorProto.INT4
73-
if(signed):
72+
signed = dtype == onnx.TensorProto.INT4
73+
if signed:
7474
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.int8)
7575
else:
7676
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.uint8)

modelopt/onnx/quantization/int4.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,13 @@
9898
# supported and working
9999
CLIP_MIN = 1e-5
100100

101+
101102
def safe_cupy_array(tensor):
102103
"""Convert ml_dtypes.int4 tensor to numpy.int8 for CuPy compatibility.
103104
104105
In ONNX 1.19, int4 tensors use ml_dtypes.int4 which CuPy doesn't support.
105106
This function converts them to regular numpy.int8 while preserving values.
107+
106108
Args:
107109
tensor: numpy array that may have ml_dtypes.int4 dtype
108110
Returns:
@@ -111,12 +113,12 @@ def safe_cupy_array(tensor):
111113
"""
112114
try:
113115
import ml_dtypes
114-
115-
if hasattr(tensor, 'dtype') and tensor.dtype == ml_dtypes.int4:
116+
117+
if hasattr(tensor, "dtype") and tensor.dtype == ml_dtypes.int4:
116118
return np.asarray(tensor.astype(numpy.int8))
117119
except ImportError:
118120
pass
119-
121+
120122
return np.asarray(tensor)
121123

122124

0 commit comments

Comments
 (0)