|
1 | | -from collections.abc import Iterable |
| 1 | +from collections.abc import Iterable, Sequence |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | | -from numpy.core.numeric import normalize_axis_tuple |
| 4 | +from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple |
5 | 5 |
|
6 | 6 | from pytensor import Variable |
7 | 7 | from pytensor.graph import Constant, node_rewriter |
|
29 | 29 | from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift |
30 | 30 | from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless |
31 | 31 | from pytensor.tensor.shape import Shape, SpecifyShape, Unbroadcast, unbroadcast |
| 32 | +from pytensor.tensor.special import Softmax, softmax |
32 | 33 | from pytensor.tensor.subtensor import ( |
33 | 34 | AdvancedSubtensor1, |
34 | 35 | Subtensor, |
|
42 | 43 |
|
43 | 44 |
|
44 | 45 | def _dims_dropped_by_basic_index(idxs) -> tuple[int, ...]: |
| 46 | + # Inputs can be slice or integer indexes |
| 47 | + # Slices keep the dimensions, integers collapse them |
45 | 48 | return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice)) |
46 | 49 |
|
47 | 50 |
|
| 51 | +def _ndim_dropped_left_of_axis_by_basic_index(idxs, axis: int) -> int: |
| 52 | + return len(_dims_dropped_by_basic_index(idxs[:axis])) |
| 53 | + |
| 54 | + |
| 55 | +def _axis_is_indexed_by_basic_index( |
| 56 | + idxs: tuple[Variable], axis: int | Sequence[int] |
| 57 | +) -> bool: |
| 58 | + if isinstance(axis, int): |
| 59 | + axis = (axis,) |
| 60 | + return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis) |
| 61 | + |
| 62 | + |
48 | 63 | @register_canonicalize |
49 | 64 | @register_stabilize |
50 | 65 | @register_specialize |
@@ -235,6 +250,84 @@ def local_subtensor_of_reduce(fgraph, node): |
235 | 250 | return [out] |
236 | 251 |
|
237 | 252 |
|
| 253 | +@register_canonicalize |
| 254 | +@register_specialize |
| 255 | +@node_rewriter([Subtensor]) |
| 256 | +def local_subtensor_of_softmax(fgraph, node): |
| 257 | + """Lift a Subtensor through a Softmax. |
| 258 | +
|
| 259 | + softmax(x, axis=1)[0] -> softmax(x[0], axis=0) |
| 260 | + softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1) |
| 261 | +
|
| 262 | + If part of the indexing acts on the axis of reduction, we split it |
| 263 | + softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0] |
| 264 | +
|
| 265 | + """ |
| 266 | + sm, *idx = node.inputs |
| 267 | + |
| 268 | + if not (sm.owner and isinstance(sm.owner.op, Softmax)): |
| 269 | + return None |
| 270 | + |
| 271 | + if len(fgraph.clients[sm]) > 1: |
| 272 | + return None |
| 273 | + |
| 274 | + [x] = sm.owner.inputs |
| 275 | + axis = sm.owner.op.axis |
| 276 | + |
| 277 | + if axis is None: |
| 278 | + if x.type.ndim == 1: |
| 279 | + axis = 0 |
| 280 | + else: |
| 281 | + # All dimensions are mixed, we can't lift the subtensor |
| 282 | + return None |
| 283 | + else: |
| 284 | + # Softmax currently only allows None or a single integer axis |
| 285 | + # Unlike CAReduce it does not normalize negative indices |
| 286 | + axis = normalize_axis_index(axis, sm.ndim) |
| 287 | + |
| 288 | + [old_out] = node.outputs |
| 289 | + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) |
| 290 | + |
| 291 | + if _axis_is_indexed_by_basic_index(idx_tuple, axis): |
| 292 | + # If there are more dimensions being indexed, we can split them |
| 293 | + # And lift the non-axis indexes while keeping the axis index |
| 294 | + real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] |
| 295 | + if len(real_indices) > 1 and sm.type.ndim > 1: |
| 296 | + # Split the subtensor |
| 297 | + idx_to_keep = idx_tuple[axis] |
| 298 | + idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) |
| 299 | + |
| 300 | + # Lift the non-axis indexes by calling the rewrite itself |
| 301 | + opt_sm = sm[idxs_to_lift] |
| 302 | + [opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner) |
| 303 | + copy_stack_trace([old_out, sm], opt_sm) |
| 304 | + |
| 305 | + # Then reintroduce the axis index |
| 306 | + ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index( |
| 307 | + idx_tuple, axis |
| 308 | + ) |
| 309 | + new_axis = axis - ndim_reduced_left |
| 310 | + idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) |
| 311 | + new_out = opt_sm[idxs_to_keep] |
| 312 | + copy_stack_trace(old_out, new_out) |
| 313 | + return [new_out] |
| 314 | + |
| 315 | + else: |
| 316 | + return None |
| 317 | + |
| 318 | + # Index input to softmax |
| 319 | + x_sub = x[idx_tuple] |
| 320 | + |
| 321 | + # Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing) |
| 322 | + axis -= len( |
| 323 | + [idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)] |
| 324 | + ) |
| 325 | + |
| 326 | + out = softmax(x_sub, axis=axis) |
| 327 | + copy_stack_trace(old_out, out) |
| 328 | + return [out] |
| 329 | + |
| 330 | + |
238 | 331 | @register_canonicalize("shape_unsafe") |
239 | 332 | @register_specialize("shape_unsafe") |
240 | 333 | @node_rewriter([Subtensor]) |
|
0 commit comments