Skip to content

Commit c6d565d

Browse files
committed
Enable dist API for Simulator
1 parent a3f44f5 commit c6d565d

File tree

2 files changed

+91
-48
lines changed

2 files changed

+91
-48
lines changed

pymc/distributions/simulator.py

Lines changed: 76 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ class Simulator(Distribution):
8686
Keyword form of ''unnamed_params''.
8787
One of unnamed_params or params must be provided.
8888
If passed both unnamed_params and params, an error is raised.
89+
class_name : str
90+
Name for the RandomVariable class which will wrap the Simulator methods.
91+
When not specified, it will be given the name of the variable.
92+
93+
.. warning:: New Simulators created with the same class_name will override the
94+
methods dispatched onto the previous classes. If using Simulators with
95+
different methods across separate models, be sure to use distinct
96+
class_names.
97+
8998
distance : Aesara_Op, callable or str, default "gaussian"
9099
Distance function. Available options are ``"gaussian"``, ``"laplace"``,
91100
``"kullback_leibler"`` or a user defined function (or Aesara_Op) that takes
@@ -137,12 +146,19 @@ def simulator_fn(rng, loc, scale, size):
137146
138147
"""
139148

140-
def __new__(
149+
rv_type = SimulatorRV
150+
151+
def __new__(cls, name, *args, **kwargs):
152+
kwargs.setdefault("class_name", name)
153+
return super().__new__(cls, name, *args, **kwargs)
154+
155+
@classmethod
156+
def dist( # type: ignore
141157
cls,
142-
name,
143158
fn,
144159
*unnamed_params,
145160
params=None,
161+
class_name: str,
146162
distance="gaussian",
147163
sum_stat="identity",
148164
epsilon=1,
@@ -196,11 +212,38 @@ def __new__(
196212
if ndims_params is None:
197213
ndims_params = [0] * len(params)
198214

215+
return super().dist(
216+
params,
217+
class_name=class_name,
218+
fn=fn,
219+
ndim_supp=ndim_supp,
220+
ndims_params=ndims_params,
221+
dtype=dtype,
222+
distance=distance,
223+
sum_stat=sum_stat,
224+
epsilon=epsilon,
225+
**kwargs,
226+
)
227+
228+
@classmethod
229+
def rv_op(
230+
cls,
231+
*params,
232+
class_name,
233+
fn,
234+
ndim_supp,
235+
ndims_params,
236+
dtype,
237+
distance,
238+
sum_stat,
239+
epsilon,
240+
**kwargs,
241+
):
199242
sim_op = type(
200-
f"Simulator_{name}",
243+
f"Simulator_{class_name}",
201244
(SimulatorRV,),
202245
dict(
203-
name="Simulator",
246+
name=f"Simulator_{class_name}",
204247
ndim_supp=ndim_supp,
205248
ndims_params=ndims_params,
206249
dtype=dtype,
@@ -211,50 +254,35 @@ def __new__(
211254
epsilon=epsilon,
212255
),
213256
)()
214-
215-
# The logp function is registered to the more general SimulatorRV,
216-
# in order to avoid issues with multiprocessing / pickling,
217-
# otherwise it would be registered to `type(sim_op)`
218-
219-
@_logprob.register(SimulatorRV)
220-
def logp(op, value_var_list, *dist_params, **kwargs):
221-
_dist_params = dist_params[3:]
222-
value_var = value_var_list[0]
223-
return cls.logp(value_var, op, dist_params)
224-
225-
@_moment.register(SimulatorRV)
226-
def moment(op, rv, rng, size, dtype, *rv_inputs):
227-
return cls.moment(rv, *rv_inputs)
228-
229-
cls.rv_op = sim_op
230-
return super().__new__(cls, name, *params, **kwargs)
231-
232-
@classmethod
233-
def dist(cls, *params, **kwargs):
234-
return super().dist(params, **kwargs)
235-
236-
@classmethod
237-
def moment(cls, rv, *sim_inputs):
238-
# Take the mean of 10 draws
239-
multiple_sim = rv.owner.op(*sim_inputs, size=at.concatenate([[10], rv.shape]))
240-
return at.mean(multiple_sim, axis=0)
241-
242-
@classmethod
243-
def logp(cls, value, sim_op, sim_inputs):
244-
# Use a new rng to avoid non-randomness in parallel sampling
245-
# TODO: Model rngs should be updated prior to multiprocessing split,
246-
# in which case this would not be needed. However, that would have to be
247-
# done for every sampler that may accomodate Simulators
248-
rng = aesara.shared(np.random.default_rng(), name="simulator_rng")
249-
# Create a new simulatorRV with identical inputs as the original one
250-
sim_value = sim_op.make_node(rng, *sim_inputs[1:]).default_output()
251-
sim_value.name = "simulator_value"
252-
253-
return sim_op.distance(
254-
sim_op.epsilon,
255-
sim_op.sum_stat(value),
256-
sim_op.sum_stat(sim_value),
257-
)
257+
return sim_op(*params, **kwargs)
258+
259+
260+
@_moment.register(SimulatorRV) # type: ignore
261+
def simulator_moment(op, rv, *inputs):
262+
sim_inputs = inputs[3:]
263+
# Take the mean of 10 draws
264+
multiple_sim = rv.owner.op(*sim_inputs, size=at.concatenate([[10], rv.shape]))
265+
return at.mean(multiple_sim, axis=0)
266+
267+
268+
@_logprob.register(SimulatorRV)
269+
def simulator_logp(op, values, *inputs, **kwargs):
270+
(value,) = values
271+
272+
# Use a new rng to avoid non-randomness in parallel sampling
273+
# TODO: Model rngs should be updated prior to multiprocessing split,
274+
# in which case this would not be needed. However, that would have to be
275+
# done for every sampler that may accomodate Simulators
276+
rng = aesara.shared(np.random.default_rng(), name="simulator_rng")
277+
# Create a new simulatorRV with identical inputs as the original one
278+
sim_value = op.make_node(rng, *inputs[1:]).default_output()
279+
sim_value.name = "simulator_value"
280+
281+
return op.distance(
282+
op.epsilon,
283+
op.sum_stat(value),
284+
op.sum_stat(sim_value),
285+
)
258286

259287

260288
def identity(x):

pymc/tests/distributions/test_simulator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import pymc as pm
3030

3131
from pymc import floatX
32+
from pymc.aesaraf import compile_pymc
3233
from pymc.initial_point import make_initial_point_fn
3334
from pymc.smc.smc import IMH
3435
from pymc.tests.helpers import SeededTest
@@ -363,3 +364,17 @@ def normal_sim(rng, mu, sigma, size):
363364
cutoff = st.norm().ppf(1 - (alpha / 2))
364365

365366
assert np.all(np.abs((result - expected_sample_mean) / expected_sample_mean_std) < cutoff)
367+
368+
def test_dist(self):
369+
x = pm.Simulator.dist(self.normal_sim, 0, 1, sum_stat="sort", shape=(3,), class_name="test")
370+
x_logp = pm.logp(x, [0, 1, 2])
371+
372+
x_logp_fn = compile_pymc([], x_logp, random_seed=1)
373+
res1, res2 = x_logp_fn(), x_logp_fn()
374+
assert res1.shape == (3,)
375+
assert np.all(res1 != res2)
376+
377+
x_logp_fn = compile_pymc([], x_logp, random_seed=1)
378+
res3, res4 = x_logp_fn(), x_logp_fn()
379+
assert np.all(res1 == res3)
380+
assert np.all(res2 == res4)

0 commit comments

Comments
 (0)