diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index b7a2ddae..e461b765 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -243,6 +243,40 @@ def _build_subplot_title( return ", ".join(title_parts) return fallback_title + def _align_y_axes(self, ax, ax2, include_zero=False): + """Align y=0 of primary and secondary y-axis.""" + if ax.axes.get_ylim()[0] < 0 or ax2.axes.get_ylim()[0] < 0: + ylims1 = ax.axes.get_ylim() + ylims2 = ax2.axes.get_ylim() + # Find the ratio of negative vs. positive part of the axes. + if ylims1[1]: + ax1_yratio = ylims1[0] / ylims1[1] + else: + # Fully negative axis. + ax1_yratio = -1 + + if ylims2[1]: + ax2_yratio = ylims2[0] / ylims2[1] + else: + # Fully negative axis, may need to reflect the other + ax2_yratio = -1 + + # Make axis adjustments. If both axes fully negative, no adjustment. + if ax1_yratio < ax2_yratio: + ax2.set_ylim(bottom=ylims2[1] * ax1_yratio) + if ax1_yratio == -1: + # if the axis is fully negative, center zero. + ax.set_ylim(top=-ylims1[0]) + elif ax2_yratio < ax1_yratio: + ax.set_ylim(bottom=ylims1[1] * ax2_yratio) + if ax2_yratio == -1: + # if the axis is fully negative, center zero. + ax2.set_ylim(top=-ylims2[0]) + elif include_zero: + # Ensure both axes start at zero + ax.set_ylim(bottom=0) + ax2.set_ylim(bottom=0) + def _get_additional_dim_combinations( self, data: xr.Dataset, @@ -1190,6 +1224,9 @@ def _plot_budget_allocation_bars( ax.set_xticklabels(channels) ax.tick_params(axis="x", rotation=90) + # Ensure that y=0 are aligned between ax and ax2. + self._align_y_axes(ax, ax2, include_zero=True) + # Turn off grid and add legend ax.grid(False) ax2.grid(False)