|
29 | 29 | from aesara import tensor as at
|
30 | 30 | from aesara.compile.builders import OpFromGraph
|
31 | 31 | from aesara.graph import node_rewriter
|
32 |
| -from aesara.graph.basic import Node, Variable, clone_replace |
| 32 | +from aesara.graph.basic import Node, clone_replace |
33 | 33 | from aesara.graph.rewriting.basic import in2out
|
34 | 34 | from aesara.graph.utils import MetaType
|
35 | 35 | from aesara.tensor.basic import as_tensor_variable
|
|
42 | 42 | from pymc.distributions.shape_utils import (
|
43 | 43 | Dims,
|
44 | 44 | Shape,
|
45 |
| - StrongDims, |
46 |
| - StrongShape, |
47 |
| - change_dist_size, |
48 | 45 | convert_dims,
|
49 | 46 | convert_shape,
|
50 | 47 | convert_size,
|
@@ -154,35 +151,6 @@ def fn(*args, **kwargs):
|
154 | 151 | return fn
|
155 | 152 |
|
156 | 153 |
|
157 |
| -def _make_rv_and_resize_shape_from_dims( |
158 |
| - *, |
159 |
| - cls, |
160 |
| - dims: Optional[StrongDims], |
161 |
| - model, |
162 |
| - observed, |
163 |
| - args, |
164 |
| - **kwargs, |
165 |
| -) -> Tuple[Variable, StrongShape]: |
166 |
| - """Creates the RV, possibly using dims or observed to determine a resize shape (if needed).""" |
167 |
| - resize_shape_from_dims = None |
168 |
| - size_or_shape = kwargs.get("size") or kwargs.get("shape") |
169 |
| - |
170 |
| - # Preference is given to size or shape. If not specified, we rely on dims and |
171 |
| - # finally, observed, to determine the shape of the variable. Because dims can be |
172 |
| - # specified on the fly, we need a two-step process where we first create the RV |
173 |
| - # without dims information and then resize it. |
174 |
| - if not size_or_shape and observed is not None: |
175 |
| - kwargs["shape"] = tuple(observed.shape) |
176 |
| - |
177 |
| - # Create the RV without dims information |
178 |
| - rv_out = cls.dist(*args, **kwargs) |
179 |
| - |
180 |
| - if not size_or_shape and dims is not None: |
181 |
| - resize_shape_from_dims = shape_from_dims(dims, tuple(rv_out.shape), model) |
182 |
| - |
183 |
| - return rv_out, resize_shape_from_dims |
184 |
| - |
185 |
| - |
186 | 154 | class SymbolicRandomVariable(OpFromGraph):
|
187 | 155 | """Symbolic Random Variable
|
188 | 156 |
|
@@ -311,17 +279,15 @@ def __new__(
|
311 | 279 | if observed is not None:
|
312 | 280 | observed = convert_observed_data(observed)
|
313 | 281 |
|
314 |
| - # Create the RV, without taking `dims` into consideration |
315 |
| - rv_out, resize_shape_from_dims = _make_rv_and_resize_shape_from_dims( |
316 |
| - cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs |
317 |
| - ) |
| 282 | + # Preference is given to size or shape. If not specified, we rely on dims and |
| 283 | + # finally, observed, to determine the shape of the variable. |
| 284 | + if not ("size" in kwargs or "shape" in kwargs): |
| 285 | + if dims is not None: |
| 286 | + kwargs["shape"] = shape_from_dims(dims, model) |
| 287 | + elif observed is not None: |
| 288 | + kwargs["shape"] = tuple(observed.shape) |
318 | 289 |
|
319 |
| - # Resize variable based on `dims` information |
320 |
| - if resize_shape_from_dims: |
321 |
| - resize_size_from_dims = find_size( |
322 |
| - shape=resize_shape_from_dims, size=None, ndim_supp=rv_out.owner.op.ndim_supp |
323 |
| - ) |
324 |
| - rv_out = change_dist_size(dist=rv_out, new_size=resize_size_from_dims, expand=False) |
| 290 | + rv_out = cls.dist(*args, **kwargs) |
325 | 291 |
|
326 | 292 | rv_out = model.register_rv(
|
327 | 293 | rv_out,
|
|
0 commit comments