@@ -230,6 +230,50 @@ def test_model_with_nonstandard_dimensionality(rng):
230
230
assert "class" in list (idata .unconstrained_posterior .sigma_log__ .coords .keys ())
231
231
232
232
233
+ def test_laplace_nonstandard_dims_2d ():
234
+ true_P = np .array ([[0.5 , 0.3 , 0.2 ], [0.1 , 0.6 , 0.3 ], [0.2 , 0.4 , 0.4 ]])
235
+ y_obs = pm .draw (
236
+ pmx .DiscreteMarkovChain .dist (
237
+ P = true_P ,
238
+ init_dist = pm .Categorical .dist (
239
+ logit_p = np .ones (
240
+ 3 ,
241
+ )
242
+ ),
243
+ shape = (100 , 5 ),
244
+ )
245
+ )
246
+
247
+ with pm .Model (
248
+ coords = {
249
+ "time" : range (y_obs .shape [0 ]),
250
+ "state" : list ("ABC" ),
251
+ "next_state" : list ("ABC" ),
252
+ "unit" : [1 , 2 , 3 , 4 , 5 ],
253
+ }
254
+ ) as model :
255
+ y = pm .Data ("y" , y_obs , dims = ["time" , "unit" ])
256
+ init_dist = pm .Categorical .dist (
257
+ logit_p = np .ones (
258
+ 3 ,
259
+ )
260
+ )
261
+ P = pm .Dirichlet ("P" , a = np .eye (3 ) * 2 + 1 , dims = ["state" , "next_state" ])
262
+ y_hat = pmx .DiscreteMarkovChain (
263
+ "y_hat" , P = P , init_dist = init_dist , dims = ["time" , "unit" ], observed = y_obs
264
+ )
265
+
266
+ idata = pmx .fit_laplace (progressbar = True )
267
+
268
+ # The simplex transform should drop from the right-most dimension, so the left dimension should be unmodified
269
+ assert "state" in list (idata .unconstrained_posterior .P_simplex__ .coords .keys ())
270
+
271
+ # The mutated dimension should be unknown coords
272
+ assert "P_simplex___dim_1" in list (idata .unconstrained_posterior .P_simplex__ .coords .keys ())
273
+
274
+ assert idata .unconstrained_posterior .P_simplex__ .shape [- 2 :] == (3 , 2 )
275
+
276
+
233
277
def test_laplace_nonscalar_rv_without_dims ():
234
278
with pm .Model (coords = {"test" : ["A" , "B" , "C" ]}) as model :
235
279
x_loc = pm .Normal ("x_loc" , mu = 0 , sigma = 1 , dims = ["test" ])
0 commit comments