|
1 | 1 | from collections.abc import Callable, Iterable |
2 | 2 | from functools import partial |
| 3 | +from itertools import repeat, starmap |
| 4 | +from unittest.mock import MagicMock, call, patch |
3 | 5 |
|
4 | 6 | import numpy as np |
5 | 7 | import pytest |
@@ -452,12 +454,29 @@ def test_ScalarLoop_Elemwise(): |
452 | 454 | x0 = pt.vector("x0", dtype="float32") |
453 | 455 | state, done = op(n_steps, x0) |
454 | 456 |
|
455 | | - fn = function([n_steps, x0], [state, done], mode=pytorch_mode) |
456 | | - py_fn = function([n_steps, x0], [state, done]) |
457 | | - |
| 457 | + f = FunctionGraph([n_steps, x0], [state, done]) |
458 | 458 | args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")] |
459 | | - torch_states, torch_dones = fn(*args) |
460 | | - py_states, py_dones = py_fn(*args) |
461 | | - |
462 | | - np.testing.assert_allclose(torch_states, py_states) |
463 | | - np.testing.assert_allclose(torch_dones, py_dones) |
| 459 | + compare_pytorch_and_py(f, args) |
| 460 | + |
| 461 | + |
| 462 | +torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") |
| 463 | + |
| 464 | + |
| 465 | +@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) |
| 466 | +@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise") |
| 467 | +def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes): |
| 468 | + args = [torch.ones(*s) for s in input_shapes[:-1]] + [ |
| 469 | + torch.zeros(*input_shapes[-1]) |
| 470 | + ] |
| 471 | + mock_inner_func = MagicMock() |
| 472 | + ret_value = torch.rand(2, 2).unbind(0) |
| 473 | + mock_inner_func.f.return_value = ret_value |
| 474 | + elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None) |
| 475 | + result = elemwise_fn(*args) |
| 476 | + for actual, expected in zip(ret_value, result): |
| 477 | + assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected))) |
| 478 | + np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0])) |
| 479 | + |
| 480 | + expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0) |
| 481 | + expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count)) |
| 482 | + mock_inner_func.f.assert_has_calls(expected_calls) |
0 commit comments