Skip to content

Commit 907a02c

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Add unit test to verify iteration
1 parent 947d024 commit 907a02c

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from collections.abc import Callable, Iterable
22
from functools import partial
3+
from itertools import repeat, starmap
4+
from unittest.mock import MagicMock, call, patch
35

46
import numpy as np
57
import pytest
@@ -407,12 +409,29 @@ def test_ScalarLoop_Elemwise():
407409
x0 = pt.vector("x0", dtype="float32")
408410
state, done = op(n_steps, x0)
409411

410-
fn = function([n_steps, x0], [state, done], mode=pytorch_mode)
411-
py_fn = function([n_steps, x0], [state, done])
412-
412+
f = FunctionGraph([n_steps, x0], [state, done])
413413
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
414-
torch_states, torch_dones = fn(*args)
415-
py_states, py_dones = py_fn(*args)
416-
417-
np.testing.assert_allclose(torch_states, py_states)
418-
np.testing.assert_allclose(torch_dones, py_dones)
414+
compare_pytorch_and_py(f, args)
415+
416+
417+
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")
418+
419+
420+
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
421+
@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise")
422+
def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes):
423+
args = [torch.ones(*s) for s in input_shapes[:-1]] + [
424+
torch.zeros(*input_shapes[-1])
425+
]
426+
mock_inner_func = MagicMock()
427+
ret_value = torch.rand(2, 2).unbind(0)
428+
mock_inner_func.f.return_value = ret_value
429+
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
430+
result = elemwise_fn(*args)
431+
for actual, expected in zip(ret_value, result):
432+
assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected)))
433+
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))
434+
435+
expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
436+
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
437+
mock_inner_func.f.assert_has_calls(expected_calls)

0 commit comments

Comments
 (0)