Skip to content

Commit 95ffc94

Browse files
committed
Merge branch 'main' into add-zerosumnormal
2 parents 5ee950a + e419d53 commit 95ffc94

File tree

9 files changed

+338
-47
lines changed

9 files changed

+338
-47
lines changed

pymc/distributions/discrete.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from aesara.tensor.random.basic import (
2020
RandomVariable,
21+
ScipyRandomVariable,
2122
bernoulli,
2223
betabinom,
2324
binomial,
@@ -1117,15 +1118,15 @@ def logcdf(value, good, bad, n):
11171118
)
11181119

11191120

1120-
class DiscreteUniformRV(RandomVariable):
1121+
class DiscreteUniformRV(ScipyRandomVariable):
11211122
name = "discrete_uniform"
11221123
ndim_supp = 0
11231124
ndims_params = [0, 0]
11241125
dtype = "int64"
11251126
_print_name = ("DiscreteUniform", "\\operatorname{DiscreteUniform}")
11261127

11271128
@classmethod
1128-
def rng_fn(cls, rng, lower, upper, size=None):
1129+
def rng_fn_scipy(cls, rng, lower, upper, size=None):
11291130
return stats.randint.rvs(lower, upper + 1, size=size, random_state=rng)
11301131

11311132

pymc/distributions/timeseries.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import aesara.tensor as at
2020
import numpy as np
2121

22-
from aeppl.abstract import _get_measurable_outputs
2322
from aeppl.logprob import _logprob
2423
from aesara.graph.basic import Node, clone_replace
2524
from aesara.tensor import TensorVariable
@@ -144,12 +143,6 @@ def rv_op(cls, init_dist, innovation_dist, steps, size=None):
144143
)(init_dist, innovation_dist, steps)
145144

146145

147-
@_get_measurable_outputs.register(RandomWalkRV)
148-
def _get_measurable_outputs_random_walk(op, node):
149-
# Ignore steps output
150-
return [node.default_output()]
151-
152-
153146
@_change_dist_size.register(RandomWalkRV)
154147
def change_random_walk_size(op, dist, new_size, expand):
155148
init_dist, innovation_dist, steps = dist.owner.inputs

