Skip to content

Commit 29d1dff

Browse files
committed
Added tests and made changes to gs_patching.py to improve coverage
Signed-off-by: Hrishith Thadicherla <[email protected]>
1 parent f6734ec commit 29d1dff

File tree

2 files changed

+55
-4
lines changed

2 files changed

+55
-4
lines changed

modelopt/onnx/quantization/gs_patching.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,8 @@ def _export_tensor_proto(tensor: gs.Constant) -> onnx.TensorProto:
7070
vals = tensor.values
7171
if _onnx_supports_int4() and dtype in [onnx.TensorProto.INT4, onnx.TensorProto.UINT4]:
7272
signed = dtype == onnx.TensorProto.INT4
73-
if signed:
74-
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.int8)
75-
else:
76-
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(np.uint8)
73+
packed_dtype = np.int8 if signed else np.uint8
74+
vals = pack_float32_to_4bit_cpp_based(tensor.values, signed=signed).astype(packed_dtype)
7775

7876
onnx_tensor = onnx.helper.make_tensor(
7977
tensor.name,

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,59 @@
3939
# test_qdq_utils_fp8.py::test_fused_q[bf16,fp16] fails if this script runs after the int4 test, but not before.
4040

4141

42+
def test_safe_cupy_array_all_paths(monkeypatch):
43+
"""Test safe_cupy_array covering all code paths including ml_dtypes handling"""
44+
# Test 1: When ml_dtypes import fails (covers ImportError path)
45+
# Temporarily remove ml_dtypes from sys.modules
46+
import sys
47+
48+
if "ml_dtypes" in sys.modules:
49+
ml_dtypes_backup = sys.modules["ml_dtypes"]
50+
monkeypatch.delitem(sys.modules, "ml_dtypes")
51+
else:
52+
ml_dtypes_backup = None
53+
54+
tensor = np.array([1, 2, 3, 4], dtype=np.int8)
55+
result = int4.safe_cupy_array(tensor)
56+
assert isinstance(result, np.ndarray) # Should return numpy array
57+
58+
# Restore ml_dtypes if it existed
59+
if ml_dtypes_backup:
60+
sys.modules["ml_dtypes"] = ml_dtypes_backup
61+
62+
# Test 2: When ml_dtypes exists and tensor has ml_dtypes.int4 dtype
63+
try:
64+
import ml_dtypes
65+
66+
# Create a mock tensor with int4 dtype
67+
class MockInt4Tensor:
68+
def __init__(self, data):
69+
self.data = data
70+
self.dtype = ml_dtypes.int4
71+
self.shape = data.shape
72+
73+
def astype(self, dtype):
74+
return self.data.astype(dtype)
75+
76+
def __array__(self):
77+
return self.data
78+
79+
mock_tensor = MockInt4Tensor(np.array([1, 2, 3, 4], dtype=np.int8))
80+
print(mock_tensor.dtype)
81+
result = int4.safe_cupy_array(mock_tensor)
82+
assert isinstance(result, np.ndarray)
83+
assert result.dtype == np.int8
84+
except ImportError:
85+
# ml_dtypes not available, skip this part
86+
pass
87+
88+
# Test 3: Normal case with regular numpy array
89+
tensor = np.array([1, 2, 3, 4], dtype=np.int8)
90+
result = int4.safe_cupy_array(tensor)
91+
# Should work normally
92+
assert isinstance(result, (np.ndarray, type(tensor)))
93+
94+
4295
def test_int4_awq(tmp_path):
4396
def _forward_loop(model, dataloader):
4497
"""Forward loop for calibration."""

0 commit comments

Comments
 (0)