Skip to content

Commit 327ac97

Browse files
authored
Register new media transformations automatically (#1320)
1 parent 1630589 commit 327ac97

File tree

6 files changed

+140
-43
lines changed

6 files changed

+140
-43
lines changed

pymc_marketing/mmm/components/adstock.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def function(self, x, alpha):
5252
5353
"""
5454

55+
from __future__ import annotations
56+
5557
import numpy as np
5658
import xarray as xr
5759
from pydantic import Field, validate_call
@@ -60,6 +62,7 @@ def function(self, x, alpha):
6062
from pymc_marketing.mmm.components.base import (
6163
SupportedPrior,
6264
Transformation,
65+
create_registration_meta,
6366
)
6467
from pymc_marketing.mmm.transformers import (
6568
ConvMode,
@@ -70,8 +73,12 @@ def function(self, x, alpha):
7073
)
7174
from pymc_marketing.prior import Prior
7275

76+
ADSTOCK_TRANSFORMATIONS: dict[str, type[AdstockTransformation]] = {}
77+
78+
AdstockRegistrationMeta: type[type] = create_registration_meta(ADSTOCK_TRANSFORMATIONS)
79+
7380

74-
class AdstockTransformation(Transformation):
81+
class AdstockTransformation(Transformation, metaclass=AdstockRegistrationMeta): # type: ignore
7582
"""Subclass for all adstock functions.
7683
7784
In order to use a custom saturation function, inherit from this class and define:
@@ -322,17 +329,6 @@ def function(self, x, lam, k):
322329
}
323330

324331

325-
ADSTOCK_TRANSFORMATIONS: dict[str, type[AdstockTransformation]] = {
326-
cls.lookup_name: cls # type: ignore
327-
for cls in [
328-
GeometricAdstock,
329-
DelayedAdstock,
330-
WeibullPDFAdstock,
331-
WeibullCDFAdstock,
332-
]
333-
}
334-
335-
336332
def register_adstock_transformation(cls: type[AdstockTransformation]) -> None:
337333
"""Register a new adstock transformation."""
338334
ADSTOCK_TRANSFORMATIONS[cls.lookup_name] = cls

pymc_marketing/mmm/components/base.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,3 +592,47 @@ def _serialize_value(value: Any) -> Any:
592592
return value.tolist()
593593

594594
return value
595+
596+
597+
class DuplicatedTransformationError(Exception):
598+
"""Exception when a transformation is duplicated."""
599+
600+
def __init__(self, name: str, lookup_name: str):
601+
self.name = name
602+
self.lookup_name = lookup_name
603+
super().__init__(f"Duplicate {name}. The name {lookup_name!r} already exists.")
604+
605+
606+
def create_registration_meta(subclasses: dict[str, Any]) -> type[type]:
607+
"""Create a metaclass for registering subclasses.
608+
609+
Parameters
610+
----------
611+
subclasses : dict[str, type[Transformation]]
612+
The subclasses to register.
613+
614+
Returns
615+
-------
616+
type
617+
The metaclass for registering subclasses.
618+
619+
"""
620+
621+
class RegistrationMeta(type):
622+
def __new__(cls, name, bases, attrs):
623+
new_cls = super().__new__(cls, name, bases, attrs)
624+
625+
if "lookup_name" not in attrs:
626+
return new_cls
627+
628+
base_name = bases[0].__name__
629+
630+
lookup_name = attrs["lookup_name"]
631+
if lookup_name in subclasses:
632+
raise DuplicatedTransformationError(base_name, lookup_name)
633+
634+
subclasses[lookup_name] = new_cls
635+
636+
return new_cls
637+
638+
return RegistrationMeta

pymc_marketing/mmm/components/saturation.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ def function(self, x, b):
7171
7272
"""
7373

74+
from __future__ import annotations
75+
7476
import numpy as np
7577
import pytensor.tensor as pt
7678
import xarray as xr
@@ -79,6 +81,7 @@ def function(self, x, b):
7981
from pymc_marketing.deserialize import deserialize, register_deserialization
8082
from pymc_marketing.mmm.components.base import (
8183
Transformation,
84+
create_registration_meta,
8285
)
8386
from pymc_marketing.mmm.transformers import (
8487
hill_function,
@@ -92,8 +95,12 @@ def function(self, x, b):
9295
)
9396
from pymc_marketing.prior import Prior
9497

98+
SATURATION_TRANSFORMATIONS: dict[str, type[SaturationTransformation]] = {}
99+
100+
SaturationRegistrationMeta = create_registration_meta(SATURATION_TRANSFORMATIONS)
101+
95102

96-
class SaturationTransformation(Transformation):
103+
class SaturationTransformation(Transformation, metaclass=SaturationRegistrationMeta): # type: ignore
97104
"""Subclass for all saturation transformations.
98105
99106
In order to use a custom saturation transformation, subclass and define:
@@ -452,21 +459,6 @@ def function(self, x, alpha, beta):
452459
}
453460

454461

