|
1 | 1 | from copy import copy
|
| 2 | +from typing import Union |
2 | 3 |
|
3 | 4 | import numpy as np
|
| 5 | +from numpy.core.numeric import normalize_axis_tuple |
4 | 6 |
|
5 | 7 | import pytensor.tensor.basic
|
6 | 8 | from pytensor.configdefaults import config
|
@@ -1399,7 +1401,7 @@ def make_node(self, input):
|
1399 | 1401 | # scalar inputs are treated as 1D regarding axis in this `Op`
|
1400 | 1402 | if axis is not None:
|
1401 | 1403 | try:
|
1402 |
| - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, inp_dims)) |
| 1404 | + axis = normalize_axis_tuple(axis, ndim=max(1, inp_dims)) |
1403 | 1405 | except np.AxisError:
|
1404 | 1406 | raise np.AxisError(axis, ndim=inp_dims)
|
1405 | 1407 |
|
@@ -1757,18 +1759,36 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
|
1757 | 1759 | return DimShuffle(input_broadcastable, new_order).make_node(x)
|
1758 | 1760 |
|
1759 | 1761 |
|
1760 |
| -@_vectorize_node.register(CAReduce) |
1761 |
| -def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply: |
1762 |
| - batched_ndims = x.type.ndim - node.inputs[0].type.ndim |
1763 |
| - if not batched_ndims: |
1764 |
| - return node.op.make_node(x) |
1765 |
| - axes = op.axis |
1766 |
| - # e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) |
1767 |
| - # e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) |
1768 |
| - if axes is None: |
1769 |
| - axes = list(range(node.inputs[0].type.ndim)) |
| 1762 | +def get_normalized_batch_axes( |
| 1763 | + core_axes: Union[None, int, tuple[int, ...]], |
| 1764 | + core_ndim: int, |
| 1765 | + batch_ndim: int, |
| 1766 | +) -> tuple[int, ...]: |
| 1767 | + """Compute batch axes for a batched operation, from the core input ndim and axes. |
| 1768 | +
|
| 1769 | + e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) |
| 1770 | + batch_axes(None, 2, 4) -> (2, 3) |
| 1771 | +
|
| 1772 | + e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) |
| 1773 | + batch_axes(0, 2, 4) -> (2,) |
| 1774 | +
|
| 1775 | + e.g., sum(tensor3, axis=(0, -1)) -> sum(tensor4, axis=(1, 3)) |
| 1776 | + batch_axes((0, -1), 3, 4) -> (1, 3) |
| 1777 | + """ |
| 1778 | + if core_axes is None: |
| 1779 | + core_axes = tuple(range(core_ndim)) |
1770 | 1780 | else:
|
1771 |
| - axes = list(axes) |
1772 |
| - new_axes = [axis + batched_ndims for axis in axes] |
1773 |
| - new_op = op.clone(axis=new_axes) |
1774 |
| - return new_op.make_node(x) |
| 1781 | + core_axes = normalize_axis_tuple(core_axes, core_ndim) |
| 1782 | + return tuple(core_axis + batch_ndim for core_axis in core_axes) |
| 1783 | + |
| 1784 | + |
| 1785 | +@_vectorize_node.register(CAReduce) |
| 1786 | +def vectorize_careduce(op: CAReduce, node: Apply, batch_x: TensorVariable) -> Apply: |
| 1787 | + core_ndim = node.inputs[0].type.ndim |
| 1788 | + batch_ndim = batch_x.type.ndim - core_ndim |
| 1789 | + |
| 1790 | + if not batch_ndim: |
| 1791 | + return node.op.make_node(batch_x) |
| 1792 | + |
| 1793 | + batch_axes = get_normalized_batch_axes(op.axis, core_ndim, batch_ndim) |
| 1794 | + return op.clone(axis=batch_axes).make_node(batch_x) |
0 commit comments