Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions pymc_marketing/mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,40 @@ def _build_subplot_title(
return ", ".join(title_parts)
return fallback_title

def _align_y_axes(self, ax, ax2, include_zero=False):
"""Align y=0 of primary and secondary y-axis."""
if ax.axes.get_ylim()[0] < 0 or ax2.axes.get_ylim()[0] < 0:
ylims1 = ax.axes.get_ylim()
ylims2 = ax2.axes.get_ylim()
# Find the ratio of negative vs. positive part of the axes.
if ylims1[1]:
ax1_yratio = ylims1[0] / ylims1[1]
else:
# Fully negative axis.
ax1_yratio = -1

if ylims2[1]:
ax2_yratio = ylims2[0] / ylims2[1]
else:
# Fully negative axis, may need to reflect the other
ax2_yratio = -1

# Make axis adjustments. If both axes fully negative, no adjustment.
if ax1_yratio < ax2_yratio:
ax2.set_ylim(bottom=ylims2[1] * ax1_yratio)
if ax1_yratio == -1:
# if the axis is fully negative, center zero.
ax.set_ylim(top=-ylims1[0])
elif ax2_yratio < ax1_yratio:
ax.set_ylim(bottom=ylims1[1] * ax2_yratio)
if ax2_yratio == -1:
# if the axis is fully negative, center zero.
ax2.set_ylim(top=-ylims2[0])
Comment on lines +248 to +274
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use NamedTuple or something. Quite confusing to read

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this now, I fully agree. I also think that this is not capturing every possible combination the axes can get odd. I'll move this to a draft and rework it.

elif include_zero:
# Ensure both axes start at zero
ax.set_ylim(bottom=0)
ax2.set_ylim(bottom=0)

def _get_additional_dim_combinations(
self,
data: xr.Dataset,
Expand Down Expand Up @@ -1190,6 +1224,9 @@ def _plot_budget_allocation_bars(
ax.set_xticklabels(channels)
ax.tick_params(axis="x", rotation=90)

# Ensure that y=0 are aligned between ax and ax2.
self._align_y_axes(ax, ax2, include_zero=True)

# Turn off grid and add legend
ax.grid(False)
ax2.grid(False)
Expand Down