11from collections .abc import Callable , Iterable
22from functools import partial
3- from itertools import repeat , starmap
4- from unittest .mock import MagicMock , call , patch
53
64import numpy as np
75import pytest
@@ -408,25 +406,11 @@ def test_ScalarLoop_while():
408406 for res , expected in zip (
409407 [fn (n_steps = 20 , x0 = 0 ), fn (n_steps = 20 , x0 = 1 ), fn (n_steps = 5 , x0 = 1 )],
410408 [[10 , True ], [10 , True ], [6 , False ]],
409+ strict = True ,
411410 ):
412411 np .testing .assert_allclose (res [0 ], np .array (expected [0 ]))
413412 np .testing .assert_allclose (res [1 ], np .array (expected [1 ]))
414413
415- def test_pytorch_OpFromGraph ():
416- x , y , z = matrices ("xyz" )
417- ofg_1 = OpFromGraph ([x , y ], [x + y ])
418- ofg_2 = OpFromGraph ([x , y ], [x * y , x - y ])
419-
420- o1 , o2 = ofg_2 (y , z )
421- out = ofg_1 (x , o1 ) + o2
422-
423- xv = np .ones ((2 , 2 ), dtype = config .floatX )
424- yv = np .ones ((2 , 2 ), dtype = config .floatX ) * 3
425- zv = np .ones ((2 , 2 ), dtype = config .floatX ) * 5
426-
427- f = FunctionGraph ([x , y , z ], [out ])
428- compare_pytorch_and_py (f , [xv , yv , zv ])
429-
430414
431415def test_ScalarLoop_Elemwise ():
432416 n_steps = int64 ("n_steps" )
@@ -444,26 +428,3 @@ def test_ScalarLoop_Elemwise():
444428 f = FunctionGraph ([n_steps , x0 ], [state , done ])
445429 args = [np .array (10 ).astype ("int32" ), np .arange (0 , 5 ).astype ("float32" )]
446430 compare_pytorch_and_py (f , args )
447-
448-
449- torch_elemwise = pytest .importorskip ("pytensor.link.pytorch.dispatch.elemwise" )
450-
451-
452- @pytest .mark .parametrize ("input_shapes" , [[(5 , 1 , 1 , 8 ), (3 , 1 , 1 ), (8 ,)]])
453- @patch ("pytensor.link.pytorch.dispatch.elemwise.Elemwise" )
454- def test_ScalarLoop_Elemwise_iteration_logic (_ , input_shapes ):
455- args = [torch .ones (* s ) for s in input_shapes [:- 1 ]] + [
456- torch .zeros (* input_shapes [- 1 ])
457- ]
458- mock_inner_func = MagicMock ()
459- ret_value = torch .rand (2 , 2 ).unbind (0 )
460- mock_inner_func .f .return_value = ret_value
461- elemwise_fn = torch_elemwise .elemwise_scalar_loop (mock_inner_func .f , None , None )
462- result = elemwise_fn (* args )
463- for actual , expected in zip (ret_value , result ):
464- assert torch .all (torch .eq (* torch .broadcast_tensors (actual , expected )))
465- np .testing .assert_equal (mock_inner_func .f .call_count , len (result [0 ]))
466-
467- expected_args = torch .FloatTensor ([1.0 ] * (len (input_shapes ) - 1 ) + [0.0 ]).unbind (0 )
468- expected_calls = starmap (call , repeat (expected_args , mock_inner_func .f .call_count ))
469- mock_inner_func .f .assert_has_calls (expected_calls )
0 commit comments