Skip to content

Commit 920f5a4

Browse files
committed
Add unit test to verify iteration
1 parent 067761f commit 920f5a4

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from collections.abc import Callable, Iterable
22
from functools import partial
3+
from itertools import repeat, starmap
4+
from unittest.mock import MagicMock, call, patch
35

46
import numpy as np
57
import pytest
@@ -393,12 +395,29 @@ def test_ScalarLoop_Elemwise():
393395
x0 = pt.vector("x0", dtype="float32")
394396
state, done = op(n_steps, x0)
395397

396-
fn = function([n_steps, x0], [state, done], mode=pytorch_mode)
397-
py_fn = function([n_steps, x0], [state, done])
398-
398+
f = FunctionGraph([n_steps, x0], [state, done])
399399
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
400-
torch_states, torch_dones = fn(*args)
401-
py_states, py_dones = py_fn(*args)
402-
403-
np.testing.assert_allclose(torch_states, py_states)
404-
np.testing.assert_allclose(torch_dones, py_dones)
400+
compare_pytorch_and_py(f, args)
401+
402+
403+
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")
404+
405+
406+
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
407+
@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise")
408+
def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes):
409+
args = [torch.ones(*s) for s in input_shapes[:-1]] + [
410+
torch.zeros(*input_shapes[-1])
411+
]
412+
mock_inner_func = MagicMock()
413+
ret_value = torch.rand(2, 2).unbind(0)
414+
mock_inner_func.f.return_value = ret_value
415+
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
416+
result = elemwise_fn(*args)
417+
for actual, expected in zip(ret_value, result):
418+
assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected)))
419+
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))
420+
421+
expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
422+
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
423+
mock_inner_func.f.assert_has_calls(expected_calls)

0 commit comments

Comments
 (0)