|
12 | 12 | @pytorch_funcify.register(Elemwise) |
13 | 13 | def pytorch_funcify_Elemwise(op, node, **kwargs): |
14 | 14 | scalar_op = op.scalar_op |
15 | | - base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) |
16 | 15 |
|
| 16 | + base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) |
17 | 17 | if isinstance(scalar_op, ScalarLoop): |
18 | | - # note: scalarloop + elemwise is too common |
19 | | - # to not work, but @1031, vmap won't allow it. |
20 | | - # Instead, we will just successively unbind |
21 | | - def elemwise_fn(*inputs): |
22 | | - Elemwise._check_runtime_broadcast(node, inputs) |
23 | | - shaped_inputs = torch.broadcast_tensors(*inputs) |
24 | | - expected_size = shaped_inputs[0].numel() |
25 | | - final_inputs = [s.clone() for s in shaped_inputs] |
26 | | - for _ in range(shaped_inputs[0].dim() - 1): |
27 | | - for i, _ in enumerate(shaped_inputs): |
28 | | - layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]]) |
29 | | - final_inputs[i] = list(layer) |
30 | | - |
31 | | - # make sure we still have the same number of things |
32 | | - assert len(final_inputs) == len(shaped_inputs) |
33 | | - |
34 | | - # make sure each group of things are the expected size |
35 | | - assert all(len(x) == expected_size for x in final_inputs) |
36 | | - |
37 | | - # make sure they are all single elements |
38 | | - assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor) |
39 | | - res = [base_fn(*args) for args in zip(*final_inputs)] |
40 | | - states = torch.stack(tuple(out[0] for out in res)) |
41 | | - done = torch.stack(tuple(out[1] for out in res)) |
42 | | - return states, done |
43 | | - |
| 18 | + return elemwise_scalar_loop(base_fn, op, node, **kwargs) |
44 | 19 | else: |
45 | 20 |
|
46 | 21 | def elemwise_fn(*inputs): |
@@ -180,3 +155,56 @@ def softmax_grad(dy, sm): |
180 | 155 | return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm |
181 | 156 |
|
182 | 157 | return softmax_grad |
| 158 | + |
| 159 | + |
| 160 | +def elemwise_scalar_loop(base_fn, op, node, **kwargs): |
| 161 | + """ |
| 162 | + ScalarLoop + Elemwise is too common |
| 163 | + to not work, but @1031, vmap won't allow it. |
| 164 | + Instead, we can do the following strategy |
| 165 | + 1. `.unbind(dim)` will return a list of tensors |
| 166 | + representing `dim` but "unwrapped". e.x. |
| 167 | + ``` |
| 168 | + t = torch.ones(3, 4, 2) |
| 169 | + len(t.unbind(0)) == 3 |
| 170 | + t[0].shape == torch.Size[4, 2] |
| 171 | + 2. If we successfully apply, the length of the list will grow |
| 172 | + by the next dimension in the tensor if we flatten the previous |
| 173 | + dimension result |
| 174 | + ``` |
| 175 | + inputs = [torch.ones(3, 4, 2)] |
| 176 | + level_1 = chain.from_iterable(t.unbind(0) for t in inputs) |
| 177 | + level_2 = chain.from_iterable(t.unbind(0) for t in level_1) |
| 178 | + len(level_2) == 3 * 4 |
| 179 | + ``` |
| 180 | + 3. Eventually we'll reach single dimension tensors. At that point |
| 181 | + we can iterate over each input in an element by element manner |
| 182 | + and call some function |
| 183 | +
|
| 184 | + For scalar loop, we need to broadcast the tensors so all |
| 185 | + the necessary values are repeated, and we "evenly" iterate through everything |
| 186 | + """ |
| 187 | + |
| 188 | + def elemwise_fn(*inputs): |
| 189 | + Elemwise._check_runtime_broadcast(node, inputs) |
| 190 | + shaped_inputs = torch.broadcast_tensors(*inputs) |
| 191 | + expected_size = shaped_inputs[0].numel() |
| 192 | + final_inputs = [s.clone() for s in shaped_inputs] |
| 193 | + for _ in range(shaped_inputs[0].dim() - 1): |
| 194 | + for i, _ in enumerate(shaped_inputs): |
| 195 | + layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]]) |
| 196 | + final_inputs[i] = list(layer) |
| 197 | + |
| 198 | + # make sure we still have the same number of things |
| 199 | + assert len(final_inputs) == len(shaped_inputs) |
| 200 | + |
| 201 | + # make sure each group of things are the expected size |
| 202 | + assert all(len(x) == expected_size for x in final_inputs) |
| 203 | + |
| 204 | + # make sure they are all single elements |
| 205 | + assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor) |
| 206 | + res = [base_fn(*args) for args in zip(*final_inputs)] |
| 207 | + |
| 208 | + return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))] |
| 209 | + |
| 210 | + return elemwise_fn |
0 commit comments