Skip to content

Commit 6ff58ff

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Remove xnp.ndarray calls
PiperOrigin-RevId: 515296466
1 parent d9c3927 commit 6ff58ff

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

dataclass_array/array_dataclass_test.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def assert_val(p: Point, shape: Shape, xnp: enp.NpModule = None):
7272
assert p.y.shape == shape
7373
assert enp.lazy.as_dtype(p.x.dtype) == np.float32
7474
assert enp.lazy.as_dtype(p.y.dtype) == np.float32
75-
assert isinstance(p.x, xnp.ndarray)
76-
assert isinstance(p.y, xnp.ndarray)
75+
assert enp.compat.is_array_xnp(p.x, xnp)
76+
assert enp.compat.is_array_xnp(p.y, xnp)
7777

7878

7979
@dca.dataclass_array(broadcast=True, cast_dtype=True)
@@ -100,8 +100,8 @@ def assert_val(p: Isometrie, shape: Shape, xnp: enp.NpModule = None):
100100
assert p.t.shape == shape + (2,)
101101
assert enp.lazy.as_dtype(p.r.dtype) == np.float32
102102
assert enp.lazy.as_dtype(p.t.dtype) == np.int32
103-
assert isinstance(p.r, xnp.ndarray)
104-
assert isinstance(p.t, xnp.ndarray)
103+
assert enp.compat.is_array_xnp(p.r, xnp)
104+
assert enp.compat.is_array_xnp(p.t, xnp)
105105

106106

107107
@dca.dataclass_array(broadcast=True, cast_dtype=True)
@@ -228,8 +228,8 @@ def assert_val(p: WithStatic, shape: Shape, xnp: enp.NpModule = None):
228228
assert p.y.shape == shape + (2, 2)
229229
assert enp.lazy.as_dtype(p.x.dtype) == np.float32
230230
assert enp.lazy.as_dtype(p.y.dtype) == np.float32
231-
assert isinstance(p.x, xnp.ndarray)
232-
assert isinstance(p.y, xnp.ndarray)
231+
assert enp.compat.is_array_xnp(p.x, xnp)
232+
assert enp.compat.is_array_xnp(p.y, xnp)
233233
# Static field is correctly forwarded
234234
assert isinstance(p.static, str)
235235
assert p.static == 'abc'
@@ -586,7 +586,7 @@ def test_broadcast(xnp: enp.NpModule):
586586
def test_infer_np(xnp: enp.NpModule):
587587
p = Point(x=xnp.ones((3,)), y=[0, 0, 0]) # y is casted to xnp
588588
assert p.xnp is xnp
589-
assert isinstance(p.y, xnp.ndarray)
589+
assert enp.compat.is_array_xnp(p.y, xnp)
590590

591591

592592
@parametrize_dataclass_arrays

dataclass_array/vectorization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def _broacast_and_flatten_to(
294294
final_shape = batch_shape + inner_shape
295295
if isinstance(array, array_dataclass.DataclassArray):
296296
array = array.broadcast_to(final_shape)
297-
elif isinstance(array, xnp.ndarray):
297+
elif enp.compat.is_array_xnp(array, xnp):
298298
array = xnp.broadcast_to(array, final_shape)
299299
else:
300300
raise TypeError(f'Unexpected array type: {type(array)}')

dataclass_array/vectorization_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_broadcast_args(
7070
def fn(self, arg_dc, arg_array):
7171
assert isinstance(self, dca.testing.Ray)
7272
assert isinstance(arg_dc, dca.testing.Ray)
73-
assert isinstance(arg_array, xnp.ndarray)
73+
assert enp.compat.is_array_xnp(arg_array, xnp)
7474
assert self.shape == () # pylint: disable=g-explicit-bool-comparison
7575
assert arg_dc.shape == expected_arg_shape[1:]
7676
assert arg_array.shape == expected_arg_shape[1:] + (3,)

0 commit comments

Comments
 (0)