File tree Expand file tree Collapse file tree 1 file changed +24
-1
lines changed
Expand file tree Collapse file tree 1 file changed +24
-1
lines changed Original file line number Diff line number Diff line change @@ -381,7 +381,30 @@ def test_ScalarLoop_while():
381381 np .testing .assert_allclose (res [1 ], np .array (expected [1 ]))
382382
383383
384- def test_ScalarLoop_Elemwise ():
384+ def test_ScalarLoop_Elemwise_single_carries ():
385+ n_steps = int64 ("n_steps" )
386+ x0 = float64 ("x0" )
387+ x = x0 * 2
388+ until = x >= 10
389+
390+ scalarop = ScalarLoop (init = [x0 ], update = [x ], until = until )
391+ op = Elemwise (scalarop )
392+
393+ n_steps = pt .scalar ("n_steps" , dtype = "int32" )
394+ x0 = pt .vector ("x0" , dtype = "float32" )
395+ state , done = op (n_steps , x0 )
396+
397+ f = FunctionGraph ([n_steps , x0 ], [state , done ])
398+ args = [
399+ np .array (10 ).astype ("int32" ),
400+ np .arange (0 , 5 ).astype ("float32" ),
401+ ]
402+ compare_pytorch_and_py (
403+ f , args , assert_fn = partial (np .testing .assert_allclose , rtol = 1e-6 )
404+ )
405+
406+
407+ def test_ScalarLoop_Elemwise_multi_carries ():
385408 n_steps = int64 ("n_steps" )
386409 x0 = float64 ("x0" )
387410 x1 = float64 ("x1" )
You can’t perform that action at this time.
0 commit comments