|
1 | | -from collections.abc import Iterable |
| 1 | +from collections.abc import Iterable, Sequence |
2 | 2 |
|
3 | 3 | import numpy as np |
4 | | -from numpy.core.numeric import normalize_axis_tuple |
| 4 | +from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple |
5 | 5 |
|
6 | 6 | from pytensor import Variable |
7 | 7 | from pytensor.graph import Constant, node_rewriter |
|
34 | 34 | specify_shape, |
35 | 35 | unbroadcast, |
36 | 36 | ) |
| 37 | +from pytensor.tensor.special import Softmax, softmax |
37 | 38 | from pytensor.tensor.subtensor import ( |
38 | 39 | AdvancedSubtensor1, |
39 | 40 | Subtensor, |
|
48 | 49 |
|
49 | 50 |
|
50 | 51 | def _dims_dropped_by_basic_index(idxs) -> tuple[int, ...]: |
| 52 | + # Inputs can be slice or integer indexes |
| 53 | + # Slices keep the dimensions, integers collapse them |
51 | 54 | return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice)) |
52 | 55 |
|
53 | 56 |
|
| 57 | +def _ndim_dropped_left_of_axis_by_basic_index(idxs, axis: int) -> int: |
| 58 | + return len(_dims_dropped_by_basic_index(idxs[:axis])) |
| 59 | + |
| 60 | + |
| 61 | +def _axis_is_indexed_by_basic_index( |
| 62 | + idxs: tuple[Variable], axis: int | Sequence[int] |
| 63 | +) -> bool: |
| 64 | + if isinstance(axis, int): |
| 65 | + axis = (axis,) |
| 66 | + return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis) |
| 67 | + |
| 68 | + |
54 | 69 | @register_canonicalize |
55 | 70 | @register_stabilize |
56 | 71 | @register_specialize |
@@ -241,6 +256,84 @@ def local_subtensor_of_reduce(fgraph, node): |
241 | 256 | return [out] |
242 | 257 |
|
243 | 258 |
|
| 259 | +@register_canonicalize |
| 260 | +@register_specialize |
| 261 | +@node_rewriter([Subtensor]) |
| 262 | +def local_subtensor_of_softmax(fgraph, node): |
| 263 | + """Lift a Subtensor through a Softmax. |
| 264 | +
|
| 265 | + softmax(x, axis=1)[0] -> softmax(x[0], axis=0) |
| 266 | + softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1) |
| 267 | +
|
| 268 | + If part of the indexing acts on the axis of reduction, we split it |
| 269 | + softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0] |
| 270 | +
|
| 271 | + """ |
| 272 | + sm, *idx = node.inputs |
| 273 | + |
| 274 | + if not (sm.owner and isinstance(sm.owner.op, Softmax)): |
| 275 | + return None |
| 276 | + |
| 277 | + if len(fgraph.clients[sm]) > 1: |
| 278 | + return None |
| 279 | + |
| 280 | + [x] = sm.owner.inputs |
| 281 | + axis = sm.owner.op.axis |
| 282 | + |
| 283 | + if axis is None: |
| 284 | + if x.type.ndim == 1: |
| 285 | + axis = 0 |
| 286 | + else: |
| 287 | + # All dimensions are mixed, we can't lift the subtensor |
| 288 | + return None |
| 289 | + else: |
| 290 | + # Softmax currently only allows None or a single integer axis |
| 291 | + # Unlike CAReduce it does not normalize negative indices |
| 292 | + axis = normalize_axis_index(axis, sm.ndim) |
| 293 | + |
| 294 | + [old_out] = node.outputs |
| 295 | + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) |
| 296 | + |
| 297 | + if _axis_is_indexed_by_basic_index(idx_tuple, axis): |
| 298 | + # If there are more dimensions being indexed, we can split them |
| 299 | + # And lift the non-axis indexes while keeping the axis index |
| 300 | + real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] |
| 301 | + if len(real_indices) > 1 and sm.type.ndim > 1: |
| 302 | + # Split the subtensor |
| 303 | + idx_to_keep = idx_tuple[axis] |
| 304 | + idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) |
| 305 | + |
| 306 | + # Lift the non-axis indexes by calling the rewrite itself |
| 307 | + opt_sm = sm[idxs_to_lift] |
| 308 | + [opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner) |
| 309 | + copy_stack_trace([old_out, sm], opt_sm) |
| 310 | + |
| 311 | + # Then reintroduce the axis index |
| 312 | + ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index( |
| 313 | + idx_tuple, axis |
| 314 | + ) |
| 315 | + new_axis = axis - ndim_reduced_left |
| 316 | + idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep) |
| 317 | + new_out = opt_sm[idxs_to_keep] |
| 318 | + copy_stack_trace(old_out, new_out) |
| 319 | + return [new_out] |
| 320 | + |
| 321 | + else: |
| 322 | + return None |
| 323 | + |
| 324 | + # Index input to softmax |
| 325 | + x_sub = x[idx_tuple] |
| 326 | + |
| 327 | + # Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing) |
| 328 | + axis -= len( |
| 329 | + [idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)] |
| 330 | + ) |
| 331 | + |
| 332 | + out = softmax(x_sub, axis=axis) |
| 333 | + copy_stack_trace(old_out, out) |
| 334 | + return [out] |
| 335 | + |
| 336 | + |
244 | 337 | @register_canonicalize("shape_unsafe") |
245 | 338 | @register_specialize("shape_unsafe") |
246 | 339 | @node_rewriter([Subtensor]) |
|
0 commit comments