We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a1f11ec commit 024db7cCopy full SHA for 024db7c
pymc_marketing/prior.py
@@ -465,6 +465,8 @@ def dims(self) -> Dims:
465
def dims(self, dims) -> None:
466
if isinstance(dims, str):
467
dims = (dims,)
468
+ elif isinstance(dims, list):
469
+ dims = tuple(dims)
470
471
self._dims = dims or ()
472
tests/test_prior.py
@@ -1164,3 +1164,8 @@ def test_scaled_sample_prior() -> None:
1164
assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3}
1165
assert "scaled_var" in prior
1166
assert "scaled_var_unscaled" in prior
1167
+
1168
1169
+def test_prior_list_dims() -> None:
1170
+ dist = Prior("Normal", dims=["channel", "geo"])
1171
+ assert dist.dims == ("channel", "geo")
0 commit comments