pymc/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1149,7 +1149,7 @@ def add_coord(
11491149
length = len(values)
11501150
if not isinstance(length, Variable):
11511151
if mutable:
1152-
length = aesara.shared(length)
1152+
length = aesara.shared(length, name=name)
11531153
else:
11541154
length = aesara.tensor.constant(length)
11551155
self._dim_lengths[name] = length

pymc/sampling.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,8 @@ def compile_forward_sampling_function(
16211621
vars_in_trace: List[Variable],
16221622
basic_rvs: Optional[List[Variable]] = None,
16231623
givens_dict: Optional[Dict[Variable, Any]] = None,
1624+
constant_data: Optional[Dict[str, np.ndarray]] = None,
1625+
constant_coords: Optional[Set[str]] = None,
16241626
**kwargs,
16251627
) -> Tuple[Callable[..., Union[np.ndarray, List[np.ndarray]]], Set[Variable]]:
16261628
"""Compile a function to draw samples, conditioned on the values of some variables.
@@ -1634,18 +1636,18 @@ def compile_forward_sampling_function(
16341636
compiled function or after inference has been run. These variables are:
16351637
16361638
- Variables in the outputs list
1637-
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``
1639+
- ``SharedVariable`` instances that are not ``RandomStateSharedVariable`` or ``RandomGeneratorSharedVariable``, and whose values changed with respect to what they were at inference time
16381640
- Variables that are in the `basic_rvs` list but not in the ``vars_in_trace`` list
16391641
- Variables that are keys in the ``givens_dict``
16401642
- Variables that have volatile inputs
16411643
16421644
Concretely, this function can be used to compile a function to sample from the
16431645
posterior predictive distribution of a model that has variables that are conditioned
1644-
on ``MutableData`` instances. The variables that depend on the mutable data will be
1645-
considered volatile, and as such, they wont be included as inputs into the compiled function.
1646-
This means that if they have values stored in the posterior, these values will be ignored
1647-
and new values will be computed (in the case of deterministics and potentials) or sampled
1648-
(in the case of random variables).
1646+
on ``MutableData`` instances. The variables that depend on the mutable data that have changed
1647+
will be considered volatile, and as such, they wont be included as inputs into the compiled
1648+
function. This means that if they have values stored in the posterior, these values will be
1649+
ignored and new values will be computed (in the case of deterministics and potentials) or
1650+
sampled (in the case of random variables).
16491651
16501652
This function also enables a way to impute values for any variable in the computational
16511653
graph that produces the desired outputs: the ``givens_dict``. This dictionary can be used
@@ -1672,6 +1674,25 @@ def compile_forward_sampling_function(
16721674
A dictionary that maps tensor variables to the values that should be used to replace them
16731675
in the compiled function. The types of the key and value should match or an error will be
16741676
raised during compilation.
1677+
constant_data : Optional[Dict[str, numpy.ndarray]]
1678+
A dictionary that maps the names of ``MutableData`` or ``ConstantData`` instances to their
1679+
corresponding values at inference time. If a model was created with ``MutableData``, these
1680+
are stored as ``SharedVariable`` with the name of the data variable and a value equal to
1681+
the initial data. At inference time, this information is stored in ``InferenceData``
1682+
objects under the ``constant_data`` group, which allows us to check whether a
1683+
``SharedVariable`` instance changed its values after inference or not. If the values have
1684+
changed, then the ``SharedVariable`` is assumed to be volatile. If it has not changed, then
1685+
the ``SharedVariable`` is assumed to not be volatile. If a ``SharedVariable`` is not found
1686+
in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
1687+
Setting ``constant_data`` to ``None`` is equivalent to passing an empty dictionary.
1688+
constant_coords : Optional[Set[str]]
1689+
A set with the names of the mutable coordinates that have not changed their shape after
1690+
inference. If a model was created with mutable coordinates, these are stored as
1691+
``SharedVariable`` with the name of the coordinate and a value equal to the length of said
1692+
coordinate. This set let's us check if a ``SharedVariable`` is a mutated coordinate, in
1693+
which case, it is considered volatile. If a ``SharedVariable`` is not found
1694+
in either ``constant_data`` or ``constant_coords``, then it is assumed to be volatile.
1695+
Setting ``constant_coords`` to ``None`` is equivalent to passing an empty set.
16751696
16761697
Returns
16771698
-------
@@ -1687,6 +1708,20 @@ def compile_forward_sampling_function(
16871708
if basic_rvs is None:
16881709
basic_rvs = []
16891710

1711+
if constant_data is None:
1712+
constant_data = {}
1713+
if constant_coords is None:
1714+
constant_coords = set()
1715+
1716+
# We define a helper function to check if shared values match to an array
1717+
def shared_value_matches(var):
1718+
try:
1719+
old_array_value = constant_data[var.name]
1720+
except KeyError:
1721+
return var.name in constant_coords
1722+
current_shared_value = var.get_value(borrow=True)
1723+
return np.array_equal(old_array_value, current_shared_value)
1724+
16901725
# We need a function graph to walk the clients and propagate the volatile property
16911726
fg = FunctionGraph(outputs=outputs, clone=False)
16921727

@@ -1702,6 +1737,7 @@ def compile_forward_sampling_function(
17021737
or ( # SharedVariables, except RandomState/Generators
17031738
isinstance(node, SharedVariable)
17041739
and not isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
1740+
and not shared_value_matches(node)
17051741
)
17061742
or ( # Basic RVs that are not in the trace
17071743
node in basic_rvs and node not in vars_in_trace
@@ -1835,16 +1871,24 @@ def sample_posterior_predictive(
18351871
idata_kwargs = {}
18361872
else:
18371873
idata_kwargs = idata_kwargs.copy()
1874+
constant_data: Dict[str, np.ndarray] = {}
1875+
trace_coords: Dict[str, np.ndarray] = {}
18381876
if "coords" not in idata_kwargs:
18391877
idata_kwargs["coords"] = {}
18401878
if isinstance(trace, InferenceData):
18411879
idata_kwargs["coords"].setdefault("draw", trace["posterior"]["draw"])
18421880
idata_kwargs["coords"].setdefault("chain", trace["posterior"]["chain"])
1881+
_constant_data = getattr(trace, "constant_data", None)
1882+
if _constant_data is not None:
1883+
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()})
1884+
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
1885+
trace_coords.update({str(k): v.data for k, v in trace["posterior"].coords.items()})
18431886
_trace = dataset_to_point_list(trace["posterior"])
18441887
nchain, len_trace = chains_and_samples(trace)
18451888
elif isinstance(trace, xarray.Dataset):
18461889
idata_kwargs["coords"].setdefault("draw", trace["draw"])
18471890
idata_kwargs["coords"].setdefault("chain", trace["chain"])
1891+
trace_coords.update({str(k): v.data for k, v in trace.coords.items()})
18481892
_trace = dataset_to_point_list(trace)
18491893
nchain, len_trace = chains_and_samples(trace)
18501894
elif isinstance(trace, MultiTrace):
@@ -1901,6 +1945,16 @@ def sample_posterior_predictive(
19011945
stacklevel=2,
19021946
)
19031947

1948+
constant_coords = set()
1949+
for dim, coord in trace_coords.items():
1950+
current_coord = model.coords.get(dim, None)
1951+
if (
1952+
current_coord is not None
1953+
and len(coord) == len(current_coord)
1954+
and np.all(coord == current_coord)
1955+
):
1956+
constant_coords.add(dim)
1957+
19041958
if var_names is not None:
19051959
vars_ = [model[x] for x in var_names]
19061960
else:
@@ -1935,6 +1989,8 @@ def sample_posterior_predictive(
19351989
basic_rvs=model.basic_RVs,
19361990
givens_dict=None,
19371991
random_seed=random_seed,
1992+
constant_data=constant_data,
1993+
constant_coords=constant_coords,
19381994
**compile_kwargs,
19391995
)
19401996
sampler_fn = point_wrapper(_sampler_fn)

pymc/smc/smc.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,7 @@ def update_beta_and_weights(self):
274274

275275
def resample(self):
276276
"""Resample particles based on importance weights"""
277-
self.resampling_indexes = self.rng.choice(
278-
np.arange(self.draws), size=self.draws, p=self.weights
279-
)
277+
self.resampling_indexes = systematic_resampling(self.weights, self.rng)
280278

281279
self.tempered_posterior = self.tempered_posterior[self.resampling_indexes]
282280
self.prior_logp = self.prior_logp[self.resampling_indexes]
@@ -546,6 +544,36 @@ def sample_settings(self):
546544
return stats
547545

548546

547+
def systematic_resampling(weights, rng):
548+
"""
549+
Systematic resampling.
550+
551+
Parameters
552+
----------
553+
weights :
554+
The weights should be probabilities and the total sum should be 1.
555+
556+
Returns
557+
-------
558+
new_indices: array
559+
A vector of indices in the interval 0, ..., len(normalized_weights)
560+
"""
561+
lnw = len(weights)
562+
arange = np.arange(lnw)
563+
uniform = (rng.random(1) + arange) / lnw
564+
565+
idx = 0
566+
weight_accu = weights[0]
567+
new_indices = np.empty(lnw, dtype=int)
568+
for i in arange:
569+
while uniform[i] > weight_accu:
570+
idx += 1
571+
weight_accu += weights[idx]
572+
new_indices[i] = idx
573+
574+
return new_indices
575+
576+
549577
def _logp_forw(point, out_vars, in_vars, shared):
550578
"""Compile Aesara function of the model and the input and output variables.
551579

pymc/tests/distributions/test_discrete.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,10 @@ def discrete_uniform_rng_fn(self, size, lower, upper, rng):
10421042
"check_rv_size",
10431043
]
10441044

