|
5 | 5 | import pytensor.tensor.random.basic as ptr
|
6 | 6 | from pytensor.graph.basic import Variable
|
7 | 7 | from pytensor.tensor.random.op import RandomVariable
|
8 |
| -from pytensor.xtensor import as_xtensor |
9 | 8 | from pytensor.xtensor.math import sqrt
|
| 9 | +from pytensor.xtensor.type import as_xtensor |
10 | 10 | from pytensor.xtensor.vectorization import XRV
|
11 | 11 |
|
12 | 12 |
|
13 | 13 | def _as_xrv(
|
14 | 14 | core_op: RandomVariable,
|
15 | 15 | core_inps_dims_map: Sequence[Sequence[int]] | None = None,
|
16 | 16 | core_out_dims_map: Sequence[int] | None = None,
|
| 17 | + name: str | None = None, |
17 | 18 | ):
|
18 | 19 | """Helper function to define an XRV constructor.
|
19 | 20 |
|
@@ -41,7 +42,14 @@ def _as_xrv(
|
41 | 42 | core_out_dims_map = tuple(range(core_op.ndim_supp))
|
42 | 43 |
|
43 | 44 | core_dims_needed = max(
|
44 |
| - (*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0 |
| 45 | + max( |
| 46 | + ( |
| 47 | + max((entry + 1 for entry in dims_map), default=0) |
| 48 | + for dims_map in core_inps_dims_map |
| 49 | + ), |
| 50 | + default=0, |
| 51 | + ), |
| 52 | + max((entry + 1 for entry in core_out_dims_map), default=0), |
45 | 53 | )
|
46 | 54 |
|
47 | 55 | @wraps(core_op)
|
@@ -76,7 +84,10 @@ def xrv_constructor(
|
76 | 84 | extra_dims = {}
|
77 | 85 |
|
78 | 86 | return XRV(
|
79 |
| - core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys()) |
| 87 | + core_op, |
| 88 | + core_dims=full_core_dims, |
| 89 | + extra_dims=tuple(extra_dims.keys()), |
| 90 | + name=name, |
80 | 91 | )(rng, *extra_dims.values(), *params)
|
81 | 92 |
|
82 | 93 | return xrv_constructor
|
|
0 commit comments