|
4 | 4 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple |
5 | 5 |
|
6 | 6 | from pytensor import Variable |
7 | | -from pytensor.graph import Constant, node_rewriter |
8 | | -from pytensor.graph.rewriting.basic import copy_stack_trace |
| 7 | +from pytensor.graph import Constant, FunctionGraph, node_rewriter |
| 8 | +from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace |
9 | 9 | from pytensor.scalar import basic as ps |
10 | 10 | from pytensor.tensor.basic import ( |
11 | 11 | Alloc, |
| 12 | + Join, |
12 | 13 | MakeVector, |
13 | 14 | alloc, |
14 | 15 | as_tensor, |
15 | 16 | expand_dims, |
16 | 17 | get_underlying_scalar_constant_value, |
| 18 | + join, |
17 | 19 | register_infer_shape, |
18 | 20 | ) |
19 | 21 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise |
@@ -66,6 +68,40 @@ def _axis_is_indexed_by_basic_index( |
66 | 68 | return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis) |
67 | 69 |
|
68 | 70 |
|
| 71 | +def _lift_subtensor_non_axis( |
| 72 | + local_subtensor_lift_rewrite: NodeRewriter, |
| 73 | + fgraph: FunctionGraph, |
| 74 | + variable: Variable, |
| 75 | + idx_tuple: tuple[int | slice], |
| 76 | + axis: int, |
| 77 | + old_subtensor_variable: Variable, |
| 78 | +) -> None | list[Variable]: |
| 79 | + # Apply generic subtensor lift rewrite along "non-axis" dimensions |
| 80 | + real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] |
| 81 | + if len(real_indices) > 1 and variable.type.ndim > 1: |
| 82 | + # Split the subtensor |
| 83 | + idx_to_keep = idx_tuple[axis] |
| 84 | + idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) |
| 85 | + |
| 86 | + # Lift the non-axis indexes by calling the rewrite itself |
| 87 | + indexed_variable = variable[idxs_to_lift] |
| 88 | + [indexed_variable] = local_subtensor_lift_rewrite.transform( |
| 89 | + fgraph, indexed_variable.owner |
| 90 | + ) |
| 91 | + copy_stack_trace([old_subtensor_variable, indexed_variable], indexed_variable) |
| 92 | + |
| 93 | + # Then reintroduce the axis index |
| 94 | + ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis) |
| 95 | + new_axis = axis - ndim_reduced_left |
| 96 | + idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) |
| 97 | + new_out = indexed_variable[idxs_to_keep] |
| 98 | + copy_stack_trace(old_subtensor_variable, new_out) |
| 99 | + return [new_out] |
| 100 | + |
| 101 | + else: |
| 102 | + return None |
| 103 | + |
| 104 | + |
69 | 105 | @register_canonicalize |
70 | 106 | @register_stabilize |
71 | 107 | @register_specialize |
@@ -297,29 +333,14 @@ def local_subtensor_of_softmax(fgraph, node): |
297 | 333 | if _axis_is_indexed_by_basic_index(idx_tuple, axis): |
298 | 334 | # If there are more dimensions being indexed, we can split them |
299 | 335 | # And lift the non-axis indexes while keeping the axis index |
300 | | - real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] |
301 | | - if len(real_indices) > 1 and sm.type.ndim > 1: |
302 | | - # Split the subtensor |
303 | | - idx_to_keep = idx_tuple[axis] |
304 | | - idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) |
305 | | - |
306 | | - # Lift the non-axis indexes by calling the rewrite itself |
307 | | - opt_sm = sm[idxs_to_lift] |
308 | | - [opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner) |
309 | | - copy_stack_trace([old_out, sm], opt_sm) |
310 | | - |
311 | | - # Then reintroduce the axis index |
312 | | - ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index( |
313 | | - idx_tuple, axis |
314 | | - ) |
315 | | - new_axis = axis - ndim_reduced_left |
316 | | - idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) |
317 | | - new_out = opt_sm[idxs_to_keep] |
318 | | - copy_stack_trace(old_out, new_out) |
319 | | - return [new_out] |
320 | | - |
321 | | - else: |
322 | | - return None |
| 336 | + return _lift_subtensor_non_axis( |
| 337 | + local_subtensor_lift_rewrite=local_subtensor_of_softmax, |
| 338 | + fgraph=fgraph, |
| 339 | + variable=sm, |
| 340 | + idx_tuple=idx_tuple, |
| 341 | + axis=axis, |
| 342 | + old_subtensor_variable=old_out, |
| 343 | + ) |
323 | 344 |
|
324 | 345 | # Index input to softmax |
325 | 346 | x_sub = x[idx_tuple] |
@@ -690,6 +711,52 @@ def local_subtensor_make_vector(fgraph, node): |
690 | 711 | pass |
691 | 712 |
|
692 | 713 |
|
| 714 | +@register_canonicalize |
| 715 | +@register_specialize |
| 716 | +@node_rewriter([Subtensor]) |
| 717 | +def local_subtensor_of_join(fgraph, node): |
| 718 | + """Lift a Subtensor through a Join. |
| 719 | +
|
| 720 | + join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0]) |
| 721 | + join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0] |
| 722 | +
|
| 723 | + """ |
| 724 | + join_var, *idx = node.inputs |
| 725 | + |
| 726 | + if not (join_var.owner and isinstance(join_var.owner.op, Join)): |
| 727 | + return None |
| 728 | + |
| 729 | + if len(fgraph.clients[join_var]) > 1: |
| 730 | + # Join involves a full_copy, so we don't want to do it twice |
| 731 | + return None |
| 732 | + |
| 733 | + join_axis, *join_components = join_var.owner.inputs |
| 734 | + |
| 735 | + # Rewrite only works when the join axis is a constant along a non-indexed dimension |
| 736 | + if not isinstance(join_axis, Constant): |
| 737 | + return None |
| 738 | + |
| 739 | + [old_out] = node.outputs |
| 740 | + axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim) |
| 741 | + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) |
| 742 | + if _axis_is_indexed_by_basic_index(idx_tuple, axis): |
| 743 | + return _lift_subtensor_non_axis( |
| 744 | + local_subtensor_lift_rewrite=local_subtensor_of_join, |
| 745 | + fgraph=fgraph, |
| 746 | + variable=join_var, |
| 747 | + idx_tuple=idx_tuple, |
| 748 | + axis=axis, |
| 749 | + old_subtensor_variable=old_out, |
| 750 | + ) |
| 751 | + |
| 752 | + # Lift index to the Join components |
| 753 | + indexed_components = [component[idx_tuple] for component in join_components] |
| 754 | + new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis) |
| 755 | + out = join(new_axis, *indexed_components) |
| 756 | + |
| 757 | + return [out] |
| 758 | + |
| 759 | + |
693 | 760 | @register_specialize |
694 | 761 | @register_canonicalize |
695 | 762 | @node_rewriter([Subtensor]) |
|
0 commit comments