|
4 | 4 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple |
5 | 5 |
|
6 | 6 | from pytensor import Variable |
7 | | -from pytensor.graph import Constant, node_rewriter |
8 | | -from pytensor.graph.rewriting.basic import copy_stack_trace |
| 7 | +from pytensor.graph import Constant, FunctionGraph, node_rewriter |
| 8 | +from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace |
9 | 9 | from pytensor.scalar import basic as ps |
10 | 10 | from pytensor.tensor import ( |
11 | 11 | TensorType, |
|
17 | 17 | specify_shape, |
18 | 18 | squeeze, |
19 | 19 | ) |
20 | | -from pytensor.tensor.basic import Alloc, MakeVector, expand_dims, register_infer_shape |
| 20 | +from pytensor.tensor.basic import ( |
| 21 | + Alloc, |
| 22 | + Join, |
| 23 | + MakeVector, |
| 24 | + expand_dims, |
| 25 | + join, |
| 26 | + register_infer_shape, |
| 27 | +) |
21 | 28 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
22 | 29 | from pytensor.tensor.exceptions import NotScalarConstantError |
23 | 30 | from pytensor.tensor.math import Dot |
@@ -60,6 +67,40 @@ def _axis_is_indexed_by_basic_index( |
60 | 67 | return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis) |
61 | 68 |
|
62 | 69 |
|
| 70 | +def _lift_subtensor_non_axis( |
| 71 | + local_subtensor_lift_rewrite: NodeRewriter, |
| 72 | + fgraph: FunctionGraph, |
| 73 | + variable: Variable, |
| 74 | + idx_tuple: tuple[int | slice], |
| 75 | + axis: int, |
| 76 | + old_subtensor_variable: Variable, |
| 77 | +) -> None | list[Variable]: |
| 78 | + # Apply generic subtensor lift rewrite along "non-axis" dimensions |
| 79 | + real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] |
| 80 | + if len(real_indices) > 1 and variable.type.ndim > 1: |
| 81 | + # Split the subtensor |
| 82 | + idx_to_keep = idx_tuple[axis] |
| 83 | + idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) |
| 84 | + |
| 85 | + # Lift the non-axis indexes by calling the rewrite itself |
| 86 | + indexed_variable = variable[idxs_to_lift] |
| 87 | + [indexed_variable] = local_subtensor_lift_rewrite.transform( |
| 88 | + fgraph, indexed_variable.owner |
| 89 | + ) |
| 90 | + copy_stack_trace([old_subtensor_variable, indexed_variable], indexed_variable) |
| 91 | + |
| 92 | + # Then reintroduce the axis index |
| 93 | + ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis) |
| 94 | + new_axis = axis - ndim_reduced_left |
| 95 | + idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) |
| 96 | + new_out = indexed_variable[idxs_to_keep] |
| 97 | + copy_stack_trace(old_subtensor_variable, new_out) |
| 98 | + return [new_out] |
| 99 | + |
| 100 | + else: |
| 101 | + return None |
| 102 | + |
| 103 | + |
63 | 104 | @register_canonicalize |
64 | 105 | @register_stabilize |
65 | 106 | @register_specialize |
@@ -291,29 +332,14 @@ def local_subtensor_of_softmax(fgraph, node): |
291 | 332 | if _axis_is_indexed_by_basic_index(idx_tuple, axis): |
292 | 333 | # If there are more dimensions being indexed, we can split them |
293 | 334 | # And lift the non-axis indexes while keeping the axis index |
294 | | - real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] |
295 | | - if len(real_indices) > 1 and sm.type.ndim > 1: |
296 | | - # Split the subtensor |
297 | | - idx_to_keep = idx_tuple[axis] |
298 | | - idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) |
299 | | - |
300 | | - # Lift the non-axis indexes by calling the rewrite itself |
301 | | - opt_sm = sm[idxs_to_lift] |
302 | | - [opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner) |
303 | | - copy_stack_trace([old_out, sm], opt_sm) |
304 | | - |
305 | | - # Then reintroduce the axis index |
306 | | - ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index( |
307 | | - idx_tuple, axis |
308 | | - ) |
309 | | - new_axis = axis - ndim_reduced_left |
310 | | - idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) |
311 | | - new_out = opt_sm[idxs_to_keep] |
312 | | - copy_stack_trace(old_out, new_out) |
313 | | - return [new_out] |
314 | | - |
315 | | - else: |
316 | | - return None |
| 335 | + return _lift_subtensor_non_axis( |
| 336 | + local_subtensor_lift_rewrite=local_subtensor_of_softmax, |
| 337 | + fgraph=fgraph, |
| 338 | + variable=sm, |
| 339 | + idx_tuple=idx_tuple, |
| 340 | + axis=axis, |
| 341 | + old_subtensor_variable=old_out, |
| 342 | + ) |
317 | 343 |
|
318 | 344 | # Index input to softmax |
319 | 345 | x_sub = x[idx_tuple] |
@@ -684,6 +710,52 @@ def local_subtensor_make_vector(fgraph, node): |
684 | 710 | pass |
685 | 711 |
|
686 | 712 |
|
| 713 | +@register_canonicalize |
| 714 | +@register_specialize |
| 715 | +@node_rewriter([Subtensor]) |
| 716 | +def local_subtensor_of_join(fgraph, node): |
| 717 | + """Lift a Subtensor through a Join. |
| 718 | +
|
| 719 | + join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0]) |
| 720 | + join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0] |
| 721 | +
|
| 722 | + """ |
| 723 | + join_var, *idx = node.inputs |
| 724 | + |
| 725 | + if not (join_var.owner and isinstance(join_var.owner.op, Join)): |
| 726 | + return None |
| 727 | + |
| 728 | + if len(fgraph.clients[join_var]) > 1: |
| 729 | + # Join involves a full_copy, so we don't want to do it twice |
| 730 | + return None |
| 731 | + |
| 732 | + join_axis, *join_components = join_var.owner.inputs |
| 733 | + |
| 734 | + # Rewrite only works when the join axis is a constant along a non-indexed dimension |
| 735 | + if not isinstance(join_axis, Constant): |
| 736 | + return None |
| 737 | + |
| 738 | + [old_out] = node.outputs |
| 739 | + axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim) |
| 740 | + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) |
| 741 | + if _axis_is_indexed_by_basic_index(idx_tuple, axis): |
| 742 | + return _lift_subtensor_non_axis( |
| 743 | + local_subtensor_lift_rewrite=local_subtensor_of_join, |
| 744 | + fgraph=fgraph, |
| 745 | + variable=join_var, |
| 746 | + idx_tuple=idx_tuple, |
| 747 | + axis=axis, |
| 748 | + old_subtensor_variable=old_out, |
| 749 | + ) |
| 750 | + |
| 751 | + # Lift index to the Join components |
| 752 | + indexed_components = [component[idx_tuple] for component in join_components] |
| 753 | + new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis) |
| 754 | + out = join(new_axis, *indexed_components) |
| 755 | + |
| 756 | + return [out] |
| 757 | + |
| 758 | + |
687 | 759 | @register_specialize |
688 | 760 | @register_canonicalize |
689 | 761 | @node_rewriter([Subtensor]) |
|
0 commit comments