|
24 | 24 | from pytensor.tensor.var import TensorVariable
|
25 | 25 |
|
26 | 26 |
|
27 |
| -def default_supp_shape_from_params( |
28 |
| - ndim_supp: int, |
29 |
| - dist_params: Sequence[Variable], |
30 |
| - rep_param_idx: int = 0, |
31 |
| - param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None, |
32 |
| -) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]: |
33 |
| - """Infer the dimensions for the output of a `RandomVariable`. |
34 |
| -
|
35 |
| - This is a function that derives a random variable's support |
36 |
| - shape/dimensions from one of its parameters. |
37 |
| -
|
38 |
| - XXX: It's not always possible to determine a random variable's support |
39 |
| - shape from its parameters, so this function has fundamentally limited |
40 |
| - applicability and must be replaced by custom logic in such cases. |
41 |
| -
|
42 |
| - XXX: This function is not expected to handle `ndim_supp = 0` (i.e. |
43 |
| - scalars), since that is already definitively handled in the `Op` that |
44 |
| - calls this. |
45 |
| -
|
46 |
| - TODO: Consider using `pytensor.compile.ops.shape_i` alongside `ShapeFeature`. |
47 |
| -
|
48 |
| - Parameters |
49 |
| - ---------- |
50 |
| - ndim_supp: int |
51 |
| - Total number of dimensions for a single draw of the random variable |
52 |
| - (e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`). |
53 |
| - dist_params: list of `pytensor.graph.basic.Variable` |
54 |
| - The distribution parameters. |
55 |
| - rep_param_idx: int (optional) |
56 |
| - The index of the distribution parameter to use as a reference |
57 |
| - In other words, a parameter in `dist_param` with a shape corresponding |
58 |
| - to the support's shape. |
59 |
| - The default is the first parameter (i.e. the value 0). |
60 |
| - param_shapes: list of tuple of `ScalarVariable` (optional) |
61 |
| - Symbolic shapes for each distribution parameter. These will |
62 |
| - be used in place of distribution parameter-generated shapes. |
63 |
| -
|
64 |
| - Results |
65 |
| - ------- |
66 |
| - out: a tuple representing the support shape for a distribution with the |
67 |
| - given `dist_params`. |
68 |
| -
|
69 |
| - """ |
70 |
| - if ndim_supp <= 0: |
71 |
| - raise ValueError("ndim_supp must be greater than 0") |
72 |
| - if param_shapes is not None: |
73 |
| - ref_param = param_shapes[rep_param_idx] |
74 |
| - return (ref_param[-ndim_supp],) |
75 |
| - else: |
76 |
| - ref_param = dist_params[rep_param_idx] |
77 |
| - if ref_param.ndim < ndim_supp: |
78 |
| - raise ValueError( |
79 |
| - "Reference parameter does not match the " |
80 |
| - f"expected dimensions; {ref_param} has less than {ndim_supp} dim(s)." |
81 |
| - ) |
82 |
| - return ref_param.shape[-ndim_supp:] |
83 |
| - |
84 |
| - |
85 | 27 | class RandomVariable(Op):
|
86 | 28 | """An `Op` that produces a sample from a random variable.
|
87 | 29 |
|
@@ -151,15 +93,29 @@ def __init__(
|
151 | 93 | if self.inplace:
|
152 | 94 | self.destroy_map = {0: [0]}
|
153 | 95 |
|
154 |
| - def _supp_shape_from_params(self, dist_params, **kwargs): |
155 |
| - """Determine the support shape of a `RandomVariable`'s output given its parameters. |
| 96 | + def _supp_shape_from_params(self, dist_params, param_shapes=None): |
| 97 | + """Determine the support shape of a multivariate `RandomVariable`'s output given its parameters. |
156 | 98 |
|
157 | 99 | This does *not* consider the extra dimensions added by the `size` parameter
|
158 | 100 | or independent (batched) parameters.
|
159 | 101 |
|
160 |
| - Defaults to `param_supp_shape_fn`. |
| 102 | + When provided, `param_shapes` should be given preference over `[d.shape for d in dist_params]`, |
| 103 | + as it will avoid redundancies in PyTensor shape inference. |
| 104 | +
|
| 105 | + Examples |
| 106 | + -------- |
| 107 | + Common multivariate `RandomVariable`s derive their support shapes implicitly from the |
| 108 | + last dimension of some of their parameters. For example `multivariate_normal` support shape |
| 109 | + corresponds to the last dimension of the mean or covariance parameters, `support_shape=(mu.shape[-1])`. |
| 110 | + For this case the helper `pytensor.tensor.random.utils.supp_shape_from_ref_param_shape` can be used. |
| 111 | +
|
| 112 | + Other variables have fixed support shape such as `support_shape=(2,)` or it is determined by the |
| 113 | + values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`, |
| 114 | + might have `support_shape=(steps,)`. |
161 | 115 | """
|
162 |
| - return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs) |
| 116 | + raise NotImplementedError( |
| 117 | + "`_supp_shape_from_params` must be implemented for multivariate RVs" |
| 118 | + ) |
163 | 119 |
|
164 | 120 | def rng_fn(self, rng, *args, **kwargs) -> Union[int, float, np.ndarray]:
|
165 | 121 | """Sample a numeric random variate."""
|
|
0 commit comments