Skip to content

Commit 20c2ca0

Browse files
Allow Lists and NumPy Arrays as Transformation Priors (#1879)
* adds wider functionality for priors * adding tests for this change --------- Co-authored-by: Juan Orduz <[email protected]>
1 parent bb7c6fb commit 20c2ca0

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

pymc_marketing/mmm/components/base.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from collections.abc import Iterable
2525
from copy import deepcopy
2626
from inspect import signature
27-
from typing import Any
27+
from typing import Any, TypeAlias
2828

2929
import numpy as np
3030
import numpy.typing as npt
@@ -49,8 +49,13 @@
4949
# "x" for saturation, "time since exposure" for adstock
5050
NON_GRID_NAMES: frozenset[str] = frozenset({"x", "time since exposure"})
5151

52-
SupportedPrior = (
53-
InstanceOf[Prior] | float | InstanceOf[TensorVariable] | InstanceOf[VariableFactory]
52+
SupportedPrior: TypeAlias = (
53+
InstanceOf[Prior]
54+
| float
55+
| InstanceOf[TensorVariable]
56+
| InstanceOf[VariableFactory]
57+
| list
58+
| InstanceOf[npt.NDArray[np.floating]]
5459
)
5560

5661

@@ -127,7 +132,7 @@ class Transformation:
127132
128133
Parameters
129134
----------
130-
priors : dict[str, Prior | float | TensorVariable | VariableFactory], optional
135+
priors : dict[str, Prior | float | TensorVariable | VariableFactory | list | numpy array], optional
131136
Dictionary with the priors for the parameters of the function. The keys should be the
132137
parameter names and the values the priors. If not provided, it will use the default
133138
priors from the subclass.
@@ -144,8 +149,7 @@ class Transformation:
144149

145150
def __init__(
146151
self,
147-
priors: dict[str, Prior | float | TensorVariable | VariableFactory]
148-
| None = None,
152+
priors: dict[str, SupportedPrior] | None = None,
149153
prefix: str | None = None,
150154
) -> None:
151155
self._checks()

tests/mmm/components/test_base.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from unittest.mock import Mock
15+
1416
import matplotlib.pyplot as plt
1517
import numpy as np
1618
import pymc as pm
1719
import pytensor.tensor as pt
1820
import pytest
1921
import xarray as xr
2022
from pymc_extras.prior import Prior, VariableFactory
21-
from pytensor.tensor import TensorVariable
23+
from pytensor.tensor import TensorVariable, scalar
2224

2325
from pymc_marketing.mmm.components.base import (
2426
DuplicatedTransformationError,
@@ -616,3 +618,40 @@ def test_apply_idx_more_dims(new_transformation_class) -> None:
616618
Y.eval(),
617619
expected.eval(),
618620
)
621+
622+
623+
class DummyTransformation(Transformation):
624+
@staticmethod
625+
def function(data, x):
626+
return data + x
627+
628+
prefix = "dummy"
629+
lookup_name = "dummy"
630+
default_priors = {"x": Prior("Normal", mu=0, sigma=1)}
631+
632+
633+
mock_variable_factory = Mock(spec=VariableFactory)
634+
635+
636+
@pytest.mark.parametrize(
637+
"prior_value",
638+
[
639+
Prior("Normal", mu=0, sigma=1),
640+
0.5,
641+
scalar("x"),
642+
mock_variable_factory,
643+
[0.1, 0.2, 0.3],
644+
np.array([0.1, 0.2, 0.3], dtype=float),
645+
],
646+
)
647+
def test_transformation_accepts_supported_priors(prior_value):
648+
priors = {"x": prior_value}
649+
tfm = DummyTransformation(priors=priors)
650+
651+
actual = tfm.function_priors["x"]
652+
653+
# Compare array-likes properly
654+
if isinstance(prior_value, list | np.ndarray):
655+
assert np.array_equal(actual, prior_value)
656+
else:
657+
assert actual == prior_value

0 commit comments

Comments
 (0)