@@ -2122,40 +2122,6 @@ def test_dirichlet_invalid(self):
2122
2122
valid_dist = Dirichlet .dist (a = [1 , 1 , 1 ])
2123
2123
assert np .all (np .isfinite (pm .logp (valid_dist , value ).eval ()) == np .array ([True , False ]))
2124
2124
2125
- @pytest .mark .parametrize (
2126
- "value,alpha,K,logp" ,
2127
- [
2128
- (np .array ([5 , 4 , 3 , 2 , 1 ]) / 15 , 0.5 , 4 , 1.5126301307277439 ),
2129
- (np .tile (1 , 13 ) / 13 , 2 , 12 , 13.980045245672827 ),
2130
- (np .array ([0.001 ] * 10 + [0.99 ]), 0.1 , 10 , - 22.971662448814723 ),
2131
- (np .append (0.5 ** np .arange (1 , 20 ), 0.5 ** 20 ), 5 , 19 , 94.20462772778092 ),
2132
- (
2133
- (np .array ([[7 , 5 , 3 , 2 ], [19 , 17 , 13 , 11 ]]) / np .array ([[17 ], [60 ]])),
2134
- 2.5 ,
2135
- 3 ,
2136
- np .array ([1.29317672 , 1.50126157 ]),
2137
- ),
2138
- ],
2139
- )
2140
- def test_stickbreakingweights_logp (self , value , alpha , K , logp ):
2141
- with Model () as model :
2142
- sbw = StickBreakingWeights ("sbw" , alpha = alpha , K = K , transform = None )
2143
- pt = {"sbw" : value }
2144
- assert_almost_equal (
2145
- pm .logp (sbw , value ).eval (),
2146
- logp ,
2147
- decimal = select_by_precision (float64 = 6 , float32 = 2 ),
2148
- err_msg = str (pt ),
2149
- )
2150
-
2151
- def test_stickbreakingweights_invalid (self ):
2152
- sbw = pm .StickBreakingWeights .dist (3.0 , 3 )
2153
- sbw_wrong_K = pm .StickBreakingWeights .dist (3.0 , 7 )
2154
- assert pm .logp (sbw , np .array ([0.4 , 0.3 , 0.2 , 0.15 ])).eval () == - np .inf
2155
- assert pm .logp (sbw , np .array ([1.1 , 0.3 , 0.2 , 0.1 ])).eval () == - np .inf
2156
- assert pm .logp (sbw , np .array ([0.4 , 0.3 , 0.2 , - 0.1 ])).eval () == - np .inf
2157
- assert pm .logp (sbw_wrong_K , np .array ([0.4 , 0.3 , 0.2 , 0.1 ])).eval () == - np .inf
2158
-
2159
2125
@pytest .mark .parametrize (
2160
2126
"a" ,
2161
2127
[
@@ -2318,6 +2284,40 @@ def test_dirichlet_multinomial_vectorized(self, n, a, size):
2318
2284
err_msg = f"vals={ vals } " ,
2319
2285
)
2320
2286
2287
+ @pytest .mark .parametrize (
2288
+ "value,alpha,K,logp" ,
2289
+ [
2290
+ (np .array ([5 , 4 , 3 , 2 , 1 ]) / 15 , 0.5 , 4 , 1.5126301307277439 ),
2291
+ (np .tile (1 , 13 ) / 13 , 2 , 12 , 13.980045245672827 ),
2292
+ (np .array ([0.001 ] * 10 + [0.99 ]), 0.1 , 10 , - 22.971662448814723 ),
2293
+ (np .append (0.5 ** np .arange (1 , 20 ), 0.5 ** 20 ), 5 , 19 , 94.20462772778092 ),
2294
+ (
2295
+ (np .array ([[7 , 5 , 3 , 2 ], [19 , 17 , 13 , 11 ]]) / np .array ([[17 ], [60 ]])),
2296
+ 2.5 ,
2297
+ 3 ,
2298
+ np .array ([1.29317672 , 1.50126157 ]),
2299
+ ),
2300
+ ],
2301
+ )
2302
+ def test_stickbreakingweights_logp (self , value , alpha , K , logp ):
2303
+ with Model () as model :
2304
+ sbw = StickBreakingWeights ("sbw" , alpha = alpha , K = K , transform = None )
2305
+ pt = {"sbw" : value }
2306
+ assert_almost_equal (
2307
+ pm .logp (sbw , value ).eval (),
2308
+ logp ,
2309
+ decimal = select_by_precision (float64 = 6 , float32 = 2 ),
2310
+ err_msg = str (pt ),
2311
+ )
2312
+
2313
+ def test_stickbreakingweights_invalid (self ):
2314
+ sbw = pm .StickBreakingWeights .dist (3.0 , 3 )
2315
+ sbw_wrong_K = pm .StickBreakingWeights .dist (3.0 , 7 )
2316
+ assert pm .logp (sbw , np .array ([0.4 , 0.3 , 0.2 , 0.15 ])).eval () == - np .inf
2317
+ assert pm .logp (sbw , np .array ([1.1 , 0.3 , 0.2 , 0.1 ])).eval () == - np .inf
2318
+ assert pm .logp (sbw , np .array ([0.4 , 0.3 , 0.2 , - 0.1 ])).eval () == - np .inf
2319
+ assert pm .logp (sbw_wrong_K , np .array ([0.4 , 0.3 , 0.2 , 0.1 ])).eval () == - np .inf
2320
+
2321
2321
@aesara .config .change_flags (compute_test_value = "raise" )
2322
2322
def test_categorical_bounds (self ):
2323
2323
with Model ():
0 commit comments