File tree Expand file tree Collapse file tree 1 file changed +24
-1
lines changed
Expand file tree Collapse file tree 1 file changed +24
-1
lines changed Original file line number Diff line number Diff line change @@ -425,7 +425,30 @@ def test_ScalarLoop_while():
425425 np .testing .assert_allclose (res [1 ], np .array (expected [1 ]))
426426
427427
428- def test_ScalarLoop_Elemwise ():
428+ def test_ScalarLoop_Elemwise_single_carries ():
429+ n_steps = int64 ("n_steps" )
430+ x0 = float64 ("x0" )
431+ x = x0 * 2
432+ until = x >= 10
433+
434+ scalarop = ScalarLoop (init = [x0 ], update = [x ], until = until )
435+ op = Elemwise (scalarop )
436+
437+ n_steps = pt .scalar ("n_steps" , dtype = "int32" )
438+ x0 = pt .vector ("x0" , dtype = "float32" )
439+ state , done = op (n_steps , x0 )
440+
441+ f = FunctionGraph ([n_steps , x0 ], [state , done ])
442+ args = [
443+ np .array (10 ).astype ("int32" ),
444+ np .arange (0 , 5 ).astype ("float32" ),
445+ ]
446+ compare_pytorch_and_py (
447+ f , args , assert_fn = partial (np .testing .assert_allclose , rtol = 1e-6 )
448+ )
449+
450+
451+ def test_ScalarLoop_Elemwise_multi_carries ():
429452 n_steps = int64 ("n_steps" )
430453 x0 = float64 ("x0" )
431454 x1 = float64 ("x1" )
You can’t perform that action at this time.
0 commit comments