Skip to content

Commit 067761f

Browse files
committed
Clean up and add description
1 parent c467a34 commit 067761f

File tree

1 file changed

+55
-27
lines changed

1 file changed

+55
-27
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,10 @@
1212
@pytorch_funcify.register(Elemwise)
1313
def pytorch_funcify_Elemwise(op, node, **kwargs):
1414
scalar_op = op.scalar_op
15-
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
1615

16+
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
1717
if isinstance(scalar_op, ScalarLoop):
18-
# note: scalarloop + elemwise is too common
19-
# to not work, but @1031, vmap won't allow it.
20-
# Instead, we will just successively unbind
21-
def elemwise_fn(*inputs):
22-
Elemwise._check_runtime_broadcast(node, inputs)
23-
shaped_inputs = torch.broadcast_tensors(*inputs)
24-
expected_size = shaped_inputs[0].numel()
25-
final_inputs = [s.clone() for s in shaped_inputs]
26-
for _ in range(shaped_inputs[0].dim() - 1):
27-
for i, _ in enumerate(shaped_inputs):
28-
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
29-
final_inputs[i] = list(layer)
30-
31-
# make sure we still have the same number of things
32-
assert len(final_inputs) == len(shaped_inputs)
33-
34-
# make sure each group of things are the expected size
35-
assert all(len(x) == expected_size for x in final_inputs)
36-
37-
# make sure they are all single elements
38-
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
39-
res = [base_fn(*args) for args in zip(*final_inputs)]
40-
states = torch.stack(tuple(out[0] for out in res))
41-
done = torch.stack(tuple(out[1] for out in res))
42-
return states, done
43-
18+
return elemwise_scalar_loop(base_fn, op, node, **kwargs)
4419
else:
4520

4621
def elemwise_fn(*inputs):
@@ -180,3 +155,56 @@ def softmax_grad(dy, sm):
180155
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
181156

182157
return softmax_grad
158+
159+
160+
def elemwise_scalar_loop(base_fn, op, node, **kwargs):
161+
"""
162+
ScalarLoop + Elemwise is too common
163+
to not work, but @1031, vmap won't allow it.
164+
Instead, we can do the following strategy
165+
1. `.unbind(dim)` will return a list of tensors
166+
representing `dim` but "unwrapped". e.x.
167+
```
168+
t = torch.ones(3, 4, 2)
169+
len(t.unbind(0)) == 3
170+
t[0].shape == torch.Size[4, 2]
171+
2. If we successfully apply, the length of the list will grow
172+
by the next dimension in the tensor if we flatten the previous
173+
dimension result
174+
```
175+
inputs = [torch.ones(3, 4, 2)]
176+
level_1 = chain.from_iterable(t.unbind(0) for t in inputs)
177+
level_2 = chain.from_iterable(t.unbind(0) for t in level_1)
178+
len(level_2) == 3 * 4
179+
```
180+
3. Eventually we'll reach single dimension tensors. At that point
181+
we can iterate over each input in an element by element manner
182+
and call some function
183+
184+
For scalar loop, we need to broadcast the tensors so all
185+
the necessary values are repeated, and we "evenly" iterate through everything
186+
"""
187+
188+
def elemwise_fn(*inputs):
189+
Elemwise._check_runtime_broadcast(node, inputs)
190+
shaped_inputs = torch.broadcast_tensors(*inputs)
191+
expected_size = shaped_inputs[0].numel()
192+
final_inputs = [s.clone() for s in shaped_inputs]
193+
for _ in range(shaped_inputs[0].dim() - 1):
194+
for i, _ in enumerate(shaped_inputs):
195+
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
196+
final_inputs[i] = list(layer)
197+
198+
# make sure we still have the same number of things
199+
assert len(final_inputs) == len(shaped_inputs)
200+
201+
# make sure each group of things are the expected size
202+
assert all(len(x) == expected_size for x in final_inputs)
203+
204+
# make sure they are all single elements
205+
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
206+
res = [base_fn(*args) for args in zip(*final_inputs)]
207+
208+
return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))]
209+
210+
return elemwise_fn

0 commit comments

Comments
 (0)