2828import tensorflow as tf
2929
3030# Activate the fixture
31- set_tnp = enp .testing .set_tnp
31+ enable_torch_tf_np_mode = enp .testing .enable_torch_tf_np_mode
3232
3333# TODO(epot): Test dtype `complex`, `str`
3434
@@ -70,8 +70,8 @@ def assert_val(p: Point, shape: Shape, xnp: enp.NpModule = None):
7070 _assert_common (p , shape = shape , xnp = xnp )
7171 assert p .x .shape == shape
7272 assert p .y .shape == shape
73- assert p .x .dtype == np .float32
74- assert p .y .dtype == np .float32
73+ assert enp . lazy . as_dtype ( p .x .dtype ) == np .float32
74+ assert enp . lazy . as_dtype ( p .y .dtype ) == np .float32
7575 assert isinstance (p .x , xnp .ndarray )
7676 assert isinstance (p .y , xnp .ndarray )
7777
@@ -98,8 +98,8 @@ def assert_val(p: Isometrie, shape: Shape, xnp: enp.NpModule = None):
9898 _assert_common (p , shape = shape , xnp = xnp )
9999 assert p .r .shape == shape + (3 , 3 )
100100 assert p .t .shape == shape + (2 ,)
101- assert p .r .dtype == np .float32
102- assert p .t .dtype == np .int32
101+ assert enp . lazy . as_dtype ( p .r .dtype ) == np .float32
102+ assert enp . lazy . as_dtype ( p .t .dtype ) == np .int32
103103 assert isinstance (p .r , xnp .ndarray )
104104 assert isinstance (p .t , xnp .ndarray )
105105
@@ -226,8 +226,8 @@ def assert_val(p: WithStatic, shape: Shape, xnp: enp.NpModule = None):
226226 NestedOnlyStatic .assert_val (p .nested_static , shape , xnp = xnp )
227227 assert p .x .shape == shape + (3 ,)
228228 assert p .y .shape == shape + (2 , 2 )
229- assert p .x .dtype == np .float32
230- assert p .y .dtype == np .float32
229+ assert enp . lazy . as_dtype ( p .x .dtype ) == np .float32
230+ assert enp . lazy . as_dtype ( p .y .dtype ) == np .float32
231231 assert isinstance (p .x , xnp .ndarray )
232232 assert isinstance (p .y , xnp .ndarray )
233233 # Static field is correctly forwarded
@@ -546,12 +546,15 @@ def test_convert(
546546):
547547 p = dca_cls .make (xnp = xnp , shape = (2 ,))
548548 assert p .xnp is xnp
549+
549550 assert p .as_np ().xnp is enp .lazy .np
550551 assert p .as_jax ().xnp is enp .lazy .jnp
551552 assert p .as_tf ().xnp is enp .lazy .tnp
553+ assert p .as_torch ().xnp is enp .lazy .torch
552554 assert p .as_xnp (np ).xnp is enp .lazy .np
553555 assert p .as_xnp (enp .lazy .jnp ).xnp is enp .lazy .jnp
554556 assert p .as_xnp (enp .lazy .tnp ).xnp is enp .lazy .tnp
557+ assert p .as_xnp (enp .lazy .torch ).xnp is enp .lazy .torch
555558 # Make sure the nested class are also updated
556559 dca_cls .assert_val (p .as_jax (), (2 ,), xnp = enp .lazy .jnp )
557560
@@ -587,24 +590,44 @@ def test_infer_np(xnp: enp.NpModule):
587590
588591
589592@parametrize_dataclass_arrays
590- def test_jax_tree_map (dca_cls : DcaTest ):
593+ @pytest .mark .parametrize (
594+ 'tree_map' ,
595+ [
596+ enp .lazy .jax .tree_map ,
597+ enp .lazy .torch .utils ._pytree .tree_map ,
598+ ],
599+ )
600+ def test_torch_tree_map (tree_map , dca_cls : DcaTest ):
591601 p = dca_cls .make (shape = (3 ,), xnp = np )
592- p = enp . lazy . jax . tree_map (lambda x : x [None , ...], p )
602+ p = tree_map (lambda x : x [None , ...], p )
593603 dca_cls .assert_val (p , (1 , 3 ), xnp = np )
594604
595605
596- def test_jax_vmap ():
606+ @enp .testing .parametrize_xnp (
607+ restrict = [
608+ 'jnp' ,
609+ 'torch' ,
610+ ]
611+ )
612+ def test_vmap (xnp : enp .NpModule ):
613+ import functorch
614+
615+ vmap_fn = {
616+ enp .lazy .jnp : enp .lazy .jax .vmap ,
617+ enp .lazy .torch : functorch .vmap ,
618+ }[xnp ]
619+
597620 batch_shape = 3
598621
599- @enp . lazy . jax . vmap
622+ @vmap_fn
600623 def fn (p : WithStatic ) -> WithStatic :
601624 assert isinstance (p , WithStatic )
602625 assert p .shape == () # pylint:disable=g-explicit-bool-comparison
603626 return p .replace (x = p .x + 1 )
604627
605- x = WithStatic .make ((batch_shape ,), xnp = enp . lazy . jnp )
628+ x = WithStatic .make ((batch_shape ,), xnp = xnp )
606629 y = fn (x )
607- WithStatic .assert_val (y , (batch_shape ,), xnp = enp . lazy . jnp )
630+ WithStatic .assert_val (y , (batch_shape ,), xnp = xnp )
608631 # pos was updated
609632 np .testing .assert_allclose (y .x , np .ones ((batch_shape , 3 )))
610633 np .testing .assert_allclose (y .y , np .zeros ((batch_shape , 2 , 2 )))
@@ -628,8 +651,8 @@ class PointNoCast(dca.DataclassArray):
628651 y = xnp .array ([1 , 2 , 3 ], dtype = np .uint8 ),
629652 )
630653 assert p .shape == (3 ,)
631- assert p .x .dtype == np .float16
632- assert p .y .dtype == np .uint8
654+ assert enp . lazy . as_dtype ( p .x .dtype ) == np .float16
655+ assert enp . lazy . as_dtype ( p .y .dtype ) == np .uint8
633656
634657
635658@enp .testing .parametrize_xnp ()
@@ -689,7 +712,13 @@ class PointDynamicShape(dca.DataclassArray):
689712 assert dca .stack ([p , p ]).shape == (2 ,) + batch_shape
690713
691714 # Incompatible shape will raise an error
692- with pytest .raises ((ValueError , tf .errors .InvalidArgumentError )):
715+ expected_exception_cls = {
716+ enp .lazy .np : ValueError ,
717+ enp .lazy .jnp : ValueError ,
718+ enp .lazy .tnp : tf .errors .InvalidArgumentError ,
719+ enp .lazy .torch : RuntimeError ,
720+ }
721+ with pytest .raises (expected_exception_cls [xnp ]):
693722 dca .stack ([p , p2 ])
694723
695724 if batch_shape :
0 commit comments