|
1 | 1 | import torch |
| 2 | +import torch.compiler |
2 | 3 |
|
3 | 4 | from pytensor.graph import FunctionGraph |
4 | 5 | from pytensor.link.pytorch.dispatch import pytorch_funcify |
5 | 6 | from pytensor.tensor.blockwise import Blockwise |
6 | | -from pytensor.tensor.random.utils import params_broadcast_shapes |
7 | 7 |
|
8 | 8 |
|
9 | 9 | @pytorch_funcify.register(Blockwise) |
10 | 10 | def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): |
11 | 11 | batched_dims = op.batch_ndim(node) |
12 | 12 | core_node = op._create_dummy_core_node(node.inputs) |
13 | 13 | core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) |
14 | | - core_func = pytorch_funcify(core_fgraph) |
15 | | - if len(node.outputs) == 1: |
16 | | - |
17 | | - def inner_func(*inputs): |
18 | | - return core_func(*inputs)[0] |
19 | | - else: |
20 | | - inner_func = core_func |
| 14 | + inner_func = pytorch_funcify(core_fgraph) |
21 | 15 |
|
22 | 16 | for _ in range(batched_dims): |
23 | 17 | inner_func = torch.vmap(inner_func) |
24 | 18 |
|
| 19 | + @torch.compiler.disable(recursive=False) |
25 | 20 | def batcher(*inputs): |
26 | 21 | op._check_runtime_broadcast(node, inputs) |
27 | 22 | # broadcast on batched_dims |
28 | | - all_batched_dims = tuple(tuple(t.shape) for t in inputs) |
29 | | - new_shapes = params_broadcast_shapes( |
30 | | - all_batched_dims, |
31 | | - ndims_params=[batched_dims] * len(inputs), |
32 | | - use_pytensor=False, |
33 | | - ) |
| 23 | + all_batched_dims = tuple(t.shape[:batched_dims] for t in inputs) |
| 24 | + batched_shape = torch.broadcast_shapes(*all_batched_dims) |
34 | 25 | broadcast_inputs = [ |
35 | | - torch.broadcast_to(i, s) for i, s in zip(inputs, new_shapes) |
| 26 | + torch.broadcast_to(i, batched_shape + i.shape[batched_dims:]) |
| 27 | + for i in inputs |
36 | 28 | ] |
37 | | - return inner_func(*broadcast_inputs) |
| 29 | + res = inner_func(*broadcast_inputs) |
| 30 | + if len(node.outputs) == 1: |
| 31 | + return res[0] |
| 32 | + else: |
| 33 | + return res |
38 | 34 |
|
39 | 35 | return batcher |
0 commit comments