|
1 | 1 | from collections.abc import Callable, Iterable |
2 | 2 | from functools import partial |
| 3 | +from itertools import repeat, starmap |
| 4 | +from unittest.mock import MagicMock, call, patch |
3 | 5 |
|
4 | 6 | import numpy as np |
5 | 7 | import pytest |
@@ -393,12 +395,29 @@ def test_ScalarLoop_Elemwise(): |
393 | 395 | x0 = pt.vector("x0", dtype="float32") |
394 | 396 | state, done = op(n_steps, x0) |
395 | 397 |
|
396 | | - fn = function([n_steps, x0], [state, done], mode=pytorch_mode) |
397 | | - py_fn = function([n_steps, x0], [state, done]) |
398 | | - |
| 398 | + f = FunctionGraph([n_steps, x0], [state, done]) |
399 | 399 | args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")] |
400 | | - torch_states, torch_dones = fn(*args) |
401 | | - py_states, py_dones = py_fn(*args) |
402 | | - |
403 | | - np.testing.assert_allclose(torch_states, py_states) |
404 | | - np.testing.assert_allclose(torch_dones, py_dones) |
| 400 | + compare_pytorch_and_py(f, args) |
| 401 | + |
| 402 | + |
| 403 | +torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") |
| 404 | + |
| 405 | + |
| 406 | +@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) |
| 407 | +@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise") |
| 408 | +def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes): |
| 409 | + args = [torch.ones(*s) for s in input_shapes[:-1]] + [ |
| 410 | + torch.zeros(*input_shapes[-1]) |
| 411 | + ] |
| 412 | + mock_inner_func = MagicMock() |
| 413 | + ret_value = torch.rand(2, 2).unbind(0) |
| 414 | + mock_inner_func.f.return_value = ret_value |
| 415 | + elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None) |
| 416 | + result = elemwise_fn(*args) |
| 417 | + for actual, expected in zip(ret_value, result): |
| 418 | + assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected))) |
| 419 | + np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0])) |
| 420 | + |
| 421 | + expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0) |
| 422 | + expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count)) |
| 423 | + mock_inner_func.f.assert_has_calls(expected_calls) |
0 commit comments