@@ -365,7 +365,7 @@ def test_complex_shape(
365365 xnp = xnp ,
366366 )
367367 dca_cls .assert_val (
368- p .flatten ()[xnp .ones (p .size , dtype = np .bool_ )],
368+ p .flatten ()[xnp .ones (p .size , dtype = enp . lazy . as_dtype ( np .bool_ , xnp = xnp ) )],
369369 (p .size ,),
370370 xnp = xnp ,
371371 )
@@ -642,13 +642,13 @@ class PointNoCast(dca.DataclassArray):
642642
643643 with pytest .raises (ValueError , match = 'Cannot cast float16' ):
644644 PointNoCast (
645- x = xnp .asarray ([1 , 2 , 3 ], dtype = np .float16 ),
646- y = xnp .asarray ([1 , 2 , 3 ], dtype = np .float16 ),
645+ x = xnp .asarray ([1 , 2 , 3 ], dtype = xnp .float16 ),
646+ y = xnp .asarray ([1 , 2 , 3 ], dtype = xnp .float16 ),
647647 )
648648
649649 p = PointNoCast (
650- x = xnp .asarray ([1 , 2 , 3 ], dtype = np .float16 ),
651- y = xnp .asarray ([1 , 2 , 3 ], dtype = np .uint8 ),
650+ x = xnp .asarray ([1 , 2 , 3 ], dtype = xnp .float16 ),
651+ y = xnp .asarray ([1 , 2 , 3 ], dtype = xnp .uint8 ),
652652 )
653653 assert p .shape == (3 ,)
654654 assert enp .lazy .as_dtype (p .x .dtype ) == np .float16
@@ -665,7 +665,7 @@ class PointNoList(dca.DataclassArray):
665665
666666 with pytest .raises (TypeError , match = 'Could not infer numpy module' ):
667667 PointNoList (
668- x = xnp .asarray (1 , dtype = np .float16 ),
668+ x = xnp .asarray (1 , dtype = xnp .float16 ),
669669 y = [1 , 2 , 3 ],
670670 )
671671
@@ -679,8 +679,8 @@ class PointNoBroadcast(dca.DataclassArray):
679679
680680 with pytest .raises (ValueError , match = 'Cannot broadcast' ):
681681 PointNoBroadcast (
682- x = xnp .asarray (1 , dtype = np .float16 ),
683- y = xnp .asarray ([1 , 2 , 3 ], dtype = np .int32 ),
682+ x = xnp .asarray (1 , dtype = xnp .float16 ),
683+ y = xnp .asarray ([1 , 2 , 3 ], dtype = xnp .int32 ),
684684 )
685685
686686
@@ -693,16 +693,16 @@ class PointDynamicShape(dca.DataclassArray):
693693 y : IntArray ['... 3 _' ]
694694
695695 p = PointDynamicShape (
696- x = xnp .zeros (batch_shape + (2 , 3 ), dtype = np .float32 ),
697- y = xnp .zeros (batch_shape + (3 , 1 ), dtype = np .int32 ),
696+ x = xnp .zeros (batch_shape + (2 , 3 ), dtype = xnp .float32 ),
697+ y = xnp .zeros (batch_shape + (3 , 1 ), dtype = xnp .int32 ),
698698 )
699699 assert p .shape == batch_shape
700700 assert p .x .shape == batch_shape + (2 , 3 )
701701 assert p .y .shape == batch_shape + (3 , 1 )
702702
703703 p2 = PointDynamicShape (
704- x = xnp .zeros (batch_shape + (3 , 2 ), dtype = np .float32 ),
705- y = xnp .zeros (batch_shape + (3 , 1 ), dtype = np .int32 ),
704+ x = xnp .zeros (batch_shape + (3 , 2 ), dtype = xnp .float32 ),
705+ y = xnp .zeros (batch_shape + (3 , 1 ), dtype = xnp .int32 ),
706706 )
707707 assert p2 .shape == batch_shape
708708 assert p2 .x .shape == batch_shape + (3 , 2 )
@@ -727,17 +727,17 @@ class PointDynamicShape(dca.DataclassArray):
727727 err_msg = 'Shape do not match.'
728728 with pytest .raises (ValueError , match = err_msg ):
729729 PointDynamicShape (
730- x = xnp .zeros (batch_shape + (3 ,), dtype = np .float32 ), # len() != 2
731- y = xnp .zeros (batch_shape + (3 , 1 ), dtype = np .int32 ),
730+ x = xnp .zeros (batch_shape + (3 ,), dtype = xnp .float32 ), # len() != 2
731+ y = xnp .zeros (batch_shape + (3 , 1 ), dtype = xnp .int32 ),
732732 )
733733
734734 with pytest .raises (
735735 ValueError ,
736736 match = 'Shape do not match.' ,
737737 ):
738738 PointDynamicShape (
739- x = xnp .zeros (batch_shape + (2 , 3 ), dtype = np .float32 ),
740- y = xnp .zeros (batch_shape + (2 , 1 ), dtype = np .int32 ), # < 2 != 3
739+ x = xnp .zeros (batch_shape + (2 , 3 ), dtype = xnp .float32 ),
740+ y = xnp .zeros (batch_shape + (2 , 1 ), dtype = xnp .int32 ), # < 2 != 3
741741 )
742742
743743
0 commit comments