Skip to content

Commit a135893

Browse files
Add 2d test for nonstandard dims
1 parent 2ebff6f commit a135893

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

tests/inference/laplace_approx/test_laplace.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,50 @@ def test_model_with_nonstandard_dimensionality(rng):
230230
assert "class" in list(idata.unconstrained_posterior.sigma_log__.coords.keys())
231231

232232

233+
def test_laplace_nonstandard_dims_2d():
234+
true_P = np.array([[0.5, 0.3, 0.2], [0.1, 0.6, 0.3], [0.2, 0.4, 0.4]])
235+
y_obs = pm.draw(
236+
pmx.DiscreteMarkovChain.dist(
237+
P=true_P,
238+
init_dist=pm.Categorical.dist(
239+
logit_p=np.ones(
240+
3,
241+
)
242+
),
243+
shape=(100, 5),
244+
)
245+
)
246+
247+
with pm.Model(
248+
coords={
249+
"time": range(y_obs.shape[0]),
250+
"state": list("ABC"),
251+
"next_state": list("ABC"),
252+
"unit": [1, 2, 3, 4, 5],
253+
}
254+
) as model:
255+
y = pm.Data("y", y_obs, dims=["time", "unit"])
256+
init_dist = pm.Categorical.dist(
257+
logit_p=np.ones(
258+
3,
259+
)
260+
)
261+
P = pm.Dirichlet("P", a=np.eye(3) * 2 + 1, dims=["state", "next_state"])
262+
y_hat = pmx.DiscreteMarkovChain(
263+
"y_hat", P=P, init_dist=init_dist, dims=["time", "unit"], observed=y_obs
264+
)
265+
266+
idata = pmx.fit_laplace(progressbar=True)
267+
268+
# The simplex transform should drop from the right-most dimension, so the left dimension should be unmodified
269+
assert "state" in list(idata.unconstrained_posterior.P_simplex__.coords.keys())
270+
271+
# The mutated dimension should be unknown coords
272+
assert "P_simplex___dim_1" in list(idata.unconstrained_posterior.P_simplex__.coords.keys())
273+
274+
assert idata.unconstrained_posterior.P_simplex__.shape[-2:] == (3, 2)
275+
276+
233277
def test_laplace_nonscalar_rv_without_dims():
234278
with pm.Model(coords={"test": ["A", "B", "C"]}) as model:
235279
x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"])

0 commit comments

Comments
 (0)