Skip to content

Commit b83ca3f

Browse files
committed
Vectorize CumOp and simplify infer_shape for vector case
1 parent 5024d54 commit b83ca3f

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from pytensor.graph.basic import Apply, Constant, Variable
1515
from pytensor.graph.op import Op
16+
from pytensor.graph.replace import _vectorize_node
1617
from pytensor.link.c.op import COp
1718
from pytensor.link.c.params_type import ParamsType
1819
from pytensor.link.c.type import EnumList, Generic
@@ -360,7 +361,7 @@ def grad(self, inputs, output_gradients):
360361
)
361362

362363
def infer_shape(self, fgraph, node, shapes):
363-
if self.axis is None:
364+
if self.axis is None and len(shapes[0]) > 1:
364365
return [(prod(shapes[0]),)] # Flatten
365366

366367
return shapes
@@ -473,6 +474,25 @@ def cumprod(x, axis=None):
473474
return CumOp(axis=axis, mode="mul")(x)
474475

475476

477+
@_vectorize_node.register(CumOp)
478+
def vectorize_cum_op(op: CumOp, node: Apply, batch_x):
479+
"""Vectorize the CumOp to work on a batch of inputs."""
480+
[original_x] = node.inputs
481+
batch_ndim = batch_x.ndim - original_x.ndim
482+
axis = op.axis
483+
if axis is None and original_x.ndim == 1:
484+
axis = 0
485+
elif axis is not None:
486+
axis = normalize_axis_index(op.axis, original_x.ndim)
487+
488+
if axis is None:
489+
# Ravel all unbatched dimensions and perform CumOp on the last axis
490+
batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x]
491+
return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled)
492+
else:
493+
return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x)
494+
495+
476496
def diff(x, n=1, axis=-1):
477497
"""Calculate the `n`-th order discrete difference along the given `axis`.
478498

0 commit comments

Comments
 (0)