9
9
from arviz import InferenceData , dict_to_dataset
10
10
from pymc .backends .arviz import coords_and_dims_for_inferencedata , dataset_to_point_list
11
11
from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
12
+ from pymc .distributions .multivariate import MvNormal
12
13
from pymc .distributions .transforms import Chain
13
14
from pymc .logprob .transforms import IntervalTransform
14
15
from pymc .model import Model
45
46
from pymc_extras .model .marginal .distributions import (
46
47
MarginalDiscreteMarkovChainRV ,
47
48
MarginalFiniteDiscreteRV ,
49
+ MarginalLaplaceRV ,
48
50
MarginalRV ,
49
51
NonSeparableLogpWarning ,
50
52
get_domain_of_finite_discrete_rv ,
@@ -144,7 +146,9 @@ def _unique(seq: Sequence) -> list:
144
146
return [x for x in seq if not (x in seen or seen_add (x ))]
145
147
146
148
147
- def marginalize (model : Model , rvs_to_marginalize : ModelRVs ) -> MarginalModel :
149
+ def marginalize (
150
+ model : Model , rvs_to_marginalize : ModelRVs , use_laplace : bool = False , ** marginalize_kwargs
151
+ ) -> MarginalModel :
148
152
"""Marginalize a subset of variables in a PyMC model.
149
153
150
154
This creates a class of `MarginalModel` from an existing `Model`, with the specified
@@ -158,6 +162,8 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
158
162
PyMC model to marginalize. Original variables well be cloned.
159
163
rvs_to_marginalize : Sequence[TensorVariable]
160
164
Variables to marginalize in the returned model.
165
+ use_laplace : bool
166
+ Whether to use Laplace appoximations to marginalize out rvs_to_marginalize.
161
167
162
168
Returns
163
169
-------
@@ -186,7 +192,12 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
186
192
raise NotImplementedError (
187
193
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
188
194
)
189
- elif not isinstance (rv_op , Bernoulli | Categorical | DiscreteUniform ):
195
+ elif use_laplace and not isinstance (rv_op , MvNormal ):
196
+ raise ValueError (
197
+ f"Marginalisation method set to Laplace but RV { rv_to_marginalize } is not instance of MvNormal. Has distribution { rv_to_marginalize .owner .op } "
198
+ )
199
+
200
+ elif not use_laplace and not isinstance (rv_op , Bernoulli | Categorical | DiscreteUniform ):
190
201
raise NotImplementedError (
191
202
f"Marginalization of RV with distribution { rv_to_marginalize .owner .op } is not supported"
192
203
)
@@ -241,7 +252,9 @@ def marginalize(model: Model, rvs_to_marginalize: ModelRVs) -> MarginalModel:
241
252
]
242
253
input_rvs = _unique ((* marginalized_rv_input_rvs , * other_direct_rv_ancestors ))
243
254
244
- replace_finite_discrete_marginal_subgraph (fg , rv_to_marginalize , dependent_rvs , input_rvs )
255
+ replace_marginal_subgraph (
256
+ fg , rv_to_marginalize , dependent_rvs , input_rvs , use_laplace , ** marginalize_kwargs
257
+ )
245
258
246
259
return model_from_fgraph (fg , mutate_fgraph = True )
247
260
@@ -551,22 +564,32 @@ def remove_model_vars(vars):
551
564
return fgraph .outputs
552
565
553
566
554
- def replace_finite_discrete_marginal_subgraph (
555
- fgraph , rv_to_marginalize , dependent_rvs , input_rvs
567
+ def replace_marginal_subgraph (
568
+ fgraph ,
569
+ rv_to_marginalize ,
570
+ dependent_rvs ,
571
+ input_rvs ,
572
+ use_laplace = False ,
573
+ ** marginalize_kwargs ,
556
574
) -> None :
557
575
# If the marginalized RV has multiple dimensions, check that graph between
558
576
# marginalized RV and dependent RVs does not mix information from batch dimensions
559
577
# (otherwise logp would require enumerating over all combinations of batch dimension values)
560
- try :
561
- dependent_rvs_dim_connections = subgraph_batch_dim_connection (
562
- rv_to_marginalize , dependent_rvs
563
- )
564
- except (ValueError , NotImplementedError ) as e :
565
- # For the perspective of the user this is a NotImplementedError
566
- raise NotImplementedError (
567
- "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
568
- "You can try splitting the marginalized RV into separate components and marginalizing them separately."
569
- ) from e
578
+ if not use_laplace :
579
+ try :
580
+ dependent_rvs_dim_connections = subgraph_batch_dim_connection (
581
+ rv_to_marginalize , dependent_rvs
582
+ )
583
+ except (ValueError , NotImplementedError ) as e :
584
+ # For the perspective of the user this is a NotImplementedError
585
+ raise NotImplementedError (
586
+ "The graph between the marginalized and dependent RVs cannot be marginalized efficiently. "
587
+ "You can try splitting the marginalized RV into separate components and marginalizing them separately."
588
+ ) from e
589
+ else :
590
+ dependent_rvs_dim_connections = [
591
+ (None ,),
592
+ ]
570
593
571
594
output_rvs = [rv_to_marginalize , * dependent_rvs ]
572
595
rng_updates = collect_default_updates (output_rvs , inputs = input_rvs , must_be_shared = False )
@@ -581,6 +604,8 @@ def replace_finite_discrete_marginal_subgraph(
581
604
582
605
if isinstance (inner_outputs [0 ].owner .op , DiscreteMarkovChain ):
583
606
marginalize_constructor = MarginalDiscreteMarkovChainRV
607
+ elif use_laplace :
608
+ marginalize_constructor = MarginalLaplaceRV
584
609
else :
585
610
marginalize_constructor = MarginalFiniteDiscreteRV
586
611
@@ -590,6 +615,7 @@ def replace_finite_discrete_marginal_subgraph(
590
615
outputs = inner_outputs ,
591
616
dims_connections = dependent_rvs_dim_connections ,
592
617
dims = dims ,
618
+ ** marginalize_kwargs ,
593
619
)
594
620
595
621
new_outputs = marginalization_op (* inputs )
0 commit comments