@@ -428,16 +428,25 @@ def test_ScalarLoop_while():
428428def test_ScalarLoop_Elemwise ():
429429 n_steps = int64 ("n_steps" )
430430 x0 = float64 ("x0" )
431+ x1 = float64 ("x1" )
431432 x = x0 * 2
433+ x1_n = x1 * 3
432434 until = x >= 10
433435
434- scalarop = ScalarLoop (init = [x0 ], update = [x ], until = until )
436+ scalarop = ScalarLoop (init = [x0 , x1 ], update = [x , x1_n ], until = until )
435437 op = Elemwise (scalarop )
436438
437439 n_steps = pt .scalar ("n_steps" , dtype = "int32" )
438440 x0 = pt .vector ("x0" , dtype = "float32" )
439- state , done = op (n_steps , x0 )
440-
441- f = FunctionGraph ([n_steps , x0 ], [state , done ])
442- args = [np .array (10 ).astype ("int32" ), np .arange (0 , 5 ).astype ("float32" )]
443- compare_pytorch_and_py (f , args )
441+ x1 = pt .tensor ("c0" , dtype = "float32" , shape = (7 , 3 , 1 ))
442+ * states , done = op (n_steps , x0 , x1 )
443+
444+ f = FunctionGraph ([n_steps , x0 , x1 ], [* states , done ])
445+ args = [
446+ np .array (10 ).astype ("int32" ),
447+ np .arange (0 , 5 ).astype ("float32" ),
448+ np .random .rand (7 , 3 , 1 ).astype ("float32" ),
449+ ]
450+ compare_pytorch_and_py (
451+ f , args , assert_fn = partial (np .testing .assert_allclose , rtol = 1e-6 )
452+ )
0 commit comments