Skip to content

Commit 024db7c

Browse files
committed
convert list to tuple in Prior
1 parent a1f11ec commit 024db7c

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

pymc_marketing/prior.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,8 @@ def dims(self) -> Dims:
465465
def dims(self, dims) -> None:
466466
if isinstance(dims, str):
467467
dims = (dims,)
468+
elif isinstance(dims, list):
469+
dims = tuple(dims)
468470

469471
self._dims = dims or ()
470472

tests/test_prior.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,3 +1164,8 @@ def test_scaled_sample_prior() -> None:
11641164
assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3}
11651165
assert "scaled_var" in prior
11661166
assert "scaled_var_unscaled" in prior
1167+
1168+
1169+
def test_prior_list_dims() -> None:
1170+
dist = Prior("Normal", dims=["channel", "geo"])
1171+
assert dist.dims == ("channel", "geo")

0 commit comments

Comments
 (0)