Skip to content

Commit 9129a9e

Browse files
PatrickRobothampatrick-conundrmwilliambdean
authored
Fix Visual for hill_saturation function (Issue #851 ) (#857)
* Fix plotting by evaluating tensors. * Add space after sphinx directive. * Remove indentation from blank line. * Add shared y axis for subplots. --------- Co-authored-by: Patrick Robotham <[email protected]> Co-authored-by: Will Dean <[email protected]>
1 parent 8ee9254 commit 9129a9e

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

pymc_marketing/mmm/transformers.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -924,47 +924,50 @@ def hill_saturation(
924924
925925
.. plot::
926926
:context: close-figs
927+
927928
import numpy as np
928929
import matplotlib.pyplot as plt
929930
from pymc_marketing.mmm.transformers import hill_saturation
930931
x = np.linspace(0, 10, 100)
931932
# Varying sigma
932933
sigmas = [0.5, 1, 1.5]
933-
plt.figure(figsize=(12, 4))
934+
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
934935
for i, sigma in enumerate(sigmas):
935936
plt.subplot(1, 3, i+1)
936-
y = hill_saturation(x, sigma, 2, 5)
937+
y = hill_saturation(x, sigma, 2, 5).eval()
937938
plt.plot(x, y)
938939
plt.xlabel('x')
939-
plt.ylabel('Hill Saturation')
940940
plt.title(f'Sigma = {sigma}')
941+
plt.subplot(1,3,1)
942+
plt.ylabel('Hill Saturation')
941943
plt.tight_layout()
942944
plt.show()
943945
# Varying beta
944946
betas = [1, 2, 3]
945-
plt.figure(figsize=(12, 4))
947+
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
946948
for i, beta in enumerate(betas):
947949
plt.subplot(1, 3, i+1)
948-
y = hill_saturation(x, 1, beta, 5)
950+
y = hill_saturation(x, 1, beta, 5).eval()
949951
plt.plot(x, y)
950952
plt.xlabel('x')
951-
plt.ylabel('Hill Saturation')
952953
plt.title(f'Beta = {beta}')
954+
plt.subplot(1,3,1)
955+
plt.ylabel('Hill Saturation')
953956
plt.tight_layout()
954957
plt.show()
955958
# Varying lam
956959
lams = [3, 5, 7]
957-
plt.figure(figsize=(12, 4))
960+
fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
958961
for i, lam in enumerate(lams):
959962
plt.subplot(1, 3, i+1)
960-
y = hill_saturation(x, 1, 2, lam)
963+
y = hill_saturation(x, 1, 2, lam).eval()
961964
plt.plot(x, y)
962965
plt.xlabel('x')
963-
plt.ylabel('Hill Saturation')
964966
plt.title(f'Lambda = {lam}')
967+
plt.subplot(1,3,1)
968+
plt.ylabel('Hill Saturation')
965969
plt.tight_layout()
966970
plt.show()
967-
968971
Parameters
969972
----------
970973
x : float or array-like

0 commit comments

Comments
 (0)