@@ -72,8 +72,8 @@ def assert_val(p: Point, shape: Shape, xnp: enp.NpModule = None):
7272 assert p .y .shape == shape
7373 assert enp .lazy .as_dtype (p .x .dtype ) == np .float32
7474 assert enp .lazy .as_dtype (p .y .dtype ) == np .float32
75- assert isinstance (p .x , xnp . ndarray )
76- assert isinstance (p .y , xnp . ndarray )
75+ assert enp . compat . is_array_xnp (p .x , xnp )
76+ assert enp . compat . is_array_xnp (p .y , xnp )
7777
7878
7979@dca .dataclass_array (broadcast = True , cast_dtype = True )
@@ -100,8 +100,8 @@ def assert_val(p: Isometrie, shape: Shape, xnp: enp.NpModule = None):
100100 assert p .t .shape == shape + (2 ,)
101101 assert enp .lazy .as_dtype (p .r .dtype ) == np .float32
102102 assert enp .lazy .as_dtype (p .t .dtype ) == np .int32
103- assert isinstance (p .r , xnp . ndarray )
104- assert isinstance (p .t , xnp . ndarray )
103+ assert enp . compat . is_array_xnp (p .r , xnp )
104+ assert enp . compat . is_array_xnp (p .t , xnp )
105105
106106
107107@dca .dataclass_array (broadcast = True , cast_dtype = True )
@@ -228,8 +228,8 @@ def assert_val(p: WithStatic, shape: Shape, xnp: enp.NpModule = None):
228228 assert p .y .shape == shape + (2 , 2 )
229229 assert enp .lazy .as_dtype (p .x .dtype ) == np .float32
230230 assert enp .lazy .as_dtype (p .y .dtype ) == np .float32
231- assert isinstance (p .x , xnp . ndarray )
232- assert isinstance (p .y , xnp . ndarray )
231+ assert enp . compat . is_array_xnp (p .x , xnp )
232+ assert enp . compat . is_array_xnp (p .y , xnp )
233233 # Static field is correctly forwarded
234234 assert isinstance (p .static , str )
235235 assert p .static == 'abc'
@@ -586,7 +586,7 @@ def test_broadcast(xnp: enp.NpModule):
586586def test_infer_np (xnp : enp .NpModule ):
587587 p = Point (x = xnp .ones ((3 ,)), y = [0 , 0 , 0 ]) # y is casted to xnp
588588 assert p .xnp is xnp
589- assert isinstance (p .y , xnp . ndarray )
589+ assert enp . compat . is_array_xnp (p .y , xnp )
590590
591591
592592@parametrize_dataclass_arrays
0 commit comments