Skip to content

Commit 6d6cf83

Browse files
committed
specific tests for the alternative dict format
1 parent 965b458 commit 6d6cf83

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

pymc_marketing/prior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,7 @@ def from_dict(cls, data: dict[str, Any]) -> Censored:
12801280
"""Create a censored distribution from a dictionary."""
12811281
data = data["data"]
12821282
return cls( # type: ignore
1283-
distribution=Prior.from_dict(data["dist"]),
1283+
distribution=deserialize(data["dist"]),
12841284
lower=data["lower"],
12851285
upper=data["upper"],
12861286
)

tests/test_prior.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,3 +1169,40 @@ def test_scaled_sample_prior() -> None:
11691169
def test_prior_list_dims() -> None:
11701170
dist = Prior("Normal", dims=["channel", "geo"])
11711171
assert dist.dims == ("channel", "geo")
1172+
1173+
1174+
@pytest.mark.parametrize(
1175+
"data, expected",
1176+
[
1177+
pytest.param(
1178+
{
1179+
"distribution": "Laplace",
1180+
"mu": 1,
1181+
"b": 2,
1182+
"dims": ("x", "y"),
1183+
"transform": "sigmoid",
1184+
},
1185+
Prior("Laplace", mu=1, b=2, dims=("x", "y"), transform="sigmoid"),
1186+
id="Prior",
1187+
),
1188+
pytest.param(
1189+
{"distribution": "Normal", "mu": {"distribution": "Normal"}},
1190+
Prior("Normal", mu=Prior("Normal")),
1191+
id="Prior with nested distribution",
1192+
),
1193+
pytest.param(
1194+
{
1195+
"class": "Censored",
1196+
"data": {
1197+
"dist": {"distribution": "Normal"},
1198+
"lower": 0,
1199+
"upper": 10,
1200+
},
1201+
},
1202+
Censored(Prior("Normal"), lower=0, upper=10),
1203+
id="Censored with alternative",
1204+
),
1205+
],
1206+
)
1207+
def test_alternative_prior_deserialize(data, expected) -> None:
1208+
assert deserialize(data) == expected

0 commit comments

Comments
 (0)