|
13 | 13 | )
|
14 | 14 | from pytensor.graph.basic import Apply, Constant, Variable
|
15 | 15 | from pytensor.graph.op import Op
|
| 16 | +from pytensor.graph.replace import _vectorize_node |
16 | 17 | from pytensor.link.c.op import COp
|
17 | 18 | from pytensor.link.c.params_type import ParamsType
|
18 | 19 | from pytensor.link.c.type import EnumList, Generic
|
@@ -360,7 +361,7 @@ def grad(self, inputs, output_gradients):
|
360 | 361 | )
|
361 | 362 |
|
362 | 363 | def infer_shape(self, fgraph, node, shapes):
|
363 |
| - if self.axis is None: |
| 364 | + if self.axis is None and len(shapes[0]) > 1: |
364 | 365 | return [(prod(shapes[0]),)] # Flatten
|
365 | 366 |
|
366 | 367 | return shapes
|
@@ -473,6 +474,25 @@ def cumprod(x, axis=None):
|
473 | 474 | return CumOp(axis=axis, mode="mul")(x)
|
474 | 475 |
|
475 | 476 |
|
| 477 | +@_vectorize_node.register(CumOp) |
| 478 | +def vectorize_cum_op(op: CumOp, node: Apply, batch_x): |
| 479 | + """Vectorize the CumOp to work on a batch of inputs.""" |
| 480 | + [original_x] = node.inputs |
| 481 | + batch_ndim = batch_x.ndim - original_x.ndim |
| 482 | + axis = op.axis |
| 483 | + if axis is None and original_x.ndim == 1: |
| 484 | + axis = 0 |
| 485 | + elif axis is not None: |
| 486 | + axis = normalize_axis_index(op.axis, original_x.ndim) |
| 487 | + |
| 488 | + if axis is None: |
| 489 | + # Ravel all unbatched dimensions and perform CumOp on the last axis |
| 490 | + batch_x_raveled = [batch_x.flatten(ndim=batch_ndim + 1) for x in batch_x] |
| 491 | + return type(op)(axis=-1, mode=op.mode).make_node(batch_x_raveled) |
| 492 | + else: |
| 493 | + return type(op)(axis=axis + batch_ndim, mode=op.mode).make_node(batch_x) |
| 494 | + |
| 495 | + |
476 | 496 | def diff(x, n=1, axis=-1):
|
477 | 497 | """Calculate the `n`-th order discrete difference along the given `axis`.
|
478 | 498 |
|
|
0 commit comments