455-
SATURATION_TRANSFORMATIONS: dict[str, type[SaturationTransformation]] = {
456-
cls.lookup_name: cls
457-
for cls in [
458-
LogisticSaturation,
459-
InverseScaledLogisticSaturation,
460-
TanhSaturation,
461-
TanhSaturationBaselined,
462-
MichaelisMentenSaturation,
463-
HillSaturation,
464-
HillSaturationSigmoid,
465-
RootSaturation,
466-
]
467-
}
468-
469-
470462
def register_saturation_transformation(cls: type[SaturationTransformation]) -> None:
471463
"""Register a new saturation transformation.
472464

tests/mmm/components/test_adstock.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@
3232
WeibullCDFAdstock,
3333
WeibullPDFAdstock,
3434
adstock_from_dict,
35-
register_adstock_transformation,
36-
)
37-
from pymc_marketing.mmm.components.adstock import (
38-
ADSTOCK_TRANSFORMATIONS,
3935
)
4036
from pymc_marketing.mmm.transformers import ConvMode
4137
from pymc_marketing.prior import Prior
@@ -161,27 +157,25 @@ def test_adstock_from_dict_without_priors(adstock, deserialize_func) -> None:
161157
}
162158

163159

164-
@pytest.mark.parametrize("deserialize_func", [adstock_from_dict, deserialize])
165-
def test_register_adstock_transformation(deserialize_func) -> None:
166-
class NewTransformation(AdstockTransformation):
167-
lookup_name: str = "new_transformation"
168-
default_priors = {}
160+
class AnotherNewTransformation(AdstockTransformation):
161+
lookup_name: str = "another_new_transformation"
162+
default_priors = {}
169163

170-
def function(self, x):
171-
return x
164+
def function(self, x):
165+
return x
172166

173-
register_adstock_transformation(NewTransformation)
174-
assert "new_transformation" in ADSTOCK_TRANSFORMATIONS
175167

168+
@pytest.mark.parametrize("deserialize_func", [adstock_from_dict, deserialize])
169+
def test_automatic_register_adstock_transformation(deserialize_func) -> None:
176170
data = {
177-
"lookup_name": "new_transformation",
171+
"lookup_name": "another_new_transformation",
178172
"l_max": 10,
179173
"normalize": False,
180174
"mode": "Before",
181175
"priors": {},
182176
}
183177
adstock = deserialize_func(data)
184-
assert adstock == NewTransformation(
178+
assert adstock == AnotherNewTransformation(
185179
l_max=10, mode=ConvMode.Before, normalize=False, priors={}
186180
)
187181

@@ -239,3 +233,22 @@ def test_deserialization(
239233
assert isinstance(alpha, ArbitraryObject)
240234
assert alpha.msg == "hello"
241235
assert alpha.value == 1
236+
237+
238+
def test_deserialize_new_transformation() -> None:
239+
class NewAdstock(AdstockTransformation):
240+
lookup_name = "new_adstock"
241+
242+
def function(self, x):
243+
return x
244+
245+
default_priors = {}
246+
247+
data = {
248+
"lookup_name": "new_adstock",
249+
"l_max": 10,
250+
}
251+
252+
instance = deserialize(data)
253+
assert isinstance(instance, NewAdstock)
254+
assert instance.l_max == 10

tests/mmm/components/test_base.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from pytensor.tensor import TensorVariable
2121

2222
from pymc_marketing.mmm.components.base import (
23+
DuplicatedTransformationError,
2324
MissingDataParameter,
2425
ParameterPriorException,
2526
Transformation,
27+
create_registration_meta,
2628
)
2729
from pymc_marketing.prior import Prior
2830

@@ -417,3 +419,35 @@ def test_serialization(new_transformation_class) -> None:
417419
"b": [1, 2, 3],
418420
},
419421
}
422+
423+
424+
def test_automatic_registration() -> None:
425+
subclasses = {}
426+
427+
RegistrationMeta = create_registration_meta(subclasses)
428+
429+
class BaseTransform:
430+
pass
431+
432+
class Transform(BaseTransform, metaclass=RegistrationMeta):
433+
pass
434+
435+
class NewTransform(Transform):
436+
lookup_name = "new"
437+
438+
assert subclasses == {"new": NewTransform}
439+
440+
class AnotherTransform(Transform):
441+
lookup_name = "another"
442+
443+
assert subclasses == {"new": NewTransform, "another": AnotherTransform}
444+
445+
with pytest.raises(DuplicatedTransformationError) as e:
446+
447+
class _(Transform):
448+
lookup_name = "new"
449+
450+
exception = e.value
451+
452+
assert exception.lookup_name == "new"
453+
assert exception.name == "Transform"

tests/mmm/components/test_saturation.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
LogisticSaturation,
3333
MichaelisMentenSaturation,
3434
RootSaturation,
35+
SaturationTransformation,
3536
TanhSaturation,
3637
TanhSaturationBaselined,
3738
saturation_from_dict,
@@ -287,3 +288,20 @@ def test_deserialization(
287288
assert isinstance(alpha, ArbitraryObject)
288289
assert alpha.msg == "hello"
289290
assert alpha.value == 1
291+
292+
293+
def test_deserialize_new_transformation() -> None:
294+
class NewSaturation(SaturationTransformation):
295+
lookup_name = "new_saturation"
296+
297+
def function(self, x):
298+
return x
299+
300+
default_priors = {}
301+
302+
data = {
303+
"lookup_name": "new_saturation",
304+
}
305+
306+
instance = deserialize(data)
307+
assert isinstance(instance, NewSaturation)

0 commit comments

Comments
 (0)