|
32 | 32 | from pytensor.graph.rewriting.basic import in2out
|
33 | 33 | from pytensor.graph.utils import MetaType
|
34 | 34 | from pytensor.tensor.basic import as_tensor_variable
|
35 |
| -from pytensor.tensor.random.op import RandomVariable |
| 35 | +from pytensor.tensor.random.op import RandomVariable, RNGConsumerOp |
36 | 36 | from pytensor.tensor.random.rewriting import local_subtensor_rv_lift
|
37 | 37 | from pytensor.tensor.random.utils import normalize_size_param
|
38 | 38 | from pytensor.tensor.rewriting.shape import ShapeFeature
|
@@ -207,7 +207,7 @@ def __get__(self, owner_self, owner_cls):
|
207 | 207 | return self.fget(owner_self if owner_self is not None else owner_cls)
|
208 | 208 |
|
209 | 209 |
|
210 |
| -class SymbolicRandomVariable(MeasurableOp, OpFromGraph): |
| 210 | +class SymbolicRandomVariable(MeasurableOp, RNGConsumerOp, OpFromGraph): |
211 | 211 | """Symbolic Random Variable.
|
212 | 212 |
|
213 | 213 | This is a subclasse of `OpFromGraph` which is used to encapsulate the symbolic
|
@@ -294,7 +294,10 @@ def default_output(cls_or_self) -> int | None:
|
294 | 294 | @staticmethod
|
295 | 295 | def get_input_output_type_idxs(
|
296 | 296 | extended_signature: str | None,
|
297 |
| - ) -> tuple[tuple[tuple[int], int | None, tuple[int]], tuple[tuple[int], tuple[int]]]: |
| 297 | + ) -> tuple[ |
| 298 | + tuple[tuple[int, ...], int | None, tuple[int, ...]], |
| 299 | + tuple[tuple[int, ...], tuple[int, ...]], |
| 300 | + ]: |
298 | 301 | """Parse extended_signature and return indexes for *[rng], [size] and parameters as well as outputs."""
|
299 | 302 | if extended_signature is None:
|
300 | 303 | raise ValueError("extended_signature must be provided")
|
@@ -367,8 +370,27 @@ def __init__(
|
367 | 370 |
|
368 | 371 | kwargs.setdefault("inline", True)
|
369 | 372 | kwargs.setdefault("strict", True)
|
| 373 | + # Many RVS have a size argument, even when this is `None` and is therefore unused |
| 374 | + kwargs.setdefault("on_unused_input", "ignore") |
370 | 375 | super().__init__(*args, **kwargs)
|
371 | 376 |
|
| 377 | + def make_node(self, *inputs): |
| 378 | + # If we try to build the RV with a different size type (vector -> None or None -> vector) |
| 379 | + # We need to rebuild the Op with new size type in the inner graph |
| 380 | + if self.extended_signature is not None: |
| 381 | + (rng_arg_idxs, size_arg_idx, param_idxs), _ = self.get_input_output_type_idxs( |
| 382 | + self.extended_signature |
| 383 | + ) |
| 384 | + if size_arg_idx is not None and len(rng_arg_idxs) == 1: |
| 385 | + new_size_type = normalize_size_param(inputs[size_arg_idx]).type |
| 386 | + if not self.input_types[size_arg_idx].in_same_class(new_size_type): |
| 387 | + params = [inputs[idx] for idx in param_idxs] |
| 388 | + size = inputs[size_arg_idx] |
| 389 | + rng = inputs[rng_arg_idxs[0]] |
| 390 | + return self.rebuild_rv(*params, size=size, rng=rng).owner |
| 391 | + |
| 392 | + return super().make_node(*inputs) |
| 393 | + |
372 | 394 | def update(self, node: Apply) -> dict[Variable, Variable]:
|
373 | 395 | """Symbolic update expression for input random state variables.
|
374 | 396 |
|
|
0 commit comments