Skip to content

Commit d9c3927

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Replace array by asarray for better Pytorch compatibility
PiperOrigin-RevId: 515265898
1 parent 0d4cb70 commit d9c3927

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

dataclass_array/array_dataclass_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ def test_point_infered_np(
280280
shape: Shape,
281281
):
282282
if xnp is not None: # Normalize np arrays to test the various backend
283-
x = xnp.array(x)
284-
y = xnp.array(y)
283+
x = xnp.asarray(x)
284+
y = xnp.asarray(y)
285285
else:
286286
xnp = np
287287

@@ -565,7 +565,7 @@ def test_broadcast(xnp: enp.NpModule):
565565
p = Nested(
566566
# pt.shape broadcasted to (2, 3)
567567
pt=Point(
568-
x=xnp.array(0),
568+
x=xnp.asarray(0),
569569
y=xnp.zeros(broadcast_shape + (3,)),
570570
),
571571
# iso.shape == (), broadcasted to (2, 3)
@@ -642,13 +642,13 @@ class PointNoCast(dca.DataclassArray):
642642

643643
with pytest.raises(ValueError, match='Cannot cast float16'):
644644
PointNoCast(
645-
x=xnp.array([1, 2, 3], dtype=np.float16),
646-
y=xnp.array([1, 2, 3], dtype=np.float16),
645+
x=xnp.asarray([1, 2, 3], dtype=np.float16),
646+
y=xnp.asarray([1, 2, 3], dtype=np.float16),
647647
)
648648

649649
p = PointNoCast(
650-
x=xnp.array([1, 2, 3], dtype=np.float16),
651-
y=xnp.array([1, 2, 3], dtype=np.uint8),
650+
x=xnp.asarray([1, 2, 3], dtype=np.float16),
651+
y=xnp.asarray([1, 2, 3], dtype=np.uint8),
652652
)
653653
assert p.shape == (3,)
654654
assert enp.lazy.as_dtype(p.x.dtype) == np.float16
@@ -665,7 +665,7 @@ class PointNoList(dca.DataclassArray):
665665

666666
with pytest.raises(TypeError, match='Could not infer numpy module'):
667667
PointNoList(
668-
x=xnp.array(1, dtype=np.float16),
668+
x=xnp.asarray(1, dtype=np.float16),
669669
y=[1, 2, 3],
670670
)
671671

@@ -679,8 +679,8 @@ class PointNoBroadcast(dca.DataclassArray):
679679

680680
with pytest.raises(ValueError, match='Cannot broadcast'):
681681
PointNoBroadcast(
682-
x=xnp.array(1, dtype=np.float16),
683-
y=xnp.array([1, 2, 3], dtype=np.int32),
682+
x=xnp.asarray(1, dtype=np.float16),
683+
y=xnp.asarray([1, 2, 3], dtype=np.int32),
684684
)
685685

686686

dataclass_array/utils/np_utils_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@
2929
@enp.testing.parametrize_xnp()
3030
def test_get_xnp(xnp: enp.NpModule):
3131
# Dataclass array support
32-
r = dca.testing.Ray(pos=xnp.array([3.0, 0, 0]), dir=xnp.array([3.0, 0, 0]))
32+
r = dca.testing.Ray(
33+
pos=xnp.asarray([3.0, 0, 0]), dir=xnp.asarray([3.0, 0, 0])
34+
)
3335
assert np_utils.get_xnp(r) is xnp
3436
# Array support
35-
assert np_utils.get_xnp(xnp.array([3.0, 0, 0])) is xnp
37+
assert np_utils.get_xnp(xnp.asarray([3.0, 0, 0])) is xnp
3638

3739
with pytest.raises(TypeError, match='Unexpected array type'):
3840
np_utils.get_xnp('abc')

0 commit comments

Comments
 (0)