|
4 | 4 | import numpy as np |
5 | 5 |
|
6 | 6 | from pytensor import Variable |
| 7 | +from pytensor.compile import optdb |
7 | 8 | from pytensor.graph import Constant, FunctionGraph, node_rewriter |
8 | 9 | from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace |
9 | 10 | from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple |
|
39 | 40 | ) |
40 | 41 | from pytensor.tensor.special import Softmax, softmax |
41 | 42 | from pytensor.tensor.subtensor import ( |
| 43 | + AdvancedSubtensor, |
42 | 44 | AdvancedSubtensor1, |
43 | 45 | Subtensor, |
| 46 | + _non_consecutive_adv_indexing, |
44 | 47 | as_index_literal, |
45 | 48 | get_canonical_form_slice, |
46 | 49 | get_constant_idx, |
47 | 50 | get_idx_list, |
48 | 51 | indices_from_subtensor, |
49 | 52 | ) |
50 | 53 | from pytensor.tensor.type import TensorType |
51 | | -from pytensor.tensor.type_other import SliceType |
| 54 | +from pytensor.tensor.type_other import NoneTypeT, SliceType |
52 | 55 | from pytensor.tensor.variable import TensorVariable |
53 | 56 |
|
54 | 57 |
|
@@ -816,3 +819,79 @@ def local_subtensor_shape_constant(fgraph, node): |
816 | 819 | return [as_tensor([1] * len(shape_parts), dtype=np.int64, ndim=1)] |
817 | 820 | elif shape_parts: |
818 | 821 | return [as_tensor(1, dtype=np.int64)] |
| 822 | + |
| 823 | + |
| 824 | +@node_rewriter([Subtensor]) |
| 825 | +def local_subtensor_of_adv_subtensor(fgraph, node): |
| 826 | + """Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones. |
| 827 | +
|
| 828 | + x[:, :, vec_idx][i, j] -> x[i, j][vec_idx] |
| 829 | + x[:, vec_idx][i, j, k] -> x[i][vec_idx][j, k] |
| 830 | +
|
| 831 | + Restricted to a single advanced indexing dimension. |
| 832 | +
|
| 833 | + An alternative approach could have fused the basic and advanced indices, |
| 834 | + so it is not clear this rewrite should be canonical or a specialization. |
| 835 | + Users must include it manually if it fits their use case. |
| 836 | + """ |
| 837 | + adv_subtensor, *idxs = node.inputs |
| 838 | + |
| 839 | + if not ( |
| 840 | + adv_subtensor.owner and isinstance(adv_subtensor.owner.op, AdvancedSubtensor) |
| 841 | + ): |
| 842 | + return None |
| 843 | + |
| 844 | + if len(fgraph.clients[adv_subtensor]) > 1: |
| 845 | + # AdvancedSubtensor involves a full_copy, so we don't want to do it twice |
| 846 | + return None |
| 847 | + |
| 848 | + x, *adv_idxs = adv_subtensor.owner.inputs |
| 849 | + |
| 850 | + # Advanced indexing is a minefield, avoid all cases except for consecutive integer indices |
| 851 | + if any( |
| 852 | + ( |
| 853 | + isinstance(adv_idx.type, NoneTypeT) |
| 854 | + or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool") |
| 855 | + or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx)) |
| 856 | + ) |
| 857 | + for adv_idx in adv_idxs |
| 858 | + ) or _non_consecutive_adv_indexing(adv_idxs): |
| 859 | + return None |
| 860 | + |
| 861 | + for first_adv_idx_dim, adv_idx in enumerate(adv_idxs): |
| 862 | + # We already made sure there were only None slices besides integer indexes |
| 863 | + if isinstance(adv_idx.type, TensorType): |
| 864 | + break |
| 865 | + else: # no-break |
| 866 | + # Not sure if this should ever happen, but better safe than sorry |
| 867 | + return None |
| 868 | + |
| 869 | + basic_idxs = indices_from_subtensor(idxs, node.op.idx_list) |
| 870 | + basic_idxs_lifted = basic_idxs[:first_adv_idx_dim] |
| 871 | + basic_idxs_kept = ((slice(None),) * len(basic_idxs_lifted)) + basic_idxs[ |
| 872 | + first_adv_idx_dim: |
| 873 | + ] |
| 874 | + |
| 875 | + if all(basic_idx == slice(None) for basic_idx in basic_idxs_lifted): |
| 876 | + # All basic indices happen to the right of the advanced indices |
| 877 | + return None |
| 878 | + |
| 879 | + [basic_subtensor] = node.outputs |
| 880 | + dropped_dims = _dims_dropped_by_basic_index(basic_idxs_lifted) |
| 881 | + |
| 882 | + x_indexed = x[basic_idxs_lifted] |
| 883 | + copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed) |
| 884 | + |
| 885 | + x_after_index_lift = expand_dims(x_indexed, dropped_dims) |
| 886 | + x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs) |
| 887 | + copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx) |
| 888 | + |
| 889 | + new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims) |
| 890 | + return [new_out] |
| 891 | + |
| 892 | + |
| 893 | +# Rewrite will only be included if tagged by name |
| 894 | +r = local_subtensor_of_adv_subtensor |
| 895 | +optdb["canonicalize"].register(r.__name__, r, use_db_name_as_tag=False) |
| 896 | +optdb["specialize"].register(r.__name__, r, use_db_name_as_tag=False) |
| 897 | +del r |
0 commit comments