17
17
from scipy .special import log_softmax , logsumexp
18
18
from scipy .stats import halfnorm , norm
19
19
20
- from pymc_experimental .model .marginal .graph_analysis import is_conditional_dependent
21
20
from pymc_experimental .model .marginal .marginal_model import (
22
21
MarginalModel ,
23
22
marginalize ,
24
23
)
25
24
from tests .utils import equal_computations_up_to_root
26
25
27
26
28
- def test_marginalized_basic ():
27
+ def test_basic_marginalized_rv ():
29
28
data = [2 ] * 5
30
29
31
30
with MarginalModel () as m :
@@ -69,7 +68,8 @@ def test_marginalized_basic():
69
68
)
70
69
71
70
72
- def test_multiple_independent_marginalized_rvs ():
71
+ def test_one_to_one_marginalized_rvs ():
72
+ """Test case with multiple, independent marginalized RVs"""
73
73
with MarginalModel () as m :
74
74
sigma = pm .HalfNormal ("sigma" )
75
75
idx1 = pm .Bernoulli ("idx1" , p = 0.75 )
@@ -95,7 +95,7 @@ def test_multiple_independent_marginalized_rvs():
95
95
np .testing .assert_array_almost_equal (y_logp , y_ref_logp )
96
96
97
97
98
- def test_multiple_dependent_marginalized_rvs ():
98
+ def test_one_to_many_marginalized_rvs ():
99
99
"""Test that marginalization works when there is more than one dependent RV"""
100
100
with MarginalModel () as m :
101
101
sigma = pm .HalfNormal ("sigma" )
@@ -118,7 +118,37 @@ def test_multiple_dependent_marginalized_rvs():
118
118
np .testing .assert_array_almost_equal (logp_x_y , ref_logp_x_y )
119
119
120
120
121
- def test_rv_dependent_multiple_marginalized_rvs ():
121
+ def test_one_to_many_unaligned_marginalized_rvs ():
122
+ """Test that marginalization works when there is more than one dependent RV with batch dimensions that are not aligned"""
123
+
124
+ def build_model (build_batched : bool ):
125
+ with MarginalModel () as m :
126
+ if build_batched :
127
+ idx = pm .Bernoulli ("idx" , p = [0.75 , 0.4 ], shape = (3 , 2 ))
128
+ else :
129
+ idxs = [pm .Bernoulli (f"idx_{ i } " , p = (0.75 if i % 2 == 0 else 0.4 )) for i in range (6 )]
130
+ idx = pt .stack (idxs , axis = 0 ).reshape (3 , 2 )
131
+
132
+ x = pm .Normal ("x" , mu = idx .T [:, :, None ], shape = (2 , 3 , 1 ))
133
+ y = pm .Normal ("y" , mu = (idx * 2 - 1 ), shape = (1 , 3 , 2 ))
134
+
135
+ return m
136
+
137
+ m = build_model (build_batched = True )
138
+ with pytest .warns (UserWarning , match = "There are multiple dependent variables" ):
139
+ m .marginalize (["idx" ])
140
+
141
+ ref_m = build_model (build_batched = False )
142
+ ref_m .marginalize ([f"idx_{ i } " for i in range (6 )])
143
+
144
+ test_point = m .initial_point ()
145
+ np .testing .assert_allclose (
146
+ m .compile_logp ()(test_point ),
147
+ ref_m .compile_logp ()(test_point ),
148
+ )
149
+
150
+
151
+ def test_many_to_one_marginalized_rvs ():
122
152
"""Test when random variables depend on multiple marginalized variables"""
123
153
with MarginalModel () as m :
124
154
x = pm .Bernoulli ("x" , 0.1 )
@@ -133,13 +163,13 @@ def test_rv_dependent_multiple_marginalized_rvs():
133
163
np .testing .assert_allclose (np .exp (logp ({"z" : 2 })), 0.1 * 0.3 )
134
164
135
165
136
- @pytest .mark .parametrize ("batched" , (False , True ))
166
+ @pytest .mark .parametrize ("batched" , (False , "left" , "right" ))
137
167
def test_nested_marginalized_rvs (batched ):
138
168
"""Test that marginalization works when there are nested marginalized RVs"""
139
169
140
170
def build_model (build_batched : bool ) -> MarginalModel :
141
171
idx_shape = (3 ,) if build_batched else ()
142
- sub_idx_shape = (3 , 5 ) if build_batched else (5 ,)
172
+ sub_idx_shape = (5 , ) if not build_batched else (5 , 3 ) if batched == "left" else ( 3 , 5 )
143
173
144
174
with MarginalModel () as m :
145
175
sigma = pm .HalfNormal ("sigma" )
@@ -148,9 +178,9 @@ def build_model(build_batched: bool) -> MarginalModel:
148
178
dep = pm .Normal ("dep" , mu = pt .switch (pt .eq (idx , 0 ), - 1000.0 , 1000.0 ), sigma = sigma )
149
179
150
180
sub_idx_p = pt .switch (pt .eq (idx , 0 ), 0.15 , 0.95 )
151
- if build_batched :
152
- sub_idx_p = sub_idx_p [: , None ]
153
- dep = dep [: , None ]
181
+ if build_batched and batched == "right" :
182
+ sub_idx_p = sub_idx_p [... , None ]
183
+ dep = dep [... , None ]
154
184
sub_idx = pm .Bernoulli ("sub_idx" , p = sub_idx_p , shape = sub_idx_shape )
155
185
sub_dep = pm .Normal ("sub_dep" , mu = dep + sub_idx * 100 , sigma = sigma )
156
186
@@ -204,22 +234,22 @@ def test_marginalized_index_as_key(advanced_indexing):
204
234
205
235
with MarginalModel () as m :
206
236
x = pm .Categorical ("x" , p = w , shape = shape )
207
- y = pm .Normal ("y" , mu [x ], sigma = 1 , observed = y_val )
237
+ y = pm .Normal ("y" , mu [x ]. T , sigma = 1 , observed = y_val )
208
238
209
239
m .marginalize (x )
210
240
211
241
marginal_logp = m .compile_logp (sum = False )({})[0 ]
212
- ref_logp = pm .logp (pm .NormalMixture .dist (w = w , mu = mu , sigma = 1 , shape = shape ), y_val ).eval ()
242
+ ref_logp = pm .logp (pm .NormalMixture .dist (w = w , mu = mu . T , sigma = 1 , shape = shape ), y_val ).eval ()
213
243
214
244
np .testing .assert_allclose (marginal_logp , ref_logp )
215
245
216
246
217
247
def test_marginalized_index_as_value_and_key ():
218
248
"""Test we can marginalize graphs were marginalized_rv is indexed."""
219
249
220
- def build_model (batch : bool ) -> MarginalModel :
250
+ def build_model (build_batched : bool ) -> MarginalModel :
221
251
with MarginalModel () as m :
222
- if batch :
252
+ if build_batched :
223
253
latent_state = pm .Bernoulli ("latent_state" , p = 0.3 , size = (4 ,))
224
254
else :
225
255
latent_state = pm .math .stack (
@@ -237,8 +267,8 @@ def build_model(batch: bool) -> MarginalModel:
237
267
return m
238
268
239
269
# We compare with the equivalent but less efficient batched model
240
- m = build_model (batch = True )
241
- ref_m = build_model (batch = False )
270
+ m = build_model (build_batched = True )
271
+ ref_m = build_model (build_batched = False )
242
272
243
273
m .marginalize (["latent_state" ])
244
274
ref_m .marginalize ([f"latent_state_{ i } " for i in range (4 )])
@@ -317,6 +347,14 @@ def test_mixed_dims_via_support_dimension(self):
317
347
with pytest .raises (NotImplementedError ):
318
348
m .marginalize (x )
319
349
350
+ def test_mixed_dims_via_nested_marginalization (self ):
351
+ with MarginalModel () as m :
352
+ x = pm .Bernoulli ("x" , p = 0.7 , shape = (3 ,))
353
+ y = pm .Bernoulli ("y" , p = 0.7 , shape = (2 ,))
354
+ z = pm .Normal ("z" , mu = pt .add .outer (x , y ), shape = (3 , 2 ))
355
+ with pytest .raises (NotImplementedError ):
356
+ m .marginalize ([x , y ])
357
+
320
358
321
359
def test_marginalized_deterministic_and_potential ():
322
360
rng = np .random .default_rng (299 )
@@ -432,17 +470,6 @@ def test_marginalized_transforms(transform, expected_warning):
432
470
np .testing .assert_allclose (m .compile_logp ()(ip ), m_ref .compile_logp ()(ip ))
433
471
434
472
435
- def test_is_conditional_dependent_static_shape ():
436
- """Test that we don't consider dependencies through "constant" shape Ops"""
437
- x1 = pt .matrix ("x1" , shape = (None , 5 ))
438
- y1 = pt .random .normal (size = pt .shape (x1 ))
439
- assert is_conditional_dependent (y1 , x1 , [x1 , y1 ])
440
-
441
- x2 = pt .matrix ("x2" , shape = (9 , 5 ))
442
- y2 = pt .random .normal (size = pt .shape (x2 ))
443
- assert not is_conditional_dependent (y2 , x2 , [x2 , y2 ])
444
-
445
-
446
473
def test_data_container ():
447
474
"""Test that MarginalModel can handle Data containers."""
448
475
with MarginalModel (coords = {"obs" : [0 ]}) as marginal_m :
@@ -469,49 +496,6 @@ def test_data_container():
469
496
np .testing .assert_allclose (logp_fn (ip ), ref_logp_fn (ip ))
470
497
471
498
472
- @pytest .mark .parametrize ("univariate" , (True , False ))
473
- def test_vector_univariate_mixture (univariate ):
474
- with MarginalModel () as m :
475
- idx = pm .Bernoulli ("idx" , p = 0.5 , shape = (2 ,) if univariate else ())
476
-
477
- def dist (idx , size ):
478
- return pm .math .switch (
479
- pm .math .eq (idx , 0 ),
480
- pm .Normal .dist ([- 10 , - 10 ], 1 ),
481
- pm .Normal .dist ([10 , 10 ], 1 ),
482
- )
483
-
484
- pm .CustomDist ("norm" , idx , dist = dist )
485
-
486
- m .marginalize (idx )
487
- logp_fn = m .compile_logp ()
488
-
489
- if univariate :
490
- with pm .Model () as ref_m :
491
- pm .NormalMixture ("norm" , w = [0.5 , 0.5 ], mu = [[- 10 , 10 ], [- 10 , 10 ]], shape = (2 ,))
492
- else :
493
- with pm .Model () as ref_m :
494
- pm .Mixture (
495
- "norm" ,
496
- w = [0.5 , 0.5 ],
497
- comp_dists = [
498
- pm .MvNormal .dist ([- 10 , - 10 ], np .eye (2 )),
499
- pm .MvNormal .dist ([10 , 10 ], np .eye (2 )),
500
- ],
501
- shape = (2 ,),
502
- )
503
- ref_logp_fn = ref_m .compile_logp ()
504
-
505
- for test_value in (
506
- [- 10 , - 10 ],
507
- [10 , 10 ],
508
- [- 10 , 10 ],
509
- [- 10 , 10 ],
510
- ):
511
- pt = {"norm" : test_value }
512
- np .testing .assert_allclose (logp_fn (pt ), ref_logp_fn (pt ))
513
-
514
-
515
499
def test_mutable_indexing_jax_backend ():
516
500
pytest .importorskip ("jax" )
517
501
from pymc .sampling .jax import get_jaxified_logp
@@ -631,11 +615,51 @@ def test_change_point_model_sampling(self, disaster_model):
631
615
rtol = 1e-2 ,
632
616
)
633
617
634
- @pytest .mark .parametrize (
635
- "batch_right" , (True , pytest .param (False , marks = pytest .mark .xfail (reason = "NotImplemented" )))
636
- )
618
+ @pytest .mark .parametrize ("univariate" , (True , False ))
619
+ def test_vector_univariate_mixture (self , univariate ):
620
+ with MarginalModel () as m :
621
+ idx = pm .Bernoulli ("idx" , p = 0.5 , shape = (2 ,) if univariate else ())
622
+
623
+ def dist (idx , size ):
624
+ return pm .math .switch (
625
+ pm .math .eq (idx , 0 ),
626
+ pm .Normal .dist ([- 10 , - 10 ], 1 ),
627
+ pm .Normal .dist ([10 , 10 ], 1 ),
628
+ )
629
+
630
+ pm .CustomDist ("norm" , idx , dist = dist )
631
+
632
+ m .marginalize (idx )
633
+ logp_fn = m .compile_logp ()
634
+
635
+ if univariate :
636
+ with pm .Model () as ref_m :
637
+ pm .NormalMixture ("norm" , w = [0.5 , 0.5 ], mu = [[- 10 , 10 ], [- 10 , 10 ]], shape = (2 ,))
638
+ else :
639
+ with pm .Model () as ref_m :
640
+ pm .Mixture (
641
+ "norm" ,
642
+ w = [0.5 , 0.5 ],
643
+ comp_dists = [
644
+ pm .MvNormal .dist ([- 10 , - 10 ], np .eye (2 )),
645
+ pm .MvNormal .dist ([10 , 10 ], np .eye (2 )),
646
+ ],
647
+ shape = (2 ,),
648
+ )
649
+ ref_logp_fn = ref_m .compile_logp ()
650
+
651
+ for test_value in (
652
+ [- 10 , - 10 ],
653
+ [10 , 10 ],
654
+ [- 10 , 10 ],
655
+ [- 10 , 10 ],
656
+ ):
657
+ pt = {"norm" : test_value }
658
+ np .testing .assert_allclose (logp_fn (pt ), ref_logp_fn (pt ))
659
+
660
+ @pytest .mark .parametrize ("batch_right" , (True , False ))
637
661
def test_k_censored_clusters_model (self , batch_right ):
638
- def build_model (batch : bool ) -> MarginalModel :
662
+ def build_model (build_batched : bool ) -> MarginalModel :
639
663
data = np .array ([[- 1.0 , - 1.0 ], [0.0 , 0.0 ], [1.0 , 1.0 ]])
640
664
nobs = data .shape [0 ]
641
665
n_clusters = 5
@@ -645,7 +669,7 @@ def build_model(batch: bool) -> MarginalModel:
645
669
"obs" : range (nobs ),
646
670
}
647
671
with MarginalModel (coords = coords ) as m :
648
- if batch :
672
+ if build_batched :
649
673
idx = pm .Categorical ("idx" , p = np .ones (n_clusters ) / n_clusters , dims = ["obs" ])
650
674
else :
651
675
idx = pm .math .stack (
@@ -682,8 +706,8 @@ def build_model(batch: bool) -> MarginalModel:
682
706
683
707
return m
684
708
685
- m = build_model (batch = True )
686
- ref_m = build_model (batch = False )
709
+ m = build_model (build_batched = True )
710
+ ref_m = build_model (build_batched = False )
687
711
688
712
m .marginalize ([m ["idx" ]])
689
713
ref_m .marginalize ([n for n in ref_m .named_vars if n .startswith ("idx_" )])
0 commit comments