Skip to content

Commit 22ecf87

Browse files
authored
Disable sass checks for float16 merge sort (#8053)
1 parent b1951b2 commit 22ecf87

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

python/cuda_cccl/tests/compute/test_merge_sort.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,12 @@ def test_merge_sort_keys(dtype, num_items, op):
8888

8989

9090
@pytest.mark.parametrize("dtype,num_items,op", merge_sort_params)
91-
def test_merge_sort_pairs(dtype, num_items, op):
91+
def test_merge_sort_pairs(dtype, num_items, op, monkeypatch):
92+
if dtype == np.float16:
93+
import cuda.compute._cccl_interop
94+
95+
monkeypatch.setattr(cuda.compute._cccl_interop, "_check_sass", False)
96+
9297
h_in_keys = random_array(num_items, dtype)
9398
h_in_items = random_array(num_items, np.float32)
9499

@@ -125,7 +130,12 @@ def test_merge_sort_keys_copy(dtype, num_items, op):
125130

126131

127132
@pytest.mark.parametrize("dtype,num_items,op", merge_sort_params)
128-
def test_merge_sort_pairs_copy(dtype, num_items, op):
133+
def test_merge_sort_pairs_copy(dtype, num_items, op, monkeypatch):
134+
if dtype == np.float16:
135+
import cuda.compute._cccl_interop
136+
137+
monkeypatch.setattr(cuda.compute._cccl_interop, "_check_sass", False)
138+
129139
h_in_keys = random_array(num_items, dtype)
130140
h_in_items = random_array(num_items, np.float32)
131141
h_out_keys = np.empty(num_items, dtype=dtype)
@@ -239,7 +249,12 @@ def test_merge_sort_keys_copy_iterator_input(dtype, num_items, op):
239249

240250

241251
@pytest.mark.parametrize("dtype,num_items,op", merge_sort_params)
242-
def test_merge_sort_pairs_copy_iterator_input(dtype, num_items, op):
252+
def test_merge_sort_pairs_copy_iterator_input(dtype, num_items, op, monkeypatch):
253+
if dtype == np.float16:
254+
import cuda.compute._cccl_interop
255+
256+
monkeypatch.setattr(cuda.compute._cccl_interop, "_check_sass", False)
257+
243258
h_in_keys = random_array(num_items, dtype)
244259
h_in_items = random_array(num_items, np.float32)
245260
h_out_keys = np.empty(num_items, dtype=dtype)

0 commit comments

Comments
 (0)