21
21
from arviz .data .base import CoordSpec , DimSpec , dict_to_dataset , requires
22
22
from pytensor .graph .basic import Constant
23
23
from pytensor .tensor .sharedvar import SharedVariable
24
- from pytensor .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
25
24
26
25
import pymc
27
26
@@ -153,7 +152,7 @@ def __init__(
153
152
trace = None ,
154
153
prior = None ,
155
154
posterior_predictive = None ,
156
- log_likelihood = True ,
155
+ log_likelihood = False ,
157
156
predictions = None ,
158
157
coords : Optional [CoordSpec ] = None ,
159
158
dims : Optional [DimSpec ] = None ,
@@ -246,68 +245,6 @@ def split_trace(self) -> Tuple[Union[None, "MultiTrace"], Union[None, "MultiTrac
246
245
trace_posterior = self .trace [self .ntune :]
247
246
return trace_posterior , trace_warmup
248
247
249
- def log_likelihood_vals_point (self , point , var , log_like_fun ):
250
- """Compute log likelihood for each observed point."""
251
- # TODO: This is a cheap hack; we should filter-out the correct
252
- # variables some other way
253
- point = {i .name : point [i .name ] for i in log_like_fun .f .maker .inputs if i .name in point }
254
- log_like_val = np .atleast_1d (log_like_fun (point ))
255
-
256
- if isinstance (var .owner .op , (AdvancedIncSubtensor , AdvancedIncSubtensor1 )):
257
- try :
258
- obs_data = extract_obs_data (self .model .rvs_to_values [var ])
259
- except TypeError :
260
- warnings .warn (f"Could not extract data from symbolic observation { var } " )
261
-
262
- mask = obs_data .mask
263
- if np .ndim (mask ) > np .ndim (log_like_val ):
264
- mask = np .any (mask , axis = - 1 )
265
- log_like_val = np .where (mask , np .nan , log_like_val )
266
- return log_like_val
267
-
268
- def _extract_log_likelihood (self , trace ):
269
- """Compute log likelihood of each observation."""
270
- if self .trace is None :
271
- return None
272
- if self .model is None :
273
- return None
274
-
275
- # TODO: We no longer need one function per observed variable
276
- if self .log_likelihood is True :
277
- cached = [
278
- (
279
- var ,
280
- self .model .compile_fn (
281
- self .model .logp (var , sum = False )[0 ],
282
- inputs = self .model .value_vars ,
283
- on_unused_input = "ignore" ,
284
- ),
285
- )
286
- for var in self .model .observed_RVs
287
- ]
288
- else :
289
- cached = [
290
- (
291
- var ,
292
- self .model .compile_fn (
293
- self .model .logp (var , sum = False )[0 ],
294
- inputs = self .model .value_vars ,
295
- on_unused_input = "ignore" ,
296
- ),
297
- )
298
- for var in self .model .observed_RVs
299
- if var .name in self .log_likelihood
300
- ]
301
- log_likelihood_dict = _DefaultTrace (len (trace .chains ))
302
- for var , log_like_fun in cached :
303
- for k , chain in enumerate (trace .chains ):
304
- log_like_chain = [
305
- self .log_likelihood_vals_point (point , var , log_like_fun )
306
- for point in trace .points ([chain ])
307
- ]
308
- log_likelihood_dict .insert (var .name , np .stack (log_like_chain ), k )
309
- return log_likelihood_dict .trace_dict
310
-
311
248
@requires ("trace" )
312
249
def posterior_to_xarray (self ):
313
250
"""Convert the posterior to an xarray dataset."""
@@ -382,49 +319,6 @@ def sample_stats_to_xarray(self):
382
319
),
383
320
)
384
321
385
- @requires ("trace" )
386
- @requires ("model" )
387
- def log_likelihood_to_xarray (self ):
388
- """Extract log likelihood and log_p data from PyMC trace."""
389
- if self .predictions or not self .log_likelihood :
390
- return None
391
- data_warmup = {}
392
- data = {}
393
- warn_msg = (
394
- "Could not compute log_likelihood, it will be omitted. "
395
- "Check your model object or set log_likelihood=False"
396
- )
397
- if self .posterior_trace :
398
- try :
399
- data = self ._extract_log_likelihood (self .posterior_trace )
400
- except TypeError :
401
- warnings .warn (warn_msg )
402
- if self .warmup_trace :
403
- try :
404
- data_warmup = self ._extract_log_likelihood (self .warmup_trace )
405
- except TypeError :
406
- warnings .warn (warn_msg )
407
- return (
408
- dict_to_dataset (
409
- data ,
410
- library = pymc ,
411
- dims = self .dims ,
412
- coords = self .coords ,
413
- skip_event_dims = True ,
414
- ),
415
- dict_to_dataset (
416
- data_warmup ,
417
- library = pymc ,
418
- dims = self .dims ,
419
- coords = self .coords ,
420
- skip_event_dims = True ,
421
- ),
422
- )
423
-
424
- return dict_to_dataset (
425
- data , library = pymc , coords = self .coords , dims = self .dims , default_dims = self .sample_dims
426
- )
427
-
428
322
@requires (["posterior_predictive" ])
429
323
def posterior_predictive_to_xarray (self ):
430
324
"""Convert posterior_predictive samples to xarray."""
@@ -509,7 +403,6 @@ def to_inference_data(self):
509
403
id_dict = {
510
404
"posterior" : self .posterior_to_xarray (),
511
405
"sample_stats" : self .sample_stats_to_xarray (),
512
- "log_likelihood" : self .log_likelihood_to_xarray (),
513
406
"posterior_predictive" : self .posterior_predictive_to_xarray (),
514
407
"predictions" : self .predictions_to_xarray (),
515
408
** self .priors_to_xarray (),
@@ -519,15 +412,27 @@ def to_inference_data(self):
519
412
id_dict ["predictions_constant_data" ] = self .constant_data_to_xarray ()
520
413
else :
521
414
id_dict ["constant_data" ] = self .constant_data_to_xarray ()
522
- return InferenceData (save_warmup = self .save_warmup , ** id_dict )
415
+ idata = InferenceData (save_warmup = self .save_warmup , ** id_dict )
416
+ if self .log_likelihood :
417
+ from pymc .stats .log_likelihood import compute_log_likelihood
418
+
419
+ idata = compute_log_likelihood (
420
+ idata ,
421
+ var_names = None if self .log_likelihood is True else self .log_likelihood ,
422
+ extend_inferencedata = True ,
423
+ model = self .model ,
424
+ sample_dims = self .sample_dims ,
425
+ progressbar = False ,
426
+ )
427
+ return idata
523
428
524
429
525
430
def to_inference_data (
526
431
trace : Optional ["MultiTrace" ] = None ,
527
432
* ,
528
433
prior : Optional [Mapping [str , Any ]] = None ,
529
434
posterior_predictive : Optional [Mapping [str , Any ]] = None ,
530
- log_likelihood : Union [bool , Iterable [str ]] = True ,
435
+ log_likelihood : Union [bool , Iterable [str ]] = False ,
531
436
coords : Optional [CoordSpec ] = None ,
532
437
dims : Optional [DimSpec ] = None ,
533
438
sample_dims : Optional [List ] = None ,
0 commit comments