Skip to content

Commit e28c3e2

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Clean up and add description
1 parent 561301b commit e28c3e2

File tree

1 file changed

+55
-25
lines changed

1 file changed

+55
-25
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
@pytorch_funcify.register(Elemwise)
1414
def pytorch_funcify_Elemwise(op, node, **kwargs):
1515
scalar_op = op.scalar_op
16+
1617
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
1718

1819
def check_special_scipy(func_name):
@@ -36,31 +37,7 @@ def elemwise_fn(*inputs):
3637
return base_fn(*inputs)
3738

3839
elif isinstance(scalar_op, ScalarLoop):
39-
# note: scalarloop + elemwise is too common
40-
# to not work, but @1031, vmap won't allow it.
41-
# Instead, we will just successively unbind
42-
def elemwise_fn(*inputs):
43-
Elemwise._check_runtime_broadcast(node, inputs)
44-
shaped_inputs = torch.broadcast_tensors(*inputs)
45-
expected_size = shaped_inputs[0].numel()
46-
final_inputs = [s.clone() for s in shaped_inputs]
47-
for _ in range(shaped_inputs[0].dim() - 1):
48-
for i, _ in enumerate(shaped_inputs):
49-
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
50-
final_inputs[i] = list(layer)
51-
52-
# make sure we still have the same number of things
53-
assert len(final_inputs) == len(shaped_inputs)
54-
55-
# make sure each group of things are the expected size
56-
assert all(len(x) == expected_size for x in final_inputs)
57-
58-
# make sure they are all single elements
59-
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
60-
res = [base_fn(*args) for args in zip(*final_inputs)]
61-
states = torch.stack(tuple(out[0] for out in res))
62-
done = torch.stack(tuple(out[1] for out in res))
63-
return states, done
40+
return elemwise_scalar_loop(base_fn, op, node, **kwargs)
6441

6542
else:
6643

@@ -206,3 +183,56 @@ def softmax_grad(dy, sm):
206183
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
207184

208185
return softmax_grad
186+
187+
188+
def elemwise_scalar_loop(base_fn, op, node, **kwargs):
189+
"""
190+
ScalarLoop + Elemwise is too common
191+
to not work, but @1031, vmap won't allow it.
192+
Instead, we can do the following strategy
193+
1. `.unbind(dim)` will return a list of tensors
194+
representing `dim` but "unwrapped". e.x.
195+
```
196+
t = torch.ones(3, 4, 2)
197+
len(t.unbind(0)) == 3
198+
t[0].shape == torch.Size[4, 2]
199+
2. If we successfully apply, the length of the list will grow
200+
by the next dimension in the tensor if we flatten the previous
201+
dimension result
202+
```
203+
inputs = [torch.ones(3, 4, 2)]
204+
level_1 = chain.from_iterable(t.unbind(0) for t in inputs)
205+
level_2 = chain.from_iterable(t.unbind(0) for t in level_1)
206+
len(level_2) == 3 * 4
207+
```
208+
3. Eventually we'll reach single dimension tensors. At that point
209+
we can iterate over each input in an element by element manner
210+
and call some function
211+
212+
For scalar loop, we need to broadcast the tensors so all
213+
the necessary values are repeated, and we "evenly" iterate through everything
214+
"""
215+
216+
def elemwise_fn(*inputs):
217+
Elemwise._check_runtime_broadcast(node, inputs)
218+
shaped_inputs = torch.broadcast_tensors(*inputs)
219+
expected_size = shaped_inputs[0].numel()
220+
final_inputs = [s.clone() for s in shaped_inputs]
221+
for _ in range(shaped_inputs[0].dim() - 1):
222+
for i, _ in enumerate(shaped_inputs):
223+
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
224+
final_inputs[i] = list(layer)
225+
226+
# make sure we still have the same number of things
227+
assert len(final_inputs) == len(shaped_inputs)
228+
229+
# make sure each group of things are the expected size
230+
assert all(len(x) == expected_size for x in final_inputs)
231+
232+
# make sure they are all single elements
233+
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
234+
res = [base_fn(*args) for args in zip(*final_inputs)]
235+
236+
return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))]
237+
238+
return elemwise_fn

0 commit comments

Comments
 (0)