Skip to content

Commit a3f44f5

Browse files
committed
Move simulator tests to appropriate module
1 parent 637ced4 commit a3f44f5

File tree

2 files changed

+306
-308
lines changed

2 files changed

+306
-308
lines changed

pymc/tests/distributions/test_simulator.py

Lines changed: 306 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,322 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415

16+
import aesara
1517
import numpy as np
1618
import pytest
1719
import scipy.stats as st
1820

21+
from aesara.graph import ancestors
22+
from aesara.tensor.random.op import RandomVariable
23+
from aesara.tensor.random.var import (
24+
RandomGeneratorSharedVariable,
25+
RandomStateSharedVariable,
26+
)
27+
from aesara.tensor.sort import SortOp
28+
1929
import pymc as pm
2030

31+
from pymc import floatX
2132
from pymc.initial_point import make_initial_point_fn
33+
from pymc.smc.smc import IMH
34+
from pymc.tests.helpers import SeededTest
35+
36+
37+
class TestSimulator(SeededTest):
38+
@staticmethod
39+
def count_rvs(end_node):
40+
return len(
41+
[
42+
node
43+
for node in ancestors([end_node])
44+
if node.owner is not None and isinstance(node.owner.op, RandomVariable)
45+
]
46+
)
47+
48+
@staticmethod
49+
def normal_sim(rng, a, b, size):
50+
return rng.normal(a, b, size=size)
51+
52+
@staticmethod
53+
def abs_diff(eps, obs_data, sim_data):
54+
return np.mean(np.abs((obs_data - sim_data) / eps))
55+
56+
@staticmethod
57+
def quantiles(x):
58+
return np.quantile(x, [0.25, 0.5, 0.75])
59+
60+
def setup_class(self):
61+
super().setup_class()
62+
self.data = np.random.normal(loc=0, scale=1, size=1000)
63+
64+
with pm.Model() as self.SMABC_test:
65+
a = pm.Normal("a", mu=0, sigma=1)
66+
b = pm.HalfNormal("b", sigma=1)
67+
s = pm.Simulator("s", self.normal_sim, a, b, sum_stat="sort", observed=self.data)
68+
self.s = s
69+
70+
with pm.Model() as self.SMABC_potential:
71+
a = pm.Normal("a", mu=0, sigma=1, initval=0.5)
72+
b = pm.HalfNormal("b", sigma=1)
73+
c = pm.Potential("c", pm.math.switch(a > 0, 0, -np.inf))
74+
s = pm.Simulator("s", self.normal_sim, a, b, observed=self.data)
75+
76+
def test_one_gaussian(self):
77+
assert self.count_rvs(self.SMABC_test.logp()) == 1
78+
79+
with self.SMABC_test:
80+
trace = pm.sample_smc(draws=1000, chains=1, return_inferencedata=False)
81+
pr_p = pm.sample_prior_predictive(1000, return_inferencedata=False)
82+
po_p = pm.sample_posterior_predictive(
83+
trace, keep_size=False, return_inferencedata=False
84+
)
85+
86+
assert abs(self.data.mean() - trace["a"].mean()) < 0.05
87+
assert abs(self.data.std() - trace["b"].mean()) < 0.05
88+
89+
assert pr_p["s"].shape == (1000, 1000)
90+
assert abs(0 - pr_p["s"].mean()) < 0.15
91+
assert abs(1.4 - pr_p["s"].std()) < 0.10
92+
93+
assert po_p["s"].shape == (1000, 1000)
94+
assert abs(self.data.mean() - po_p["s"].mean()) < 0.10
95+
assert abs(self.data.std() - po_p["s"].std()) < 0.10
96+
97+
@pytest.mark.parametrize("floatX", ["float32", "float64"])
98+
def test_custom_dist_sum_stat(self, floatX):
99+
with aesara.config.change_flags(floatX=floatX):
100+
with pm.Model() as m:
101+
a = pm.Normal("a", mu=0, sigma=1)
102+
b = pm.HalfNormal("b", sigma=1)
103+
s = pm.Simulator(
104+
"s",
105+
self.normal_sim,
106+
a,
107+
b,
108+
distance=self.abs_diff,
109+
sum_stat=self.quantiles,
110+
observed=self.data,
111+
)
112+
113+
assert self.count_rvs(m.logp()) == 1
114+
115+
with m:
116+
with warnings.catch_warnings():
117+
warnings.filterwarnings("ignore", "More chains .* than draws .*", UserWarning)
118+
pm.sample_smc(draws=100)
119+
120+
@pytest.mark.parametrize("floatX", ["float32", "float64"])
121+
def test_custom_dist_sum_stat_scalar(self, floatX):
122+
"""
123+
Test that automatically wrapped functions cope well with scalar inputs
124+
"""
125+
scalar_data = 5
126+
127+
with aesara.config.change_flags(floatX=floatX):
128+
with pm.Model() as m:
129+
s = pm.Simulator(
130+
"s",
131+
self.normal_sim,
132+
0,
133+
1,
134+
distance=self.abs_diff,
135+
sum_stat=self.quantiles,
136+
observed=scalar_data,
137+
)
138+
assert self.count_rvs(m.logp()) == 1
139+
140+
with pm.Model() as m:
141+
s = pm.Simulator(
142+
"s",
143+
self.normal_sim,
144+
0,
145+
1,
146+
distance=self.abs_diff,
147+
sum_stat="mean",
148+
observed=scalar_data,
149+
)
150+
assert self.count_rvs(m.logp()) == 1
151+
152+
def test_model_with_potential(self):
153+
assert self.count_rvs(self.SMABC_potential.logp()) == 1
154+
155+
with self.SMABC_potential:
156+
trace = pm.sample_smc(draws=100, chains=1, return_inferencedata=False)
157+
assert np.all(trace["a"] >= 0)
158+
159+
def test_simulator_metropolis_mcmc(self):
160+
with self.SMABC_test as m:
161+
step = pm.Metropolis([m.rvs_to_values[m["a"]], m.rvs_to_values[m["b"]]])
162+
trace = pm.sample(step=step, return_inferencedata=False)
163+
164+
assert abs(self.data.mean() - trace["a"].mean()) < 0.05
165+
assert abs(self.data.std() - trace["b"].mean()) < 0.05
166+
167+
def test_multiple_simulators(self):
168+
true_a = 2
169+
true_b = -2
170+
171+
data1 = np.random.normal(true_a, 0.1, size=1000)
172+
data2 = np.random.normal(true_b, 0.1, size=1000)
173+
174+
with pm.Model() as m:
175+
a = pm.Normal("a", mu=0, sigma=3)
176+
b = pm.Normal("b", mu=0, sigma=3)
177+
sim1 = pm.Simulator(
178+
"sim1",
179+
self.normal_sim,
180+
a,
181+
0.1,
182+
distance="gaussian",
183+
sum_stat="sort",
184+
observed=data1,
185+
)
186+
sim2 = pm.Simulator(
187+
"sim2",
188+
self.normal_sim,
189+
b,
190+
0.1,
191+
distance="laplace",
192+
sum_stat="mean",
193+
epsilon=0.1,
194+
observed=data2,
195+
)
196+
197+
assert self.count_rvs(m.logp()) == 2
198+
199+
# Check that the logps use the correct methods
200+
a_val = m.rvs_to_values[a]
201+
sim1_val = m.rvs_to_values[sim1]
202+
logp_sim1 = pm.joint_logp(sim1, sim1_val)
203+
logp_sim1_fn = aesara.function([a_val], logp_sim1)
204+
205+
b_val = m.rvs_to_values[b]
206+
sim2_val = m.rvs_to_values[sim2]
207+
logp_sim2 = pm.joint_logp(sim2, sim2_val)
208+
logp_sim2_fn = aesara.function([b_val], logp_sim2)
209+
210+
assert any(
211+
node for node in logp_sim1_fn.maker.fgraph.toposort() if isinstance(node.op, SortOp)
212+
)
213+
214+
assert not any(
215+
node for node in logp_sim2_fn.maker.fgraph.toposort() if isinstance(node.op, SortOp)
216+
)
217+
218+
with m:
219+
trace = pm.sample_smc(return_inferencedata=False)
220+
221+
assert abs(true_a - trace["a"].mean()) < 0.05
222+
assert abs(true_b - trace["b"].mean()) < 0.05
223+
224+
def test_nested_simulators(self):
225+
true_a = 2
226+
rng = self.get_random_state()
227+
data = rng.normal(true_a, 0.1, size=1000)
228+
229+
with pm.Model() as m:
230+
sim1 = pm.Simulator(
231+
"sim1",
232+
self.normal_sim,
233+
params=(0, 4),
234+
distance="gaussian",
235+
sum_stat="identity",
236+
)
237+
sim2 = pm.Simulator(
238+
"sim2",
239+
self.normal_sim,
240+
params=(sim1, 0.1),
241+
distance="gaussian",
242+
sum_stat="mean",
243+
epsilon=0.1,
244+
observed=data,
245+
)
246+
247+
assert self.count_rvs(m.logp()) == 2
248+
249+
with m:
250+
trace = pm.sample_smc(return_inferencedata=False)
251+
252+
assert np.abs(true_a - trace["sim1"].mean()) < 0.1
253+
254+
def test_upstream_rngs_not_in_compiled_logp(self):
255+
smc = IMH(model=self.SMABC_test)
256+
smc.initialize_population()
257+
smc._initialize_kernel()
258+
likelihood_func = smc.likelihood_logp_func
259+
260+
# Test graph is stochastic
261+
inarray = floatX(np.array([0, 0]))
262+
assert likelihood_func(inarray) != likelihood_func(inarray)
263+
264+
# Test only one shared RNG is present
265+
compiled_graph = likelihood_func.maker.fgraph.outputs
266+
shared_rng_vars = [
267+
node
268+
for node in ancestors(compiled_graph)
269+
if isinstance(node, (RandomStateSharedVariable, RandomGeneratorSharedVariable))
270+
]
271+
assert len(shared_rng_vars) == 1
272+
273+
def test_simulator_error_msg(self):
274+
msg = "The distance metric not_real is not implemented"
275+
with pytest.raises(ValueError, match=msg):
276+
with pm.Model() as m:
277+
sim = pm.Simulator("sim", self.normal_sim, 0, 1, distance="not_real")
278+
279+
msg = "The summary statistic not_real is not implemented"
280+
with pytest.raises(ValueError, match=msg):
281+
with pm.Model() as m:
282+
sim = pm.Simulator("sim", self.normal_sim, 0, 1, sum_stat="not_real")
283+
284+
msg = "Cannot pass both unnamed parameters and `params`"
285+
with pytest.raises(ValueError, match=msg):
286+
with pm.Model() as m:
287+
sim = pm.Simulator("sim", self.normal_sim, 0, params=(1))
288+
289+
@pytest.mark.xfail(reason="KL not refactored")
290+
def test_automatic_use_of_sort(self):
291+
with pm.Model() as model:
292+
s_k = pm.Simulator(
293+
"s_k",
294+
None,
295+
params=None,
296+
distance="kullback_leibler",
297+
sum_stat="sort",
298+
observed=self.data,
299+
)
300+
assert s_k.distribution.sum_stat is pm.distributions.simulator.identity
301+
302+
def test_name_is_string_type(self):
303+
with self.SMABC_potential:
304+
assert not self.SMABC_potential.name
305+
with warnings.catch_warnings():
306+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
307+
warnings.filterwarnings(
308+
"ignore", "invalid value encountered in true_divide", RuntimeWarning
309+
)
310+
trace = pm.sample_smc(draws=10, chains=1, return_inferencedata=False)
311+
assert isinstance(trace._straces[0].name, str)
312+
313+
def test_named_model(self):
314+
# Named models used to fail with Simulator because the arguments to the
315+
# random fn used to be passed by name. This is no longer true.
316+
# https://github.com/pymc-devs/pymc/pull/4365#issuecomment-761221146
317+
name = "NamedModel"
318+
with pm.Model(name=name):
319+
a = pm.Normal("a", mu=0, sigma=1)
320+
b = pm.HalfNormal("b", sigma=1)
321+
s = pm.Simulator("s", self.normal_sim, a, b, observed=self.data)
22322

323+
with warnings.catch_warnings():
324+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
325+
trace = pm.sample_smc(draws=10, chains=2, return_inferencedata=False)
326+
assert f"{name}::a" in trace.varnames
327+
assert f"{name}::b" in trace.varnames
328+
assert f"{name}::b_log__" in trace.varnames
23329

24-
class TestMoments:
25330
@pytest.mark.parametrize("mu", [0, np.arange(3)], ids=str)
26331
@pytest.mark.parametrize("sigma", [1, np.array([1, 2, 5])], ids=str)
27332
@pytest.mark.parametrize("size", [None, 3, (5, 3)], ids=str)

0 commit comments

Comments
 (0)