Skip to content

Commit cd678ef

Browse files
author
Ian Schweer
committed
Refactor to ravel method
1 parent fb90500 commit cd678ef

File tree

2 files changed

+25
-84
lines changed

2 files changed

+25
-84
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
import importlib
2-
from itertools import chain
3-
42
import torch
53

64
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
@@ -189,50 +187,32 @@ def elemwise_scalar_loop(base_fn, op, node, **kwargs):
189187
"""
190188
ScalarLoop + Elemwise is too common
191189
to not work, but @1031, vmap won't allow it.
192-
Instead, we can do the following strategy
193-
1. `.unbind(dim)` will return a list of tensors
194-
representing `dim` but "unwrapped". e.x.
195-
```
196-
t = torch.ones(3, 4, 2)
197-
len(t.unbind(0)) == 3
198-
t[0].shape == torch.Size[4, 2]
199-
2. If we successfully apply, the length of the list will grow
200-
by the next dimension in the tensor if we flatten the previous
201-
dimension result
202-
```
203-
inputs = [torch.ones(3, 4, 2)]
204-
level_1 = chain.from_iterable(t.unbind(0) for t in inputs)
205-
level_2 = chain.from_iterable(t.unbind(0) for t in level_1)
206-
len(level_2) == 3 * 4
207-
```
208-
3. Eventually we'll reach single dimension tensors. At that point
209-
we can iterate over each input in an element by element manner
210-
and call some function
211-
212-
For scalar loop, we need to broadcast the tensors so all
213-
the necessary values are repeated, and we "evenly" iterate through everything
190+
Instead, we can ravel all the inputs, broadcasted
191+
according to torch
214192
"""
215193

194+
n_outputs = len(node.outputs)
195+
216196
def elemwise_fn(*inputs):
217-
Elemwise._check_runtime_broadcast(node, inputs)
218-
shaped_inputs = torch.broadcast_tensors(*inputs)
219-
expected_size = shaped_inputs[0].numel()
220-
final_inputs = [s.clone() for s in shaped_inputs]
221-
for _ in range(shaped_inputs[0].dim() - 1):
222-
for i, _ in enumerate(shaped_inputs):
223-
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
224-
final_inputs[i] = list(layer)
225-
226-
# make sure we still have the same number of things
227-
assert len(final_inputs) == len(shaped_inputs)
228-
229-
# make sure each group of things are the expected size
230-
assert all(len(x) == expected_size for x in final_inputs)
231-
232-
# make sure they are all single elements
233-
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
234-
res = [base_fn(*args) for args in zip(*final_inputs)]
235-
236-
return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))]
197+
bcasted_inputs = torch.broadcast_tensors(*inputs)
198+
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]
199+
200+
out_shape = bcasted_inputs[0].size()
201+
out_size = out_shape.numel()
202+
raveled_outputs = [torch.zeros(out_size) for out in node.outputs]
203+
204+
for i in range(out_size):
205+
core_outs = base_fn(*(inp[i] for inp in raveled_inputs))
206+
if n_outputs == 1:
207+
raveled_outputs[0][i] = core_outs
208+
else:
209+
for o in range(n_outputs):
210+
raveled_outputs[o][i] = core_outs[o]
211+
212+
outputs = tuple(out.view(out_shape) for out in raveled_outputs)
213+
if n_outputs == 1:
214+
return outputs[0]
215+
else:
216+
return outputs
237217

238218
return elemwise_fn

tests/link/pytorch/test_basic.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from collections.abc import Callable, Iterable
22
from functools import partial
3-
from itertools import repeat, starmap
4-
from unittest.mock import MagicMock, call, patch
53

64
import numpy as np
75
import pytest
@@ -421,25 +419,11 @@ def test_ScalarLoop_while():
421419
for res, expected in zip(
422420
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
423421
[[10, True], [10, True], [6, False]],
422+
strict=True,
424423
):
425424
np.testing.assert_allclose(res[0], np.array(expected[0]))
426425
np.testing.assert_allclose(res[1], np.array(expected[1]))
427426

428-
def test_pytorch_OpFromGraph():
429-
x, y, z = matrices("xyz")
430-
ofg_1 = OpFromGraph([x, y], [x + y])
431-
ofg_2 = OpFromGraph([x, y], [x * y, x - y])
432-
433-
o1, o2 = ofg_2(y, z)
434-
out = ofg_1(x, o1) + o2
435-
436-
xv = np.ones((2, 2), dtype=config.floatX)
437-
yv = np.ones((2, 2), dtype=config.floatX) * 3
438-
zv = np.ones((2, 2), dtype=config.floatX) * 5
439-
440-
f = FunctionGraph([x, y, z], [out])
441-
compare_pytorch_and_py(f, [xv, yv, zv])
442-
443427

444428
def test_ScalarLoop_Elemwise():
445429
n_steps = int64("n_steps")
@@ -457,26 +441,3 @@ def test_ScalarLoop_Elemwise():
457441
f = FunctionGraph([n_steps, x0], [state, done])
458442
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
459443
compare_pytorch_and_py(f, args)
460-
461-
462-
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")
463-
464-
465-
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
466-
@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise")
467-
def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes):
468-
args = [torch.ones(*s) for s in input_shapes[:-1]] + [
469-
torch.zeros(*input_shapes[-1])
470-
]
471-
mock_inner_func = MagicMock()
472-
ret_value = torch.rand(2, 2).unbind(0)
473-
mock_inner_func.f.return_value = ret_value
474-
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
475-
result = elemwise_fn(*args)
476-
for actual, expected in zip(ret_value, result):
477-
assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected)))
478-
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))
479-
480-
expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
481-
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
482-
mock_inner_func.f.assert_has_calls(expected_calls)

0 commit comments

Comments
 (0)