Skip to content

Commit 6049ae8

Browse files
authored
Create inverse_scaled_logistic_saturation and the corresponding class (#827)
1 parent 717702a commit 6049ae8

File tree

5 files changed

+110
-0
lines changed

5 files changed

+110
-0
lines changed

pymc_marketing/mmm/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from pymc_marketing.mmm.components.saturation import (
2525
HillSaturation,
26+
InverseScaledLogisticSaturation,
2627
LogisticSaturation,
2728
MichaelisMentenSaturation,
2829
SaturationTransformation,
@@ -45,6 +46,7 @@
4546
"GeometricAdstock",
4647
"HillSaturation",
4748
"LogisticSaturation",
49+
"InverseScaledLogisticSaturation",
4850
"MMM",
4951
"MMMModelBuilder",
5052
"MichaelisMentenSaturation",

pymc_marketing/mmm/components/saturation.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def function(self, x, b):
7676
from pymc_marketing.mmm.components.base import Transformation
7777
from pymc_marketing.mmm.transformers import (
7878
hill_saturation,
79+
inverse_scaled_logistic_saturation,
7980
logistic_saturation,
8081
michaelis_menten,
8182
tanh_saturation,
@@ -201,6 +202,39 @@ def function(self, x, lam, beta):
201202
}
202203

203204

205+
class InverseScaledLogisticSaturation(SaturationTransformation):
206+
"""Wrapper around inverse scaled logistic saturation function.
207+
208+
For more information, see :func:`pymc_marketing.mmm.transformers.inverse_scaled_logistic_saturation`.
209+
210+
.. plot::
211+
:context: close-figs
212+
213+
import matplotlib.pyplot as plt
214+
import numpy as np
215+
from pymc_marketing.mmm import InverseScaledLogisticSaturation
216+
217+
rng = np.random.default_rng(0)
218+
219+
adstock = InverseScaledLogisticSaturation()
220+
prior = adstock.sample_prior(random_seed=rng)
221+
curve = adstock.sample_curve(prior)
222+
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
223+
plt.show()
224+
225+
"""
226+
227+
lookup_name = "inverse_scaled_logistic"
228+
229+
def function(self, x, lam, beta):
230+
return beta * inverse_scaled_logistic_saturation(x, lam)
231+
232+
default_priors = {
233+
"lam": Prior("Gamma", alpha=0.5, beta=1),
234+
"beta": Prior("HalfNormal", sigma=2),
235+
}
236+
237+
204238
class TanhSaturation(SaturationTransformation):
205239
"""Wrapper around tanh saturation function.
206240
@@ -339,6 +373,7 @@ class HillSaturation(SaturationTransformation):
339373
cls.lookup_name: cls
340374
for cls in [
341375
LogisticSaturation,
376+
InverseScaledLogisticSaturation,
342377
TanhSaturation,
343378
TanhSaturationBaselined,
344379
MichaelisMentenSaturation,

pymc_marketing/mmm/transformers.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,55 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5):
478478
return (1 - pt.exp(-lam * x)) / (1 + pt.exp(-lam * x))
479479

480480

481+
def inverse_scaled_logistic_saturation(
482+
x, lam: npt.NDArray[np.float64] | float = 0.5, eps: float = np.log(3)
483+
):
484+
"""Inverse scaled logistic saturation transformation.
485+
It offers a more intuitive alternative to logistic_saturation,
486+
allowing for lambda to be interpreted as the half saturation point
487+
when using default value for eps.
488+
489+
.. math::
490+
f(x) = \\frac{1 - e^{-x*\epsilon/\lambda}}{1 + e^{-x*\epsilon/\lambda}}
491+
492+
.. plot::
493+
:context: close-figs
494+
495+
import matplotlib.pyplot as plt
496+
import numpy as np
497+
import arviz as az
498+
from pymc_marketing.mmm.transformers import inverse_scaled_logistic_saturation
499+
plt.style.use('arviz-darkgrid')
500+
lam = np.array([0.25, 0.5, 1, 2, 4])
501+
x = np.linspace(0, 5, 100)
502+
ax = plt.subplot(111)
503+
for l in lam:
504+
y = inverse_scaled_logistic_saturation(x, lam=l).eval()
505+
plt.plot(x, y, label=f'lam = {l}')
506+
plt.xlabel('spend', fontsize=12)
507+
plt.ylabel('f(spend)', fontsize=12)
508+
box = ax.get_position()
509+
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
510+
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
511+
plt.show()
512+
513+
Parameters
514+
----------
515+
x : tensor
516+
Input tensor.
517+
lam : float or array-like, optional, by default 0.5
518+
Saturation parameter.
519+
eps : float or array-like, optional, by default ln(3)
520+
Scaling parameter. ln(3) results in halfway saturation at lam
521+
522+
Returns
523+
-------
524+
tensor
525+
Transformed tensor.
526+
""" # noqa: W605
527+
return logistic_saturation(x, eps / lam)
528+
529+
481530
class TanhSaturationParameters(NamedTuple):
482531
"""Container for tanh saturation parameters.
483532

tests/mmm/components/test_saturation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from pymc_marketing.mmm.components.saturation import (
2424
HillSaturation,
25+
InverseScaledLogisticSaturation,
2526
LogisticSaturation,
2627
MichaelisMentenSaturation,
2728
TanhSaturation,
@@ -40,6 +41,7 @@ def model() -> pm.Model:
4041
def saturation_functions():
4142
return [
4243
LogisticSaturation(),
44+
InverseScaledLogisticSaturation(),
4345
TanhSaturation(),
4446
TanhSaturationBaselined(),
4547
MichaelisMentenSaturation(),
@@ -93,6 +95,7 @@ def test_support_for_lift_test_integrations(saturation) -> None:
9395
@pytest.mark.parametrize(
9496
"name, saturation_cls",
9597
[
98+
("inverse_scaled_logistic", InverseScaledLogisticSaturation),
9699
("logistic", LogisticSaturation),
97100
("tanh", TanhSaturation),
98101
("tanh_baselined", TanhSaturationBaselined),

tests/mmm/test_transformers.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
delayed_adstock,
2929
geometric_adstock,
3030
hill_saturation,
31+
inverse_scaled_logistic_saturation,
3132
logistic_saturation,
3233
michaelis_menten,
3334
tanh_saturation,
@@ -343,6 +344,26 @@ def test_logistic_saturation_min_max_value(self, x, lam):
343344
assert y_eval.max() <= 1
344345
assert y_eval.min() >= 0
345346

347+
def test_inverse_scaled_logistic_saturation_lam_half(self):
348+
x = np.array([0.01, 0.1, 0.5, 1, 100])
349+
y = inverse_scaled_logistic_saturation(x=x, lam=x)
350+
expected = np.array([0.5] * len(x))
351+
np.testing.assert_almost_equal(
352+
y.eval(),
353+
expected,
354+
decimal=5,
355+
err_msg="The function does not behave as expected at the default value for eps",
356+
)
357+
358+
def test_inverse_scaled_logistic_saturation_min_max_value(self):
359+
x = np.array([0, 1, 100, 500, 5000])
360+
lam = np.array([0.01, 0.25, 0.75, 1.5, 5.0, 10.0, 15.0])[:, None]
361+
362+
y = inverse_scaled_logistic_saturation(x=x, lam=lam)
363+
y_eval = y.eval()
364+
assert y_eval.max() <= 1
365+
assert y_eval.min() >= 0
366+
346367
@pytest.mark.parametrize(
347368
"x, b, c",
348369
[

0 commit comments

Comments
 (0)