|
33 | 33 | from aesara.tensor.var import TensorVariable
|
34 | 34 | from typing_extensions import TypeAlias
|
35 | 35 |
|
36 |
| -from pymc.aesaraf import change_rv_size |
| 36 | +from pymc.aesaraf import change_rv_size, convert_observed_data |
37 | 37 | from pymc.distributions.shape_utils import (
|
38 | 38 | Dims,
|
39 | 39 | Shape,
|
40 | 40 | Size,
|
| 41 | + StrongDims, |
41 | 42 | StrongShape,
|
42 | 43 | convert_dims,
|
43 | 44 | convert_shape,
|
44 | 45 | convert_size,
|
45 | 46 | find_size,
|
46 |
| - resize_from_dims, |
47 |
| - resize_from_observed, |
| 47 | + shape_from_dims, |
48 | 48 | )
|
49 | 49 | from pymc.printing import str_for_dist, str_for_symbolic_dist
|
50 | 50 | from pymc.util import UNSET
|
@@ -152,29 +152,28 @@ def fn(*args, **kwargs):
|
152 | 152 | def _make_rv_and_resize_shape(
|
153 | 153 | *,
|
154 | 154 | cls,
|
155 |
| - dims: Optional[Dims], |
| 155 | + dims: Optional[StrongDims], |
156 | 156 | model,
|
157 | 157 | observed,
|
158 | 158 | args,
|
159 | 159 | **kwargs,
|
160 |
| -) -> Tuple[Variable, Optional[Dims], Optional[Union[np.ndarray, Variable]], StrongShape]: |
161 |
| - """Creates the RV and processes dims or observed to determine a resize shape.""" |
162 |
| - # Create the RV without dims information, because that's not something tracked at the Aesara level. |
163 |
| - # If necessary we'll later replicate to a different size implied by already known dims. |
164 |
| - rv_out = cls.dist(*args, **kwargs) |
165 |
| - ndim_actual = rv_out.ndim |
| 160 | +) -> Tuple[Variable, StrongShape]: |
| 161 | + """Creates the RV, possibly using dims or observed to determine a resize shape (if needed).""" |
166 | 162 | resize_shape = None
|
| 163 | + size_or_shape = kwargs.get("size") or kwargs.get("shape") |
| 164 | + |
| 165 | + # Create the RV without dims or observed information |
| 166 | + rv_out = cls.dist(*args, **kwargs) |
167 | 167 |
|
168 |
| - # # `dims` are only available with this API, because `.dist()` can be used |
169 |
| - # # without a modelcontext and dims are not tracked at the Aesara level. |
170 |
| - dims = convert_dims(dims) |
171 |
| - dims_can_resize = kwargs.get("shape", None) is None and kwargs.get("size", None) is None |
172 |
| - if dims is not None: |
173 |
| - if dims_can_resize: |
174 |
| - resize_shape, dims = resize_from_dims(dims, ndim_actual, model) |
175 |
| - elif observed is not None: |
176 |
| - resize_shape, observed = resize_from_observed(observed, ndim_actual) |
177 |
| - return rv_out, dims, observed, resize_shape |
| 168 | + # Preference is given to size or shape, if not provided we use dims and observed |
| 169 | + # to resize the variable |
| 170 | + if not size_or_shape: |
| 171 | + if dims is not None: |
| 172 | + resize_shape = shape_from_dims(dims, tuple(rv_out.shape), model) |
| 173 | + elif observed is not None: |
| 174 | + resize_shape = tuple(observed.shape) |
| 175 | + |
| 176 | + return rv_out, resize_shape |
178 | 177 |
|
179 | 178 |
|
180 | 179 | class Distribution(metaclass=DistributionMeta):
|
@@ -254,15 +253,20 @@ def __new__(
|
254 | 253 | if not isinstance(name, string_types):
|
255 | 254 | raise TypeError(f"Name needs to be a string but got: {name}")
|
256 | 255 |
|
257 |
| - # Create the RV and process dims and observed to determine |
258 |
| - # a shape by which the created RV may need to be resized. |
259 |
| - rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( |
| 256 | + dims = convert_dims(dims) |
| 257 | + if observed is not None: |
| 258 | + observed = convert_observed_data(observed) |
| 259 | + |
| 260 | + # Create the RV, possibly taking into consideration dims and observed to |
| 261 | + # determine its shape |
| 262 | + rv_out, resize_shape = _make_rv_and_resize_shape( |
260 | 263 | cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
|
261 | 264 | )
|
262 | 265 |
|
| 266 | + # A shape was specified only through `dims`, or implied by `observed`. |
263 | 267 | if resize_shape:
|
264 |
| - # A batch size was specified through `dims`, or implied by `observed`. |
265 |
| - rv_out = change_rv_size(rv=rv_out, new_size=resize_shape, expand=True) |
| 268 | + resize_size = find_size(shape=resize_shape, size=None, ndim_supp=cls.rv_op.ndim_supp) |
| 269 | + rv_out = change_rv_size(rv=rv_out, new_size=resize_size, expand=False) |
266 | 270 |
|
267 | 271 | rv_out = model.register_rv(
|
268 | 272 | rv_out,
|
@@ -336,11 +340,7 @@ def dist(
|
336 | 340 | shape = convert_shape(shape)
|
337 | 341 | size = convert_size(size)
|
338 | 342 |
|
339 |
| - create_size, ndim_expected, ndim_batch, ndim_supp = find_size( |
340 |
| - shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp |
341 |
| - ) |
342 |
| - # Create the RV with a `size` right away. |
343 |
| - # This is not necessarily the final result. |
| 343 | + create_size = find_size(shape=shape, size=size, ndim_supp=cls.rv_op.ndim_supp) |
344 | 344 | rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs)
|
345 | 345 |
|
346 | 346 | rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
|
@@ -448,19 +448,20 @@ def __new__(
|
448 | 448 | if not isinstance(name, string_types):
|
449 | 449 | raise TypeError(f"Name needs to be a string but got: {name}")
|
450 | 450 |
|
451 |
| - # Create the RV and process dims and observed to determine |
452 |
| - # a shape by which the created RV may need to be resized. |
453 |
| - rv_out, dims, observed, resize_shape = _make_rv_and_resize_shape( |
| 451 | + dims = convert_dims(dims) |
| 452 | + if observed is not None: |
| 453 | + observed = convert_observed_data(observed) |
| 454 | + |
| 455 | + # Create the RV, possibly taking into consideration dims and observed to |
| 456 | + # determine its shape |
| 457 | + rv_out, resize_shape = _make_rv_and_resize_shape( |
454 | 458 | cls=cls, dims=dims, model=model, observed=observed, args=args, **kwargs
|
455 | 459 | )
|
456 | 460 |
|
| 461 | + # A shape was specified only through `dims`, or implied by `observed`. |
457 | 462 | if resize_shape:
|
458 |
| - # A batch size was specified through `dims`, or implied by `observed`. |
459 |
| - rv_out = cls.change_size( |
460 |
| - rv=rv_out, |
461 |
| - new_size=resize_shape, |
462 |
| - expand=True, |
463 |
| - ) |
| 463 | + resize_size = find_size(shape=resize_shape, size=None, ndim_supp=rv_out.tag.ndim_supp) |
| 464 | + rv_out = cls.change_size(rv=rv_out, new_size=resize_size, expand=False) |
464 | 465 |
|
465 | 466 | rv_out = model.register_rv(
|
466 | 467 | rv_out,
|
@@ -529,18 +530,17 @@ def dist(
|
529 | 530 | shape = convert_shape(shape)
|
530 | 531 | size = convert_size(size)
|
531 | 532 |
|
532 |
| - create_size, ndim_expected, ndim_batch, ndim_supp = find_size( |
533 |
| - shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params) |
534 |
| - ) |
535 |
| - # Create the RV with a `size` right away. |
536 |
| - # This is not necessarily the final result. |
537 |
| - graph = cls.rv_op(*dist_params, size=create_size, **kwargs) |
| 533 | + ndim_supp = cls.ndim_supp(*dist_params) |
| 534 | + create_size = find_size(shape=shape, size=size, ndim_supp=ndim_supp) |
| 535 | + rv_out = cls.rv_op(*dist_params, size=create_size, **kwargs) |
| 536 | + # This is needed for resizing from dims in `__new__` |
| 537 | + rv_out.tag.ndim_supp = ndim_supp |
538 | 538 |
|
539 | 539 | # TODO: Create new attr error stating that these are not available for DerivedDistribution
|
540 | 540 | # rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
|
541 | 541 | # rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
|
542 | 542 | # rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
|
543 |
| - return graph |
| 543 | + return rv_out |
544 | 544 |
|
545 | 545 |
|
546 | 546 | @singledispatch
|
|
0 commit comments