|
1 | 1 | from collections.abc import Callable, Iterable |
2 | 2 | from functools import partial |
| 3 | +from itertools import repeat, starmap |
| 4 | +from unittest.mock import MagicMock, call, patch |
3 | 5 |
|
4 | 6 | import numpy as np |
5 | 7 | import pytest |
@@ -407,12 +409,29 @@ def test_ScalarLoop_Elemwise(): |
407 | 409 | x0 = pt.vector("x0", dtype="float32") |
408 | 410 | state, done = op(n_steps, x0) |
409 | 411 |
|
410 | | - fn = function([n_steps, x0], [state, done], mode=pytorch_mode) |
411 | | - py_fn = function([n_steps, x0], [state, done]) |
412 | | - |
| 412 | + f = FunctionGraph([n_steps, x0], [state, done]) |
413 | 413 | args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")] |
414 | | - torch_states, torch_dones = fn(*args) |
415 | | - py_states, py_dones = py_fn(*args) |
416 | | - |
417 | | - np.testing.assert_allclose(torch_states, py_states) |
418 | | - np.testing.assert_allclose(torch_dones, py_dones) |
| 414 | + compare_pytorch_and_py(f, args) |
| 415 | + |
| 416 | + |
| 417 | +torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") |
| 418 | + |
| 419 | + |
| 420 | +@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) |
| 421 | +@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise") |
| 422 | +def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes): |
| 423 | + args = [torch.ones(*s) for s in input_shapes[:-1]] + [ |
| 424 | + torch.zeros(*input_shapes[-1]) |
| 425 | + ] |
| 426 | + mock_inner_func = MagicMock() |
| 427 | + ret_value = torch.rand(2, 2).unbind(0) |
| 428 | + mock_inner_func.f.return_value = ret_value |
| 429 | + elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None) |
| 430 | + result = elemwise_fn(*args) |
| 431 | + for actual, expected in zip(ret_value, result): |
| 432 | + assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected))) |
| 433 | + np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0])) |
| 434 | + |
| 435 | + expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0) |
| 436 | + expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count)) |
| 437 | + mock_inner_func.f.assert_has_calls(expected_calls) |
0 commit comments