11from collections .abc import Callable , Iterable
22from functools import partial
3- from itertools import repeat , starmap
4- from unittest .mock import MagicMock , call , patch
53
64import numpy as np
75import pytest
@@ -421,25 +419,11 @@ def test_ScalarLoop_while():
421419 for res , expected in zip (
422420 [fn (n_steps = 20 , x0 = 0 ), fn (n_steps = 20 , x0 = 1 ), fn (n_steps = 5 , x0 = 1 )],
423421 [[10 , True ], [10 , True ], [6 , False ]],
422+ strict = True ,
424423 ):
425424 np .testing .assert_allclose (res [0 ], np .array (expected [0 ]))
426425 np .testing .assert_allclose (res [1 ], np .array (expected [1 ]))
427426
428- def test_pytorch_OpFromGraph ():
429- x , y , z = matrices ("xyz" )
430- ofg_1 = OpFromGraph ([x , y ], [x + y ])
431- ofg_2 = OpFromGraph ([x , y ], [x * y , x - y ])
432-
433- o1 , o2 = ofg_2 (y , z )
434- out = ofg_1 (x , o1 ) + o2
435-
436- xv = np .ones ((2 , 2 ), dtype = config .floatX )
437- yv = np .ones ((2 , 2 ), dtype = config .floatX ) * 3
438- zv = np .ones ((2 , 2 ), dtype = config .floatX ) * 5
439-
440- f = FunctionGraph ([x , y , z ], [out ])
441- compare_pytorch_and_py (f , [xv , yv , zv ])
442-
443427
444428def test_ScalarLoop_Elemwise ():
445429 n_steps = int64 ("n_steps" )
@@ -457,26 +441,3 @@ def test_ScalarLoop_Elemwise():
457441 f = FunctionGraph ([n_steps , x0 ], [state , done ])
458442 args = [np .array (10 ).astype ("int32" ), np .arange (0 , 5 ).astype ("float32" )]
459443 compare_pytorch_and_py (f , args )
460-
461-
462- torch_elemwise = pytest .importorskip ("pytensor.link.pytorch.dispatch.elemwise" )
463-
464-
465- @pytest .mark .parametrize ("input_shapes" , [[(5 , 1 , 1 , 8 ), (3 , 1 , 1 ), (8 ,)]])
466- @patch ("pytensor.link.pytorch.dispatch.elemwise.Elemwise" )
467- def test_ScalarLoop_Elemwise_iteration_logic (_ , input_shapes ):
468- args = [torch .ones (* s ) for s in input_shapes [:- 1 ]] + [
469- torch .zeros (* input_shapes [- 1 ])
470- ]
471- mock_inner_func = MagicMock ()
472- ret_value = torch .rand (2 , 2 ).unbind (0 )
473- mock_inner_func .f .return_value = ret_value
474- elemwise_fn = torch_elemwise .elemwise_scalar_loop (mock_inner_func .f , None , None )
475- result = elemwise_fn (* args )
476- for actual , expected in zip (ret_value , result ):
477- assert torch .all (torch .eq (* torch .broadcast_tensors (actual , expected )))
478- np .testing .assert_equal (mock_inner_func .f .call_count , len (result [0 ]))
479-
480- expected_args = torch .FloatTensor ([1.0 ] * (len (input_shapes ) - 1 ) + [0.0 ]).unbind (0 )
481- expected_calls = starmap (call , repeat (expected_args , mock_inner_func .f .call_count ))
482- mock_inner_func .f .assert_has_calls (expected_calls )
0 commit comments