11import numpy as np
22import pytest
3- from packaging .version import parse as version_parse
43
54import pytensor .tensor .basic as ptb
65from pytensor .configdefaults import config
76from pytensor .graph .fg import FunctionGraph
87from pytensor .graph .op import get_test_value
98from pytensor .tensor import extra_ops as pt_extra_ops
10- from pytensor .tensor .type import matrix
9+ from pytensor .tensor .type import matrix , tensor
1110from tests .link .jax .test_basic import compare_jax_and_py
1211
1312
1413jax = pytest .importorskip ("jax" )
1514
1615
17- def set_test_value (x , v ):
18- x .tag .test_value = v
19- return x
20-
21-
2216def test_extra_ops ():
2317 a = matrix ("a" )
24- a . tag . test_value = np .arange (6 , dtype = config .floatX ).reshape ((3 , 2 ))
18+ a_test = np .arange (6 , dtype = config .floatX ).reshape ((3 , 2 ))
2519
2620 out = pt_extra_ops .cumsum (a , axis = 0 )
2721 fgraph = FunctionGraph ([a ], [out ])
28- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
22+ compare_jax_and_py (fgraph , [a_test ])
2923
3024 out = pt_extra_ops .cumprod (a , axis = 1 )
3125 fgraph = FunctionGraph ([a ], [out ])
32- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
26+ compare_jax_and_py (fgraph , [a_test ])
3327
3428 out = pt_extra_ops .diff (a , n = 2 , axis = 1 )
3529 fgraph = FunctionGraph ([a ], [out ])
36- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
30+ compare_jax_and_py (fgraph , [a_test ])
3731
3832 out = pt_extra_ops .repeat (a , (3 , 3 ), axis = 1 )
3933 fgraph = FunctionGraph ([a ], [out ])
40- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
34+ compare_jax_and_py (fgraph , [a_test ])
4135
4236 c = ptb .as_tensor (5 )
43-
4437 out = pt_extra_ops .fill_diagonal (a , c )
4538 fgraph = FunctionGraph ([a ], [out ])
46- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
39+ compare_jax_and_py (fgraph , [a_test ])
4740
4841 with pytest .raises (NotImplementedError ):
4942 out = pt_extra_ops .fill_diagonal_offset (a , c , c )
5043 fgraph = FunctionGraph ([a ], [out ])
51- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
44+ compare_jax_and_py (fgraph , [a_test ])
5245
5346 with pytest .raises (NotImplementedError ):
5447 out = pt_extra_ops .Unique (axis = 1 )(a )
5548 fgraph = FunctionGraph ([a ], [out ])
56- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
49+ compare_jax_and_py (fgraph , [a_test ])
5750
5851 indices = np .arange (np .prod ((3 , 4 )))
5952 out = pt_extra_ops .unravel_index (indices , (3 , 4 ), order = "C" )
@@ -63,40 +56,30 @@ def test_extra_ops():
6356 )
6457
6558
66- @pytest .mark .xfail (
67- version_parse (jax .__version__ ) >= version_parse ("0.2.12" ),
68- reason = "JAX Numpy API does not support dynamic shapes" ,
69- )
70- def test_extra_ops_dynamic_shapes ():
71- a = matrix ("a" )
72- a .tag .test_value = np .arange (6 , dtype = config .floatX ).reshape ((3 , 2 ))
73-
74- # This function also cannot take symbolic input.
75- c = ptb .as_tensor (5 )
59+ @pytest .mark .xfail (reason = "Jitted JAX does not support dynamic shapes" )
60+ def test_bartlett_dynamic_shape ():
61+ c = tensor (shape = (), dtype = int )
7662 out = pt_extra_ops .bartlett (c )
7763 fgraph = FunctionGraph ([], [out ])
78- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
64+ compare_jax_and_py (fgraph , [np . array ( 5 ) ])
7965
80- multi_index = np .unravel_index (np .arange (np .prod ((3 , 4 ))), (3 , 4 ))
81- out = pt_extra_ops .ravel_multi_index (multi_index , (3 , 4 ))
82- fgraph = FunctionGraph ([], [out ])
83- compare_jax_and_py (
84- fgraph , [get_test_value (i ) for i in fgraph .inputs ], must_be_device_array = False
85- )
8666
87- # The inputs are "concrete", yet it still has problems?
88- out = pt_extra_ops .Unique ()(
89- ptb .as_tensor (np .arange (6 , dtype = config .floatX ).reshape ((3 , 2 )))
90- )
67+ @pytest .mark .xfail (reason = "Jitted JAX does not support dynamic shapes" )
68+ def test_ravel_multi_index_dynamic_shape ():
69+ x_test , y_test = np .unravel_index (np .arange (np .prod ((3 , 4 ))), (3 , 4 ))
70+
71+ x = tensor (shape = (None ,), dtype = int )
72+ y = tensor (shape = (None ,), dtype = int )
73+ out = pt_extra_ops .ravel_multi_index ((x , y ), (3 , 4 ))
9174 fgraph = FunctionGraph ([], [out ])
92- compare_jax_and_py (fgraph , [])
75+ compare_jax_and_py (fgraph , [x_test , y_test ])
9376
9477
95- @pytest .mark .xfail (reason = "jax.numpy.arange requires concrete inputs " )
96- def test_unique_nonconcrete ():
78+ @pytest .mark .xfail (reason = "Jitted JAX does not support dynamic shapes " )
79+ def test_unique_dynamic_shape ():
9780 a = matrix ("a" )
98- a . tag . test_value = np .arange (6 , dtype = config .floatX ).reshape ((3 , 2 ))
81+ a_test = np .arange (6 , dtype = config .floatX ).reshape ((3 , 2 ))
9982
10083 out = pt_extra_ops .Unique ()(a )
10184 fgraph = FunctionGraph ([a ], [out ])
102- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
85+ compare_jax_and_py (fgraph , [a_test ])
0 commit comments