@@ -142,6 +142,35 @@ def test_isin_strided_bool():
142142 assert r2 .shape == x_s .shape
143143
144144
145+ @pytest .mark .parametrize ("dt1" , _numeric_dtypes )
146+ @pytest .mark .parametrize ("dt2" , _numeric_dtypes )
147+ def test_isin_dtype_matrix (dt1 , dt2 ):
148+ q = get_queue_or_skip ()
149+ skip_if_dtype_not_supported (dt1 , q )
150+ skip_if_dtype_not_supported (dt2 , q )
151+
152+ sz = 10
153+ x = dpt .asarray ([0 , 1 , 11 ], dtype = dt1 , sycl_queue = q )
154+ test1 = dpt .arange (sz , dtype = dt2 , sycl_queue = q )
155+
156+ r1 = dpt .isin (x , test1 )
157+ assert isinstance (r1 , dpt .usm_ndarray )
158+ assert r1 .dtype == dpt .bool
159+ assert r1 .shape == x .shape
160+ assert not r1 [- 1 ]
161+ assert dpt .all (r1 [0 :- 1 ])
162+ assert r1 .sycl_queue == x .sycl_queue
163+
164+ test2 = dpt .tile (dpt .asarray ([[0 , 1 ]], dtype = dt2 , sycl_queue = q ).mT , 2 )
165+ r2 = dpt .isin (x , test2 )
166+ assert isinstance (r2 , dpt .usm_ndarray )
167+ assert r2 .dtype == dpt .bool
168+ assert r2 .shape == x .shape
169+ assert not r2 [- 1 ]
170+ assert dpt .all (r1 [0 :- 1 ])
171+ assert r2 .sycl_queue == x .sycl_queue
172+
173+
145174def test_isin_empty_inputs ():
146175 get_queue_or_skip ()
147176
0 commit comments