Skip to content

Commit c467a34

Browse files
author
Ian Schweer
committed
Do iteration instead of vmap for elemwise
1 parent beb4440 commit c467a34

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from itertools import chain
2+
13
import torch
24

35
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
6+
from pytensor.scalar import ScalarLoop
47
from pytensor.tensor.elemwise import DimShuffle, Elemwise
58
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
69
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@@ -11,9 +14,38 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
1114
scalar_op = op.scalar_op
1215
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
1316

14-
def elemwise_fn(*inputs):
15-
Elemwise._check_runtime_broadcast(node, inputs)
16-
return base_fn(*inputs)
17+
if isinstance(scalar_op, ScalarLoop):
18+
# note: scalarloop + elemwise is too common
19+
# to not work, but @1031, vmap won't allow it.
20+
# Instead, we will just successively unbind
21+
def elemwise_fn(*inputs):
22+
Elemwise._check_runtime_broadcast(node, inputs)
23+
shaped_inputs = torch.broadcast_tensors(*inputs)
24+
expected_size = shaped_inputs[0].numel()
25+
final_inputs = [s.clone() for s in shaped_inputs]
26+
for _ in range(shaped_inputs[0].dim() - 1):
27+
for i, _ in enumerate(shaped_inputs):
28+
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
29+
final_inputs[i] = list(layer)
30+
31+
# make sure we still have the same number of things
32+
assert len(final_inputs) == len(shaped_inputs)
33+
34+
# make sure each group of things are the expected size
35+
assert all(len(x) == expected_size for x in final_inputs)
36+
37+
# make sure they are all single elements
38+
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
39+
res = [base_fn(*args) for args in zip(*final_inputs)]
40+
states = torch.stack(tuple(out[0] for out in res))
41+
done = torch.stack(tuple(out[1] for out in res))
42+
return states, done
43+
44+
else:
45+
46+
def elemwise_fn(*inputs):
47+
Elemwise._check_runtime_broadcast(node, inputs)
48+
return base_fn(*inputs)
1749

1850
return elemwise_fn
1951

tests/link/pytorch/test_basic.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66

7+
import pytensor.tensor as pt
78
import pytensor.tensor.basic as ptb
89
from pytensor.compile.builders import OpFromGraph
910
from pytensor.compile.function import function
@@ -385,10 +386,19 @@ def test_ScalarLoop_Elemwise():
385386
x = x0 * 2
386387
until = x >= 10
387388

388-
op = ScalarLoop(init=[x0], update=[x], until=until)
389-
fn = function([n_steps, x0], Elemwise(op)(n_steps, x0), mode=pytorch_mode)
389+
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
390+
op = Elemwise(scalarop)
391+
392+
n_steps = pt.scalar("n_steps", dtype="int32")
393+
x0 = pt.vector("x0", dtype="float32")
394+
state, done = op(n_steps, x0)
395+
396+
fn = function([n_steps, x0], [state, done], mode=pytorch_mode)
397+
py_fn = function([n_steps, x0], [state, done])
390398

391-
states, dones = fn(10, np.array(range(5)))
399+
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
400+
torch_states, torch_dones = fn(*args)
401+
py_states, py_dones = py_fn(*args)
392402

393-
np.testing.assert_allclose(states, [0, 4, 8, 12, 16])
394-
np.testing.assert_allclose(dones, [False, False, False, True, True])
403+
np.testing.assert_allclose(torch_states, py_states)
404+
np.testing.assert_allclose(torch_dones, py_dones)

0 commit comments

Comments
 (0)