66import pytensor .tensor as pt
77from pytensor .compile import get_mode
88from pytensor .configdefaults import config
9- from pytensor .graph .fg import FunctionGraph
10- from pytensor .graph .op import get_test_value
119from pytensor .tensor import elemwise as pt_elemwise
1210from pytensor .tensor .math import all as pt_all
1311from pytensor .tensor .math import prod
@@ -26,87 +24,81 @@ def test_jax_Dimshuffle():
2624 a_pt = matrix ("a" )
2725
2826 x = a_pt .T
29- x_fg = FunctionGraph ([a_pt ], [x ])
30- compare_jax_and_py (x_fg , [np .c_ [[1.0 , 2.0 ], [3.0 , 4.0 ]].astype (config .floatX )])
27+ compare_jax_and_py (
28+ [a_pt ], [x ], [np .c_ [[1.0 , 2.0 ], [3.0 , 4.0 ]].astype (config .floatX )]
29+ )
3130
3231 x = a_pt .dimshuffle ([0 , 1 , "x" ])
33- x_fg = FunctionGraph ([a_pt ], [x ])
34- compare_jax_and_py (x_fg , [np .c_ [[1.0 , 2.0 ], [3.0 , 4.0 ]].astype (config .floatX )])
32+ compare_jax_and_py (
33+ [a_pt ], [x ], [np .c_ [[1.0 , 2.0 ], [3.0 , 4.0 ]].astype (config .floatX )]
34+ )
3535
3636 a_pt = tensor (dtype = config .floatX , shape = (None , 1 ))
3737 x = a_pt .dimshuffle ((0 ,))
38- x_fg = FunctionGraph ([a_pt ], [x ])
39- compare_jax_and_py (x_fg , [np .c_ [[1.0 , 2.0 , 3.0 , 4.0 ]].astype (config .floatX )])
38+ compare_jax_and_py ([a_pt ], [x ], [np .c_ [[1.0 , 2.0 , 3.0 , 4.0 ]].astype (config .floatX )])
4039
4140 a_pt = tensor (dtype = config .floatX , shape = (None , 1 ))
4241 x = pt_elemwise .DimShuffle (input_ndim = 2 , new_order = (0 ,))(a_pt )
43- x_fg = FunctionGraph ([a_pt ], [x ])
44- compare_jax_and_py (x_fg , [np .c_ [[1.0 , 2.0 , 3.0 , 4.0 ]].astype (config .floatX )])
42+ compare_jax_and_py ([a_pt ], [x ], [np .c_ [[1.0 , 2.0 , 3.0 , 4.0 ]].astype (config .floatX )])
4543
4644
4745def test_jax_CAReduce ():
4846 a_pt = vector ("a" )
4947 a_pt .tag .test_value = np .r_ [1 , 2 , 3 ].astype (config .floatX )
5048
5149 x = pt_sum (a_pt , axis = None )
52- x_fg = FunctionGraph ([a_pt ], [x ])
5350
54- compare_jax_and_py (x_fg , [np .r_ [1 , 2 , 3 ].astype (config .floatX )])
51+ compare_jax_and_py ([ a_pt ], [ x ] , [np .r_ [1 , 2 , 3 ].astype (config .floatX )])
5552
5653 a_pt = matrix ("a" )
5754 a_pt .tag .test_value = np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )
5855
5956 x = pt_sum (a_pt , axis = 0 )
60- x_fg = FunctionGraph ([a_pt ], [x ])
6157
62- compare_jax_and_py (x_fg , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
58+ compare_jax_and_py ([ a_pt ], [ x ] , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
6359
6460 x = pt_sum (a_pt , axis = 1 )
65- x_fg = FunctionGraph ([a_pt ], [x ])
6661
67- compare_jax_and_py (x_fg , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
62+ compare_jax_and_py ([ a_pt ], [ x ] , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
6863
6964 a_pt = matrix ("a" )
7065 a_pt .tag .test_value = np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )
7166
7267 x = prod (a_pt , axis = 0 )
73- x_fg = FunctionGraph ([a_pt ], [x ])
7468
75- compare_jax_and_py (x_fg , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
69+ compare_jax_and_py ([ a_pt ], [ x ] , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
7670
7771 x = pt_all (a_pt )
78- x_fg = FunctionGraph ([a_pt ], [x ])
7972
80- compare_jax_and_py (x_fg , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
73+ compare_jax_and_py ([ a_pt ], [ x ] , [np .c_ [[1 , 2 , 3 ], [1 , 2 , 3 ]].astype (config .floatX )])
8174
8275
8376@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
8477def test_softmax (axis ):
8578 x = matrix ("x" )
86- x . tag . test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
79+ x_test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
8780 out = softmax (x , axis = axis )
88- fgraph = FunctionGraph ([x ], [out ])
89- compare_jax_and_py (fgraph , [get_test_value (i ) for i in fgraph .inputs ])
81+ compare_jax_and_py ([x ], [out ], [x_test_value ])
9082
9183
9284@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
9385def test_logsoftmax (axis ):
9486 x = matrix ("x" )
95- x . tag . test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
87+ x_test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
9688 out = log_softmax (x , axis = axis )
97- fgraph = FunctionGraph ([ x ], [ out ])
98- compare_jax_and_py (fgraph , [get_test_value ( i ) for i in fgraph . inputs ])
89+
90+ compare_jax_and_py ([ x ] , [out ], [ x_test_value ])
9991
10092
10193@pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
10294def test_softmax_grad (axis ):
10395 dy = matrix ("dy" )
104- dy . tag . test_value = np .array ([[1 , 1 , 1 ], [0 , 0 , 0 ]], dtype = config .floatX )
96+ dy_test_value = np .array ([[1 , 1 , 1 ], [0 , 0 , 0 ]], dtype = config .floatX )
10597 sm = matrix ("sm" )
106- sm . tag . test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
98+ sm_test_value = np .arange (6 , dtype = config .floatX ).reshape (2 , 3 )
10799 out = SoftmaxGrad (axis = axis )(dy , sm )
108- fgraph = FunctionGraph ([ dy , sm ], [ out ])
109- compare_jax_and_py (fgraph , [ get_test_value ( i ) for i in fgraph . inputs ])
100+
101+ compare_jax_and_py ([ dy , sm ], [ out ], [ dy_test_value , sm_test_value ])
110102
111103
112104@pytest .mark .parametrize ("size" , [(10 , 10 ), (1000 , 1000 )])
@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
134126def test_multiple_input_multiply ():
135127 x , y , z = vectors ("xyz" )
136128 out = pt .mul (x , y , z )
137-
138- fg = FunctionGraph (outputs = [out ], clone = False )
139- compare_jax_and_py (fg , [[1.5 ], [2.5 ], [3.5 ]])
129+ compare_jax_and_py ([x , y , z ], [out ], test_inputs = [[1.5 ], [2.5 ], [3.5 ]])
0 commit comments