Skip to content

Commit fb90500

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Add unit test to verify iteration
1 parent e28c3e2 commit fb90500

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from collections.abc import Callable, Iterable
22
from functools import partial
3+
from itertools import repeat, starmap
4+
from unittest.mock import MagicMock, call, patch
35

46
import numpy as np
57
import pytest
@@ -452,12 +454,29 @@ def test_ScalarLoop_Elemwise():
452454
x0 = pt.vector("x0", dtype="float32")
453455
state, done = op(n_steps, x0)
454456

455-
fn = function([n_steps, x0], [state, done], mode=pytorch_mode)
456-
py_fn = function([n_steps, x0], [state, done])
457-
457+
f = FunctionGraph([n_steps, x0], [state, done])
458458
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
459-
torch_states, torch_dones = fn(*args)
460-
py_states, py_dones = py_fn(*args)
461-
462-
np.testing.assert_allclose(torch_states, py_states)
463-
np.testing.assert_allclose(torch_dones, py_dones)
459+
compare_pytorch_and_py(f, args)
460+
461+
462+
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")
463+
464+
465+
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
466+
@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise")
467+
def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes):
468+
args = [torch.ones(*s) for s in input_shapes[:-1]] + [
469+
torch.zeros(*input_shapes[-1])
470+
]
471+
mock_inner_func = MagicMock()
472+
ret_value = torch.rand(2, 2).unbind(0)
473+
mock_inner_func.f.return_value = ret_value
474+
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
475+
result = elemwise_fn(*args)
476+
for actual, expected in zip(ret_value, result):
477+
assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected)))
478+
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))
479+
480+
expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
481+
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
482+
mock_inner_func.f.assert_has_calls(expected_calls)

0 commit comments

Comments
 (0)