11from collections .abc import Callable , Iterable
22from functools import partial
3- from itertools import repeat , starmap
4- from unittest .mock import MagicMock , call , patch
53
64import numpy as np
75import pytest
@@ -350,6 +348,7 @@ def test_pytorch_OpFromGraph():
350348 f = FunctionGraph ([x , y , z ], [out ])
351349 compare_pytorch_and_py (f , [xv , yv , zv ])
352350
351+
353352def test_ScalarLoop ():
354353 n_steps = int64 ("n_steps" )
355354 x0 = float64 ("x0" )
@@ -376,25 +375,11 @@ def test_ScalarLoop_while():
376375 for res , expected in zip (
377376 [fn (n_steps = 20 , x0 = 0 ), fn (n_steps = 20 , x0 = 1 ), fn (n_steps = 5 , x0 = 1 )],
378377 [[10 , True ], [10 , True ], [6 , False ]],
378+ strict = True ,
379379 ):
380380 np .testing .assert_allclose (res [0 ], np .array (expected [0 ]))
381381 np .testing .assert_allclose (res [1 ], np .array (expected [1 ]))
382382
383- def test_pytorch_OpFromGraph ():
384- x , y , z = matrices ("xyz" )
385- ofg_1 = OpFromGraph ([x , y ], [x + y ])
386- ofg_2 = OpFromGraph ([x , y ], [x * y , x - y ])
387-
388- o1 , o2 = ofg_2 (y , z )
389- out = ofg_1 (x , o1 ) + o2
390-
391- xv = np .ones ((2 , 2 ), dtype = config .floatX )
392- yv = np .ones ((2 , 2 ), dtype = config .floatX ) * 3
393- zv = np .ones ((2 , 2 ), dtype = config .floatX ) * 5
394-
395- f = FunctionGraph ([x , y , z ], [out ])
396- compare_pytorch_and_py (f , [xv , yv , zv ])
397-
398383
399384def test_ScalarLoop_Elemwise ():
400385 n_steps = int64 ("n_steps" )
@@ -412,26 +397,3 @@ def test_ScalarLoop_Elemwise():
412397 f = FunctionGraph ([n_steps , x0 ], [state , done ])
413398 args = [np .array (10 ).astype ("int32" ), np .arange (0 , 5 ).astype ("float32" )]
414399 compare_pytorch_and_py (f , args )
415-
416-
417- torch_elemwise = pytest .importorskip ("pytensor.link.pytorch.dispatch.elemwise" )
418-
419-
420- @pytest .mark .parametrize ("input_shapes" , [[(5 , 1 , 1 , 8 ), (3 , 1 , 1 ), (8 ,)]])
421- @patch ("pytensor.link.pytorch.dispatch.elemwise.Elemwise" )
422- def test_ScalarLoop_Elemwise_iteration_logic (_ , input_shapes ):
423- args = [torch .ones (* s ) for s in input_shapes [:- 1 ]] + [
424- torch .zeros (* input_shapes [- 1 ])
425- ]
426- mock_inner_func = MagicMock ()
427- ret_value = torch .rand (2 , 2 ).unbind (0 )
428- mock_inner_func .f .return_value = ret_value
429- elemwise_fn = torch_elemwise .elemwise_scalar_loop (mock_inner_func .f , None , None )
430- result = elemwise_fn (* args )
431- for actual , expected in zip (ret_value , result ):
432- assert torch .all (torch .eq (* torch .broadcast_tensors (actual , expected )))
433- np .testing .assert_equal (mock_inner_func .f .call_count , len (result [0 ]))
434-
435- expected_args = torch .FloatTensor ([1.0 ] * (len (input_shapes ) - 1 ) + [0.0 ]).unbind (0 )
436- expected_calls = starmap (call , repeat (expected_args , mock_inner_func .f .call_count ))
437- mock_inner_func .f .assert_has_calls (expected_calls )
0 commit comments