@@ -86,6 +86,15 @@ class Simulator(Distribution):
86
86
Keyword form of ''unnamed_params''.
87
87
One of unnamed_params or params must be provided.
88
88
If passed both unnamed_params and params, an error is raised.
89
+ class_name : str
90
+ Name for the RandomVariable class which will wrap the Simulator methods.
91
+ When not specified, it will be given the name of the variable.
92
+
93
+ .. warning:: New Simulators created with the same class_name will override the
94
+ methods dispatched onto the previous classes. If using Simulators with
95
+ different methods across separate models, be sure to use distinct
96
+ class_names.
97
+
89
98
distance : Aesara_Op, callable or str, default "gaussian"
90
99
Distance function. Available options are ``"gaussian"``, ``"laplace"``,
91
100
``"kullback_leibler"`` or a user defined function (or Aesara_Op) that takes
@@ -137,12 +146,19 @@ def simulator_fn(rng, loc, scale, size):
137
146
138
147
"""
139
148
140
- def __new__ (
149
+ rv_type = SimulatorRV
150
+
151
+ def __new__ (cls , name , * args , ** kwargs ):
152
+ kwargs .setdefault ("class_name" , name )
153
+ return super ().__new__ (cls , name , * args , ** kwargs )
154
+
155
+ @classmethod
156
+ def dist ( # type: ignore
141
157
cls ,
142
- name ,
143
158
fn ,
144
159
* unnamed_params ,
145
160
params = None ,
161
+ class_name : str ,
146
162
distance = "gaussian" ,
147
163
sum_stat = "identity" ,
148
164
epsilon = 1 ,
@@ -196,11 +212,38 @@ def __new__(
196
212
if ndims_params is None :
197
213
ndims_params = [0 ] * len (params )
198
214
215
+ return super ().dist (
216
+ params ,
217
+ class_name = class_name ,
218
+ fn = fn ,
219
+ ndim_supp = ndim_supp ,
220
+ ndims_params = ndims_params ,
221
+ dtype = dtype ,
222
+ distance = distance ,
223
+ sum_stat = sum_stat ,
224
+ epsilon = epsilon ,
225
+ ** kwargs ,
226
+ )
227
+
228
+ @classmethod
229
+ def rv_op (
230
+ cls ,
231
+ * params ,
232
+ class_name ,
233
+ fn ,
234
+ ndim_supp ,
235
+ ndims_params ,
236
+ dtype ,
237
+ distance ,
238
+ sum_stat ,
239
+ epsilon ,
240
+ ** kwargs ,
241
+ ):
199
242
sim_op = type (
200
- f"Simulator_{ name } " ,
243
+ f"Simulator_{ class_name } " ,
201
244
(SimulatorRV ,),
202
245
dict (
203
- name = "Simulator " ,
246
+ name = f"Simulator_ { class_name } " ,
204
247
ndim_supp = ndim_supp ,
205
248
ndims_params = ndims_params ,
206
249
dtype = dtype ,
@@ -211,50 +254,35 @@ def __new__(
211
254
epsilon = epsilon ,
212
255
),
213
256
)()
214
-
215
- # The logp function is registered to the more general SimulatorRV,
216
- # in order to avoid issues with multiprocessing / pickling,
217
- # otherwise it would be registered to `type(sim_op)`
218
-
219
- @_logprob .register (SimulatorRV )
220
- def logp (op , value_var_list , * dist_params , ** kwargs ):
221
- _dist_params = dist_params [3 :]
222
- value_var = value_var_list [0 ]
223
- return cls .logp (value_var , op , dist_params )
224
-
225
- @_moment .register (SimulatorRV )
226
- def moment (op , rv , rng , size , dtype , * rv_inputs ):
227
- return cls .moment (rv , * rv_inputs )
228
-
229
- cls .rv_op = sim_op
230
- return super ().__new__ (cls , name , * params , ** kwargs )
231
-
232
- @classmethod
233
- def dist (cls , * params , ** kwargs ):
234
- return super ().dist (params , ** kwargs )
235
-
236
- @classmethod
237
- def moment (cls , rv , * sim_inputs ):
238
- # Take the mean of 10 draws
239
- multiple_sim = rv .owner .op (* sim_inputs , size = at .concatenate ([[10 ], rv .shape ]))
240
- return at .mean (multiple_sim , axis = 0 )
241
-
242
- @classmethod
243
- def logp (cls , value , sim_op , sim_inputs ):
244
- # Use a new rng to avoid non-randomness in parallel sampling
245
- # TODO: Model rngs should be updated prior to multiprocessing split,
246
- # in which case this would not be needed. However, that would have to be
247
- # done for every sampler that may accomodate Simulators
248
- rng = aesara .shared (np .random .default_rng (), name = "simulator_rng" )
249
- # Create a new simulatorRV with identical inputs as the original one
250
- sim_value = sim_op .make_node (rng , * sim_inputs [1 :]).default_output ()
251
- sim_value .name = "simulator_value"
252
-
253
- return sim_op .distance (
254
- sim_op .epsilon ,
255
- sim_op .sum_stat (value ),
256
- sim_op .sum_stat (sim_value ),
257
- )
257
+ return sim_op (* params , ** kwargs )
258
+
259
+
260
+ @_moment .register (SimulatorRV ) # type: ignore
261
+ def simulator_moment (op , rv , * inputs ):
262
+ sim_inputs = inputs [3 :]
263
+ # Take the mean of 10 draws
264
+ multiple_sim = rv .owner .op (* sim_inputs , size = at .concatenate ([[10 ], rv .shape ]))
265
+ return at .mean (multiple_sim , axis = 0 )
266
+
267
+
268
+ @_logprob .register (SimulatorRV )
269
+ def simulator_logp (op , values , * inputs , ** kwargs ):
270
+ (value ,) = values
271
+
272
+ # Use a new rng to avoid non-randomness in parallel sampling
273
+ # TODO: Model rngs should be updated prior to multiprocessing split,
274
+ # in which case this would not be needed. However, that would have to be
275
+ # done for every sampler that may accomodate Simulators
276
+ rng = aesara .shared (np .random .default_rng (), name = "simulator_rng" )
277
+ # Create a new simulatorRV with identical inputs as the original one
278
+ sim_value = op .make_node (rng , * inputs [1 :]).default_output ()
279
+ sim_value .name = "simulator_value"
280
+
281
+ return op .distance (
282
+ op .epsilon ,
283
+ op .sum_stat (value ),
284
+ op .sum_stat (sim_value ),
285
+ )
258
286
259
287
260
288
def identity (x ):
0 commit comments