Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 127 additions & 1 deletion pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,7 +1155,7 @@ def compute_mean_contributions_over_time(
)

if getattr(self, "yearly_seasonality", None):
contributions_fourier_over_time = (
contributions_fourier_over_time = pd.DataFrame(
az.extract(
self.fit_result,
var_names=["fourier_contributions"],
Expand All @@ -1165,6 +1165,8 @@ def compute_mean_contributions_over_time(
.to_dataframe()
.squeeze()
.unstack()
.sum(axis=1),
columns=["yearly_seasonality"],
)
else:
contributions_fourier_over_time = pd.DataFrame(
Expand Down Expand Up @@ -1293,6 +1295,130 @@ def plot_channel_contribution_share_hdi(
def graphviz(self, **kwargs):
return pm.model_to_graphviz(self.model, **kwargs)

def _process_decomposition_components(self, data: pd.DataFrame):
"""
Process data to compute the sum of contributions by component and calculate their percentages.
Parameters
----------
dataframe : pd.DataFrame
Dataframe containing the time series data.
Returns
-------
dataframe : pd.DataFrame
A dataframe with contributions summed up by component, sorted, and with percentages.
"""
dataframe = data.copy()
stack_dataframe = dataframe.stack().reset_index()
stack_dataframe.columns = pd.Index(["date", "component", "contribution"])
stack_dataframe.set_index(["date", "component"], inplace=True)
dataframe = stack_dataframe.groupby("component").sum()
dataframe.sort_values(by="contribution", ascending=True, inplace=True)
dataframe.reset_index(inplace=True)

total_contribution = dataframe["contribution"].sum()
dataframe["percentage"] = (dataframe["contribution"] / total_contribution) * 100

return dataframe

def plot_waterfall_components_decomposition(
self,
original_scale: bool = True,
figsize: Tuple = (14, 7),
**kwargs,
):
"""
This function creates a waterfall plot. The plot shows the decomposition of the target into its components.

Parameters
----------
original_scale : bool, optional
If True, the contributions are plotted in the original scale of the target.
figsize : Tuple, optional
The size of the figure. The default is (14, 7).
**kwargs
Additional keyword arguments to pass to the matplotlib `subplots` function.

Returns
-------
fig : matplotlib.figure.Figure
The matplotlib figure object.
"""

# Sort the dataframe in ascending order of contribution for the waterfall plot
dataframe = self.compute_mean_contributions_over_time(
original_scale=original_scale
)

dataframe = self._process_decomposition_components(data=dataframe)
total_contribution = dataframe["contribution"].sum()

# Initialize the matplotlib figure and axis
fig, ax = plt.subplots(figsize=figsize, **kwargs)

# Initialize the starting point for the first bar
cumulative_contribution = 0

# Plot each bar with the updated order
for index, row in dataframe.iterrows():
# Choose the color based on the sign of the contribution
color = "lightblue" if row["contribution"] >= 0 else "salmon"

# For negative contributions, start the bar at the cumulative sum minus the contribution
bar_start = (
cumulative_contribution + row["contribution"]
if row["contribution"] < 0
else cumulative_contribution
)
ax.barh(row["component"], row["contribution"], left=bar_start, color=color)

# Only add to the cumulative sum if the contribution is positive
if row["contribution"] > 0:
cumulative_contribution += row["contribution"]

# Label positioning
label_pos = bar_start + (row["contribution"] / 2)
# Ensure that the label is always inside the bar for visibility
if row["contribution"] < 0:
label_pos = bar_start - (row["contribution"] / 2)

# Add labels on top of the bars for the contribution values and percentages
ax.text(
label_pos,
index,
f"{row['contribution']:,.0f}\n({row['percentage']:.1f}%)",
ha="center",
va="center",
color="black",
fontsize=10,
)

# Set the title and labels
ax.set_title("Response Decomposition Waterfall by Components")
ax.set_xlabel("Cumulative Contribution")
ax.set_ylabel("Components")

# Adjust x-axis to show the percentage
xticks = np.linspace(
0, total_contribution, num=11
) # 10 equally spaced ticks from 0 to total
xticklabels = [
f"{(x/total_contribution)*100:.0f}%" for x in xticks
] # Convert to percentages
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)

# Hide the right, top, and left spines for a cleaner look
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(False)

# Add labels on the left to identify the predictor channels, corresponding to the y-ticks
ax.set_yticks(np.arange(len(dataframe)))
ax.set_yticklabels(dataframe["component"])

plt.tight_layout()
return fig


class MMM(BaseMMM, ValidateTargetColumn, ValidateDateColumn, ValidateChannelColumns):
pass
18 changes: 12 additions & 6 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None:
assert mmm.model_config is not None
n_channel: int = len(mmm.channel_columns)
n_control: int = len(mmm.control_columns)
fourier_terms: int = 2 * mmm.yearly_seasonality
# fourier_terms: int = 2 * mmm.yearly_seasonality
mmm.fit(
X=toy_X,
y=toy_y,
Expand Down Expand Up @@ -323,17 +323,23 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None:
)
assert mean_model_contributions_ts.shape == (
toy_X.shape[0],
n_channel + n_control + fourier_terms + 1,
n_channel
+ n_control
+ 2, # 2 for yearly seasonality (+1) and intercept (+)
)

processed_df = mmm._process_decomposition_components(
data=mean_model_contributions_ts
)

assert processed_df.shape == (n_channel + n_control + 2, 3)

assert mean_model_contributions_ts.columns.tolist() == [
"channel_1",
"channel_2",
"control_1",
"control_2",
"sin_order_1",
"cos_order_1",
"sin_order_2",
"cos_order_2",
"yearly_seasonality",
"intercept",
]

Expand Down
1 change: 1 addition & 0 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
("plot_posterior_predictive", {"original_scale": True}),
("plot_components_contributions", {}),
("plot_channel_parameter", {"param_name": "alpha"}),
("plot_waterfall_components_decomposition", {"original_scale": True}),
("plot_direct_contribution_curves", {}),
("plot_direct_contribution_curves", {"same_axes": True}),
("plot_direct_contribution_curves", {"channels": ["channel_2"]}),
Expand Down