Skip to content

Commit 561301b

Browse files
author
Ian Schweer
committed
Do iteration instead of vmap for elemwise
1 parent 07e4520 commit 561301b

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import importlib
2+
from itertools import chain
23

34
import torch
45

56
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
7+
from pytensor.scalar import ScalarLoop
68
from pytensor.tensor.elemwise import DimShuffle, Elemwise
79
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
810
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@@ -33,6 +35,33 @@ def elemwise_fn(*inputs):
3335
Elemwise._check_runtime_broadcast(node, inputs)
3436
return base_fn(*inputs)
3537

38+
elif isinstance(scalar_op, ScalarLoop):
39+
# note: scalarloop + elemwise is too common
40+
# to not work, but @1031, vmap won't allow it.
41+
# Instead, we will just successively unbind
42+
def elemwise_fn(*inputs):
43+
Elemwise._check_runtime_broadcast(node, inputs)
44+
shaped_inputs = torch.broadcast_tensors(*inputs)
45+
expected_size = shaped_inputs[0].numel()
46+
final_inputs = [s.clone() for s in shaped_inputs]
47+
for _ in range(shaped_inputs[0].dim() - 1):
48+
for i, _ in enumerate(shaped_inputs):
49+
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
50+
final_inputs[i] = list(layer)
51+
52+
# make sure we still have the same number of things
53+
assert len(final_inputs) == len(shaped_inputs)
54+
55+
# make sure each group of things are the expected size
56+
assert all(len(x) == expected_size for x in final_inputs)
57+
58+
# make sure they are all single elements
59+
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
60+
res = [base_fn(*args) for args in zip(*final_inputs)]
61+
states = torch.stack(tuple(out[0] for out in res))
62+
done = torch.stack(tuple(out[1] for out in res))
63+
return states, done
64+
3665
else:
3766

3867
def elemwise_fn(*inputs):
@@ -42,6 +71,7 @@ def elemwise_fn(*inputs):
4271
for _ in range(broadcast_inputs[0].dim()):
4372
ufunc = torch.vmap(ufunc)
4473
return ufunc(*broadcast_inputs)
74+
return base_fn(*inputs)
4575

4676
return elemwise_fn
4777

tests/link/pytorch/test_basic.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66

7+
import pytensor.tensor as pt
78
import pytensor.tensor.basic as ptb
89
from pytensor.compile.builders import OpFromGraph
910
from pytensor.compile.function import function
@@ -444,10 +445,19 @@ def test_ScalarLoop_Elemwise():
444445
x = x0 * 2
445446
until = x >= 10
446447

447-
op = ScalarLoop(init=[x0], update=[x], until=until)
448-
fn = function([n_steps, x0], Elemwise(op)(n_steps, x0), mode=pytorch_mode)
448+
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
449+
op = Elemwise(scalarop)
450+
451+
n_steps = pt.scalar("n_steps", dtype="int32")
452+
x0 = pt.vector("x0", dtype="float32")
453+
state, done = op(n_steps, x0)
454+
455+
fn = function([n_steps, x0], [state, done], mode=pytorch_mode)
456+
py_fn = function([n_steps, x0], [state, done])
449457

450-
states, dones = fn(10, np.array(range(5)))
458+
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
459+
torch_states, torch_dones = fn(*args)
460+
py_states, py_dones = py_fn(*args)
451461

452-
np.testing.assert_allclose(states, [0, 4, 8, 12, 16])
453-
np.testing.assert_allclose(dones, [False, False, False, True, True])
462+
np.testing.assert_allclose(torch_states, py_states)
463+
np.testing.assert_allclose(torch_dones, py_dones)

0 commit comments

Comments
 (0)