@@ -384,16 +384,25 @@ def test_ScalarLoop_while():
384384def test_ScalarLoop_Elemwise ():
385385 n_steps = int64 ("n_steps" )
386386 x0 = float64 ("x0" )
387+ x1 = float64 ("x1" )
387388 x = x0 * 2
389+ x1_n = x1 * 3
388390 until = x >= 10
389391
390- scalarop = ScalarLoop (init = [x0 ], update = [x ], until = until )
392+ scalarop = ScalarLoop (init = [x0 , x1 ], update = [x , x1_n ], until = until )
391393 op = Elemwise (scalarop )
392394
393395 n_steps = pt .scalar ("n_steps" , dtype = "int32" )
394396 x0 = pt .vector ("x0" , dtype = "float32" )
395- state , done = op (n_steps , x0 )
396-
397- f = FunctionGraph ([n_steps , x0 ], [state , done ])
398- args = [np .array (10 ).astype ("int32" ), np .arange (0 , 5 ).astype ("float32" )]
399- compare_pytorch_and_py (f , args )
397+ x1 = pt .tensor ("c0" , dtype = "float32" , shape = (7 , 3 , 1 ))
398+ * states , done = op (n_steps , x0 , x1 )
399+
400+ f = FunctionGraph ([n_steps , x0 , x1 ], [* states , done ])
401+ args = [
402+ np .array (10 ).astype ("int32" ),
403+ np .arange (0 , 5 ).astype ("float32" ),
404+ np .random .rand (7 , 3 , 1 ).astype ("float32" ),
405+ ]
406+ compare_pytorch_and_py (
407+ f , args , assert_fn = partial (np .testing .assert_allclose , rtol = 1e-6 )
408+ )
0 commit comments