Skip to content

Commit 07db172

Browse files
committed
Adding unit test for safe_cupy_array function
Signed-off-by: Hrishith Thadicherla <[email protected]>
1 parent 1635bab commit 07db172

File tree

1 file changed

+29
-47
lines changed

1 file changed

+29
-47
lines changed

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -39,56 +39,38 @@
3939
# test_qdq_utils_fp8.py::test_fused_q[bf16,fp16] fails if this script runs after the int4 test, but not before.
4040

4141

42-
def test_safe_cupy_array_all_paths(monkeypatch):
43-
"""Test safe_cupy_array covering all code paths including ml_dtypes handling"""
44-
# Test 1: When ml_dtypes import fails (covers ImportError path)
45-
# Temporarily remove ml_dtypes from sys.modules
46-
import sys
47-
48-
if "ml_dtypes" in sys.modules:
49-
ml_dtypes_backup = sys.modules["ml_dtypes"]
50-
monkeypatch.delitem(sys.modules, "ml_dtypes")
51-
else:
52-
ml_dtypes_backup = None
53-
54-
tensor = np.array([1, 2, 3, 4], dtype=np.int8)
55-
result = int4.safe_cupy_array(tensor)
56-
assert isinstance(result, np.ndarray) # Should return numpy array
57-
58-
# Restore ml_dtypes if it existed
59-
if ml_dtypes_backup:
60-
sys.modules["ml_dtypes"] = ml_dtypes_backup
61-
62-
# Test 2: When ml_dtypes exists and tensor has ml_dtypes.int4 dtype
42+
def test_safe_cupy_array(monkeypatch):
43+
"""Comprehensive test for safe_cupy_array covering all code paths."""
44+
import builtins
45+
import numpy # Import actual numpy for creating int4 tensors
46+
47+
# Test 1: Regular numpy array (should hit line 122)
48+
result = int4.safe_cupy_array(numpy.array([1, 2, 3, 4], dtype=numpy.float32))
49+
assert isinstance(result, np.ndarray)
50+
51+
# Test 2: With real ml_dtypes.int4 (covers lines 117-118)
6352
try:
6453
import ml_dtypes
65-
66-
# Create a mock tensor with int4 dtype
67-
class MockInt4Tensor:
68-
def __init__(self, data):
69-
self.data = data
70-
self.dtype = ml_dtypes.int4
71-
self.shape = data.shape
72-
73-
def astype(self, dtype):
74-
return self.data.astype(dtype)
75-
76-
def __array__(self):
77-
return self.data
78-
79-
mock_tensor = MockInt4Tensor(np.array([1, 2, 3, 4], dtype=np.int8))
80-
result = int4.safe_cupy_array(mock_tensor)
81-
assert isinstance(result, np.ndarray)
82-
assert result.dtype == np.int8
54+
int4_tensor = numpy.array([1, 2, -3, 4], dtype=numpy.float32).astype(ml_dtypes.int4)
55+
result = int4.safe_cupy_array(int4_tensor)
56+
assert isinstance(result, np.ndarray) and result.dtype == numpy.int8
57+
expected = int4_tensor.astype(numpy.int8)
58+
actual = result.get() if int4.has_cupy else result
59+
np.testing.assert_array_equal(actual, expected)
8360
except ImportError:
84-
# ml_dtypes not available, skip this part
85-
pass
86-
87-
# Test 3: Normal case with regular numpy array
88-
tensor = np.array([1, 2, 3, 4], dtype=np.int8)
89-
result = int4.safe_cupy_array(tensor)
90-
# Should work normally
91-
assert isinstance(result, (np.ndarray, type(tensor)))
61+
pass # ml_dtypes not available
62+
63+
# Test 3: When ml_dtypes import fails (covers ImportError catch and line 122)
64+
def mock_import(name, *args, **kwargs):
65+
if name == "ml_dtypes":
66+
raise ImportError("ml_dtypes not available")
67+
return builtins.__import__(name, *args, **kwargs)
68+
69+
monkeypatch.setattr(builtins, "__import__", mock_import)
70+
71+
# Use actual numpy for creating the array
72+
result = int4.safe_cupy_array(numpy.array([5, 6, 7, 8], dtype=numpy.int8))
73+
assert isinstance(result, np.ndarray)
9274

9375

9476
def test_int4_awq(tmp_path):

0 commit comments

Comments
 (0)