Skip to content

Commit 285f704

Browse files
authored
make hill pass through the origin (#920)
1 parent b070375 commit 285f704

File tree

3 files changed

+33
-14
lines changed

3 files changed

+33
-14
lines changed

pymc_marketing/mmm/components/saturation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,9 +367,9 @@ class HillSaturation(SaturationTransformation):
367367
function = hill_saturation
368368

369369
default_priors = {
370-
"sigma": Prior("HalfNormal", sigma=2),
371-
"beta": Prior("HalfNormal", sigma=2),
372-
"lam": Prior("HalfNormal", sigma=2),
370+
"sigma": Prior("HalfNormal", sigma=1.5),
371+
"beta": Prior("HalfNormal", sigma=1.5),
372+
"lam": Prior("HalfNormal", sigma=1.5),
373373
}
374374

375375

pymc_marketing/mmm/transformers.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -908,10 +908,10 @@ def hill_saturation(
908908
r"""Hill Saturation Function
909909
910910
.. math::
911-
f(x) = \frac{\sigma}{1 + e^{-\beta(x - \lambda)}}
911+
f(x) = \frac{\sigma}{1 + e^{-\beta(x - \lambda)}} - \frac{\sigma}{1 + e^{\beta\lambda}}
912912
913913
where:
914-
- :math:`\sigma` is the maximum value (upper asymptote)
914+
- :math:`\sigma` is the upper asymptote
915915
- :math:`\beta` is the slope parameter
916916
- :math:`\lambda` is the transition point on the X-axis
917917
- :math:`x` is the independent variable
@@ -920,7 +920,9 @@ def hill_saturation(
920920
used to describe the saturation effect in biological systems. The curve is
921921
characterized by its sigmoidal shape, representing a gradual transition from
922922
a low, nearly zero level to a high plateau, the maximum value the function
923-
will approach as the independent variable grows large.
923+
will approach as the independent variable grows large. In this implementation,
924+
we add an offset to the sigmoid function to ensure that the function always passes
925+
through the origin as we expect zero spend to result in zero contribution.
924926
925927
.. plot::
926928
:context: close-figs
@@ -968,6 +970,7 @@ def hill_saturation(
968970
plt.ylabel('Hill Saturation')
969971
plt.tight_layout()
970972
plt.show()
973+
971974
Parameters
972975
----------
973976
x : float or array-like
@@ -987,7 +990,7 @@ def hill_saturation(
987990
float or array-like
988991
The value of the Hill function for each input value of x.
989992
"""
990-
return sigma / (1 + pt.exp(-beta * (x - lam)))
993+
return sigma / (1 + pt.exp(-beta * (x - lam))) - sigma / (1 + pt.exp(beta * lam))
991994

992995

993996
def root_saturation(

tests/mmm/test_transformers.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,23 @@ def test_michaelis_menten(self, x, alpha, lam, expected):
466466
(3, 2, -1),
467467
],
468468
)
469-
def test_monotonicity(self, sigma, beta, lam):
469+
def test_hill_monotonicity(self, sigma, beta, lam):
470470
x = np.linspace(-10, 10, 100)
471471
y = hill_saturation(x, sigma, beta, lam).eval()
472472
assert np.all(np.diff(y) >= 0), "The function is not monotonic."
473473

474+
@pytest.mark.parametrize(
475+
"sigma, beta, lam",
476+
[
477+
(1, 1, 0),
478+
(2, 0.5, 1),
479+
(3, 2, -1),
480+
],
481+
)
482+
def test_hill_zero(self, sigma, beta, lam):
483+
y = hill_saturation(0, sigma, beta, lam).eval()
484+
assert y == pytest.approx(0.0)
485+
474486
@pytest.mark.parametrize(
475487
"x, sigma, beta, lam",
476488
[
@@ -479,7 +491,7 @@ def test_monotonicity(self, sigma, beta, lam):
479491
(-3, 3, 2, -1),
480492
],
481493
)
482-
def test_sigma_upper_bound(self, x, sigma, beta, lam):
494+
def test_hill_sigma_upper_bound(self, x, sigma, beta, lam):
483495
y = hill_saturation(x, sigma, beta, lam).eval()
484496
assert y <= sigma, f"The output {y} exceeds the upper bound sigma {sigma}."
485497

@@ -491,11 +503,13 @@ def test_sigma_upper_bound(self, x, sigma, beta, lam):
491503
(-1, 3, 2, -1, 1.5),
492504
],
493505
)
494-
def test_behavior_at_lambda(self, x, sigma, beta, lam, expected):
506+
def test_hill_behavior_at_lambda(self, x, sigma, beta, lam, expected):
495507
y = hill_saturation(x, sigma, beta, lam).eval()
508+
offset = sigma / (1 + np.exp(beta * lam))
509+
expected_with_offset = expected - offset
496510
np.testing.assert_almost_equal(
497511
y,
498-
expected,
512+
expected_with_offset,
499513
decimal=5,
500514
err_msg="The function does not behave as expected at lambda.",
501515
)
@@ -508,7 +522,7 @@ def test_behavior_at_lambda(self, x, sigma, beta, lam, expected):
508522
(np.array([1, 2, 3]), 3, 2, 2),
509523
],
510524
)
511-
def test_vectorized_input(self, x, sigma, beta, lam):
525+
def test_hill_vectorized_input(self, x, sigma, beta, lam):
512526
y = hill_saturation(x, sigma, beta, lam).eval()
513527
assert (
514528
y.shape == x.shape
@@ -522,12 +536,14 @@ def test_vectorized_input(self, x, sigma, beta, lam):
522536
(3, 2, -1),
523537
],
524538
)
525-
def test_asymptotic_behavior(self, sigma, beta, lam):
539+
def test_hill_asymptotic_behavior(self, sigma, beta, lam):
526540
x = 1e6 # A very large value to approximate infinity
527541
y = hill_saturation(x, sigma, beta, lam).eval()
542+
offset = sigma / (1 + np.exp(beta * lam))
543+
expected = sigma - offset
528544
np.testing.assert_almost_equal(
529545
y,
530-
sigma,
546+
expected,
531547
decimal=5,
532548
err_msg="The function does not approach sigma as x approaches infinity.",
533549
)

0 commit comments

Comments
 (0)