Skip to content

Commit 845f715

Browse files
ConchylicultorThe dataclass_array Authors
authored andcommitted
Use xnp.dtype instead of np.dtype
PiperOrigin-RevId: 515337644
1 parent 6ff58ff commit 845f715

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

dataclass_array/array_dataclass_test.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def test_complex_shape(
365365
xnp=xnp,
366366
)
367367
dca_cls.assert_val(
368-
p.flatten()[xnp.ones(p.size, dtype=np.bool_)],
368+
p.flatten()[xnp.ones(p.size, dtype=enp.lazy.as_dtype(np.bool_, xnp=xnp))],
369369
(p.size,),
370370
xnp=xnp,
371371
)
@@ -642,13 +642,13 @@ class PointNoCast(dca.DataclassArray):
642642

643643
with pytest.raises(ValueError, match='Cannot cast float16'):
644644
PointNoCast(
645-
x=xnp.asarray([1, 2, 3], dtype=np.float16),
646-
y=xnp.asarray([1, 2, 3], dtype=np.float16),
645+
x=xnp.asarray([1, 2, 3], dtype=xnp.float16),
646+
y=xnp.asarray([1, 2, 3], dtype=xnp.float16),
647647
)
648648

649649
p = PointNoCast(
650-
x=xnp.asarray([1, 2, 3], dtype=np.float16),
651-
y=xnp.asarray([1, 2, 3], dtype=np.uint8),
650+
x=xnp.asarray([1, 2, 3], dtype=xnp.float16),
651+
y=xnp.asarray([1, 2, 3], dtype=xnp.uint8),
652652
)
653653
assert p.shape == (3,)
654654
assert enp.lazy.as_dtype(p.x.dtype) == np.float16
@@ -665,7 +665,7 @@ class PointNoList(dca.DataclassArray):
665665

666666
with pytest.raises(TypeError, match='Could not infer numpy module'):
667667
PointNoList(
668-
x=xnp.asarray(1, dtype=np.float16),
668+
x=xnp.asarray(1, dtype=xnp.float16),
669669
y=[1, 2, 3],
670670
)
671671

@@ -679,8 +679,8 @@ class PointNoBroadcast(dca.DataclassArray):
679679

680680
with pytest.raises(ValueError, match='Cannot broadcast'):
681681
PointNoBroadcast(
682-
x=xnp.asarray(1, dtype=np.float16),
683-
y=xnp.asarray([1, 2, 3], dtype=np.int32),
682+
x=xnp.asarray(1, dtype=xnp.float16),
683+
y=xnp.asarray([1, 2, 3], dtype=xnp.int32),
684684
)
685685

686686

@@ -693,16 +693,16 @@ class PointDynamicShape(dca.DataclassArray):
693693
y: IntArray['... 3 _']
694694

695695
p = PointDynamicShape(
696-
x=xnp.zeros(batch_shape + (2, 3), dtype=np.float32),
697-
y=xnp.zeros(batch_shape + (3, 1), dtype=np.int32),
696+
x=xnp.zeros(batch_shape + (2, 3), dtype=xnp.float32),
697+
y=xnp.zeros(batch_shape + (3, 1), dtype=xnp.int32),
698698
)
699699
assert p.shape == batch_shape
700700
assert p.x.shape == batch_shape + (2, 3)
701701
assert p.y.shape == batch_shape + (3, 1)
702702

703703
p2 = PointDynamicShape(
704-
x=xnp.zeros(batch_shape + (3, 2), dtype=np.float32),
705-
y=xnp.zeros(batch_shape + (3, 1), dtype=np.int32),
704+
x=xnp.zeros(batch_shape + (3, 2), dtype=xnp.float32),
705+
y=xnp.zeros(batch_shape + (3, 1), dtype=xnp.int32),
706706
)
707707
assert p2.shape == batch_shape
708708
assert p2.x.shape == batch_shape + (3, 2)
@@ -727,17 +727,17 @@ class PointDynamicShape(dca.DataclassArray):
727727
err_msg = 'Shape do not match.'
728728
with pytest.raises(ValueError, match=err_msg):
729729
PointDynamicShape(
730-
x=xnp.zeros(batch_shape + (3,), dtype=np.float32), # len() != 2
731-
y=xnp.zeros(batch_shape + (3, 1), dtype=np.int32),
730+
x=xnp.zeros(batch_shape + (3,), dtype=xnp.float32), # len() != 2
731+
y=xnp.zeros(batch_shape + (3, 1), dtype=xnp.int32),
732732
)
733733

734734
with pytest.raises(
735735
ValueError,
736736
match='Shape do not match.',
737737
):
738738
PointDynamicShape(
739-
x=xnp.zeros(batch_shape + (2, 3), dtype=np.float32),
740-
y=xnp.zeros(batch_shape + (2, 1), dtype=np.int32), # < 2 != 3
739+
x=xnp.zeros(batch_shape + (2, 3), dtype=xnp.float32),
740+
y=xnp.zeros(batch_shape + (2, 1), dtype=xnp.int32), # < 2 != 3
741741
)
742742

743743

0 commit comments

Comments
 (0)