|
1 | 1 | from collections.abc import Iterable, Sequence |
| 2 | +from typing import cast |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 | from numpy.core.numeric import ( # type: ignore |
|
7 | 8 | ) |
8 | 9 |
|
9 | 10 | from pytensor import Variable |
10 | | -from pytensor.graph import Constant, node_rewriter |
11 | | -from pytensor.graph.rewriting.basic import copy_stack_trace |
| 11 | +from pytensor.graph import Constant, FunctionGraph, node_rewriter |
| 12 | +from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace |
12 | 13 | from pytensor.scalar import basic as ps |
| 14 | +from pytensor.tensor import TensorVariable |
13 | 15 | from pytensor.tensor.basic import ( |
14 | 16 | Alloc, |
| 17 | + Join, |
15 | 18 | MakeVector, |
16 | 19 | alloc, |
17 | 20 | as_tensor, |
18 | 21 | expand_dims, |
19 | 22 | get_underlying_scalar_constant_value, |
| 23 | + join, |
20 | 24 | register_infer_shape, |
21 | 25 | ) |
22 | 26 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
@@ -71,6 +75,41 @@ def _axis_is_indexed_by_basic_index( |
71 | 75 | return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis) |
72 | 76 |
|
73 | 77 |
|
| 78 | +def _lift_subtensor_non_axis( |
| 79 | + local_subtensor_lift_rewrite: NodeRewriter, |
| 80 | + fgraph: FunctionGraph, |
| 81 | + variable: TensorVariable, |
| 82 | + idx_tuple: tuple[int | slice], |
| 83 | + axis: int, |
| 84 | + old_subtensor_variable: TensorVariable, |
| 85 | +) -> None | list[TensorVariable]: |
| 86 | + # Apply generic subtensor lift rewrite along "non-axis" dimensions |
| 87 | + real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] |
| 88 | + if len(real_indices) > 1 and variable.type.ndim > 1: |
| 89 | + # Split the subtensor |
| 90 | + idx_to_keep = idx_tuple[axis] |
| 91 | + idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) |
| 92 | + |
| 93 | + # Lift the non-axis indexes by calling the rewrite itself |
| 94 | + indexed_variable = variable[idxs_to_lift] |
| 95 | + [indexed_variable] = cast( |
| 96 | + list[TensorVariable], |
| 97 | + local_subtensor_lift_rewrite.transform(fgraph, indexed_variable.owner), |
| 98 | + ) |
| 99 | + copy_stack_trace([old_subtensor_variable, indexed_variable], indexed_variable) |
| 100 | + |
| 101 | + # Then reintroduce the axis index |
| 102 | + ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis) |
| 103 | + new_axis = axis - ndim_reduced_left |
| 104 | + idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) |
| 105 | + new_out = indexed_variable[idxs_to_keep] |
| 106 | + copy_stack_trace(old_subtensor_variable, new_out) |
| 107 | + return [new_out] |
| 108 | + |
| 109 | + else: |
| 110 | + return None |
| 111 | + |
| 112 | + |
74 | 113 | @register_canonicalize |
75 | 114 | @register_stabilize |
76 | 115 | @register_specialize |
@@ -302,29 +341,14 @@ def local_subtensor_of_softmax(fgraph, node): |
302 | 341 | if _axis_is_indexed_by_basic_index(idx_tuple, axis): |
303 | 342 | # If there are more dimensions being indexed, we can split them |
304 | 343 | # And lift the non-axis indexes while keeping the axis index |
305 | | - real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] |
306 | | - if len(real_indices) > 1 and sm.type.ndim > 1: |
307 | | - # Split the subtensor |
308 | | - idx_to_keep = idx_tuple[axis] |
309 | | - idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) |
310 | | - |
311 | | - # Lift the non-axis indexes by calling the rewrite itself |
312 | | - opt_sm = sm[idxs_to_lift] |
313 | | - [opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner) |
314 | | - copy_stack_trace([old_out, sm], opt_sm) |
315 | | - |
316 | | - # Then reintroduce the axis index |
317 | | - ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index( |
318 | | - idx_tuple, axis |
319 | | - ) |
320 | | - new_axis = axis - ndim_reduced_left |
321 | | - idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) |
322 | | - new_out = opt_sm[idxs_to_keep] |
323 | | - copy_stack_trace(old_out, new_out) |
324 | | - return [new_out] |
325 | | - |
326 | | - else: |
327 | | - return None |
| 344 | + return _lift_subtensor_non_axis( |
| 345 | + local_subtensor_lift_rewrite=local_subtensor_of_softmax, |
| 346 | + fgraph=fgraph, |
| 347 | + variable=sm, |
| 348 | + idx_tuple=idx_tuple, |
| 349 | + axis=axis, |
| 350 | + old_subtensor_variable=old_out, |
| 351 | + ) |
328 | 352 |
|
329 | 353 | # Index input to softmax |
330 | 354 | x_sub = x[idx_tuple] |
@@ -695,6 +719,52 @@ def local_subtensor_make_vector(fgraph, node): |
695 | 719 | pass |
696 | 720 |
|
697 | 721 |
|
| 722 | +@register_canonicalize |
| 723 | +@register_specialize |
| 724 | +@node_rewriter([Subtensor]) |
| 725 | +def local_subtensor_of_join(fgraph, node): |
| 726 | + """Lift a Subtensor through a Join. |
| 727 | +
|
| 728 | + join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0]) |
| 729 | + join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0] |
| 730 | +
|
| 731 | + """ |
| 732 | + join_var, *idx = node.inputs |
| 733 | + |
| 734 | + if not (join_var.owner and isinstance(join_var.owner.op, Join)): |
| 735 | + return None |
| 736 | + |
| 737 | + if len(fgraph.clients[join_var]) > 1: |
| 738 | + # Join involves a full_copy, so we don't want to do it twice |
| 739 | + return None |
| 740 | + |
| 741 | + join_axis, *join_components = join_var.owner.inputs |
| 742 | + |
| 743 | + # Rewrite only works when the join axis is a constant along a non-indexed dimension |
| 744 | + if not isinstance(join_axis, Constant): |
| 745 | + return None |
| 746 | + |
| 747 | + [old_out] = node.outputs |
| 748 | + axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim) |
| 749 | + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) |
| 750 | + if _axis_is_indexed_by_basic_index(idx_tuple, axis): |
| 751 | + return _lift_subtensor_non_axis( |
| 752 | + local_subtensor_lift_rewrite=local_subtensor_of_join, |
| 753 | + fgraph=fgraph, |
| 754 | + variable=join_var, |
| 755 | + idx_tuple=idx_tuple, |
| 756 | + axis=axis, |
| 757 | + old_subtensor_variable=old_out, |
| 758 | + ) |
| 759 | + |
| 760 | + # Lift index to the Join components |
| 761 | + indexed_components = [component[idx_tuple] for component in join_components] |
| 762 | + new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis) |
| 763 | + out = join(new_axis, *indexed_components) |
| 764 | + |
| 765 | + return [out] |
| 766 | + |
| 767 | + |
698 | 768 | @register_specialize |
699 | 769 | @register_canonicalize |
700 | 770 | @node_rewriter([Subtensor]) |
|
0 commit comments