|
17 | 17 | A collection of common shape operations needed for broadcasting |
18 | 18 | samples from probability distributions for stochastic nodes in PyMC. |
19 | 19 | """ |
| 20 | +import warnings |
| 21 | + |
20 | 22 | from functools import singledispatch |
21 | 23 | from typing import Optional, Sequence, Tuple, Union |
22 | 24 |
|
@@ -579,8 +581,8 @@ def change_dist_size( |
579 | 581 | Returns |
580 | 582 | ------- |
581 | 583 | A new distribution variable that is equivalent to the original distribution with |
582 | | - the new size. The new distribution may reuse the same RandomState/Generator inputs |
583 | | - as the original distribution. |
| 584 | + the new size. The new distribution will not reuse the old RandomState/Generator |
| 585 | + input, so it will be independent from the original distribution. |
584 | 586 |
|
585 | 587 | Examples |
586 | 588 | -------- |
@@ -618,24 +620,29 @@ def change_dist_size( |
618 | 620 | def change_rv_size(op, rv, new_size, expand) -> TensorVariable: |
619 | 621 | # Extract the RV node that is to be resized |
620 | 622 | rv_node = rv.owner |
621 | | - rng, size, dtype, *dist_params = rv_node.inputs |
| 623 | + old_rng, old_size, dtype, *dist_params = rv_node.inputs |
622 | 624 |
|
623 | 625 | if expand: |
624 | | - shape = tuple(rv_node.op._infer_shape(size, dist_params)) |
625 | | - size = shape[: len(shape) - rv_node.op.ndim_supp] |
626 | | - new_size = tuple(new_size) + tuple(size) |
| 626 | + shape = tuple(rv_node.op._infer_shape(old_size, dist_params)) |
| 627 | + old_size = shape[: len(shape) - rv_node.op.ndim_supp] |
| 628 | + new_size = tuple(new_size) + tuple(old_size) |
627 | 629 |
|
628 | 630 | # Make sure the new size is a tensor. This dtype-aware conversion helps |
629 | 631 | # to not unnecessarily pick up a `Cast` in some cases (see #4652). |
630 | 632 | new_size = at.as_tensor(new_size, ndim=1, dtype="int64") |
631 | 633 |
|
632 | | - new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params) |
633 | | - new_rv = new_rv_node.outputs[-1] |
| 634 | + new_rv = rv_node.op(*dist_params, size=new_size, dtype=dtype) |
634 | 635 |
|
635 | | - # Update "traditional" rng default_update, if that was set for old RV |
636 | | - default_update = getattr(rng, "default_update", None) |
637 | | - if default_update is not None and default_update is rv_node.outputs[0]: |
638 | | - rng.default_update = new_rv_node.outputs[0] |
| 636 | + # Replicate "traditional" rng default_update, if that was set for old_rng |
| 637 | + default_update = getattr(old_rng, "default_update", None) |
| 638 | + if default_update is not None: |
| 639 | + if default_update is rv_node.outputs[0]: |
| 640 | + new_rv.owner.inputs[0].default_update = new_rv.owner.outputs[0] |
| 641 | + else: |
| 642 | + warnings.warn( |
| 643 | + f"Update expression of {rv} RNG could not be replicated in resized variable", |
| 644 | + UserWarning, |
| 645 | + ) |
639 | 646 |
|
640 | 647 | return new_rv |
641 | 648 |
|
|
0 commit comments