@@ -415,16 +415,25 @@ def test_ScalarLoop_while():
415415def test_ScalarLoop_Elemwise ():
416416 n_steps = int64 ("n_steps" )
417417 x0 = float64 ("x0" )
418+ x1 = float64 ("x1" )
418419 x = x0 * 2
420+ x1_n = x1 * 3
419421 until = x >= 10
420422
421- scalarop = ScalarLoop (init = [x0 ], update = [x ], until = until )
423+ scalarop = ScalarLoop (init = [x0 , x1 ], update = [x , x1_n ], until = until )
422424 op = Elemwise (scalarop )
423425
424426 n_steps = pt .scalar ("n_steps" , dtype = "int32" )
425427 x0 = pt .vector ("x0" , dtype = "float32" )
426- state , done = op (n_steps , x0 )
427-
428- f = FunctionGraph ([n_steps , x0 ], [state , done ])
429- args = [np .array (10 ).astype ("int32" ), np .arange (0 , 5 ).astype ("float32" )]
430- compare_pytorch_and_py (f , args )
428+ x1 = pt .tensor ("c0" , dtype = "float32" , shape = (7 , 3 , 1 ))
429+ * states , done = op (n_steps , x0 , x1 )
430+
431+ f = FunctionGraph ([n_steps , x0 , x1 ], [* states , done ])
432+ args = [
433+ np .array (10 ).astype ("int32" ),
434+ np .arange (0 , 5 ).astype ("float32" ),
435+ np .random .rand (7 , 3 , 1 ).astype ("float32" ),
436+ ]
437+ compare_pytorch_and_py (
438+ f , args , assert_fn = partial (np .testing .assert_allclose , rtol = 1e-6 )
439+ )
0 commit comments