Skip to content

Commit bf39fc5

Browse files
committed
Added tests and made changes to gs_patching.py to improve coverage
Signed-off-by: Hrishith Thadicherla <[email protected]>
1 parent f6734ec commit bf39fc5

File tree

2 files changed

+51
-4
lines changed

2 files changed

+51
-4
lines changed

modelopt/onnx/quantization/gs_patching.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,8 @@ def _export_tensor_proto(tensor: gs.Constant) -> onnx.TensorProto:
7070
vals = tensor.values
7171
if _onnx_supports_int4() and dtype in [onnx.TensorProto.INT4, onnx.TensorProto.UINT4]:
7272
signed = dtype == onnx.TensorProto.INT4
73-
if signed:
74-
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.int8)
75-
else:
76-
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.uint8)
73+
packed_dtype = np.int8 if signed else np.uint8
74+
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(packed_dtype)
7775

7876
onnx_tensor = onnx.helper.make_tensor(
7977
tensor.name,

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,55 @@
3838
# For that, we need to investigate failure in 'pytest tests/gpu/onnx'.
3939
# test_qdq_utils_fp8.py::test_fused_q[bf16,fp16] fails if this script runs after the int4 test, but not before.
4040

41+
def test_safe_cupy_array_all_paths(monkeypatch):
42+
"""Test safe_cupy_array covering all code paths including ml_dtypes handling"""
43+
# Test 1: When ml_dtypes import fails (covers ImportError path)
44+
# Temporarily remove ml_dtypes from sys.modules
45+
import sys
46+
if 'ml_dtypes' in sys.modules:
47+
ml_dtypes_backup = sys.modules['ml_dtypes']
48+
monkeypatch.delitem(sys.modules, 'ml_dtypes')
49+
else:
50+
ml_dtypes_backup = None
51+
52+
tensor = np.array([1, 2, 3, 4], dtype=np.int8)
53+
result = int4.safe_cupy_array(tensor)
54+
assert isinstance(result, np.ndarray) # Should return numpy array
55+
56+
# Restore ml_dtypes if it existed
57+
if ml_dtypes_backup:
58+
sys.modules['ml_dtypes'] = ml_dtypes_backup
59+
60+
# Test 2: When ml_dtypes exists and tensor has ml_dtypes.int4 dtype
61+
try:
62+
import ml_dtypes
63+
# Create a mock tensor with int4 dtype
64+
class MockInt4Tensor:
65+
def __init__(self, data):
66+
self.data = data
67+
self.dtype = ml_dtypes.int4
68+
self.shape = data.shape
69+
70+
def astype(self, dtype):
71+
return self.data.astype(dtype)
72+
73+
def __array__(self):
74+
return self.data
75+
76+
mock_tensor = MockInt4Tensor(np.array([1, 2, 3, 4], dtype=np.int8))
77+
print(mock_tensor.dtype)
78+
result = int4.safe_cupy_array(mock_tensor)
79+
assert isinstance(result, np.ndarray)
80+
assert result.dtype == np.int8
81+
except ImportError:
82+
# ml_dtypes not available, skip this part
83+
pass
84+
85+
# Test 3: Normal case with regular numpy array
86+
tensor = np.array([1, 2, 3, 4], dtype=np.int8)
87+
result = int4.safe_cupy_array(tensor)
88+
# Should work normally
89+
assert isinstance(result, (np.ndarray, type(tensor)))
4190

4291
def test_int4_awq(tmp_path):
4392
def _forward_loop(model, dataloader):

0 commit comments

Comments
 (0)