1045+
def test_implied_degenerate_shape(self):
1046+
x = pm.DiscreteUniform.dist(0, [1])
1047+
assert x.eval().shape == (1,)
1048+
10451049

10461050
class TestDiracDelta(BaseTestDistributionRandom):
10471051
def diracdelta_rng_fn(self, size, c):

pymc/tests/distributions/test_distribution.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,11 +339,25 @@ class TestInlinedSymbolicRV(SymbolicRandomVariable):
339339
x_inline = TestInlinedSymbolicRV([], [Flat.dist()], ndim_supp=0)()
340340
assert np.isclose(logp(x_inline, 0).eval(), 0)
341341

342-
def test_measurable_outputs(self):
342+
def test_measurable_outputs_rng_ignored(self):
343+
"""Test that any RandomType outputs are ignored as a measurable_outputs"""
344+
343345
class TestSymbolicRV(SymbolicRandomVariable):
344346
pass
345347

346348
next_rng_, dirac_delta_ = DiracDelta.dist(5).owner.outputs
347349
next_rng, dirac_delta = TestSymbolicRV([], [next_rng_, dirac_delta_], ndim_supp=0)()
348350
node = dirac_delta.owner
349351
assert get_measurable_outputs(node.op, node) == [dirac_delta]
352+
353+
@pytest.mark.parametrize("default_output_idx", (0, 1))
354+
def test_measurable_outputs_default_output(self, default_output_idx):
355+
"""Test that if provided, a default output is considered the only measurable_output"""
356+
357+
class TestSymbolicRV(SymbolicRandomVariable):
358+
default_output = default_output_idx
359+
360+
dirac_delta_1_ = DiracDelta.dist(5)
361+
dirac_delta_2_ = DiracDelta.dist(10)
362+
node = TestSymbolicRV([], [dirac_delta_1_, dirac_delta_2_], ndim_supp=0)().owner
363+
assert get_measurable_outputs(node.op, node) == [node.outputs[default_output_idx]]

0 commit comments

Comments
 (0)