Skip to content

Implement MultiDimensionalMMM with pymc.dims#2204

Open
ricardoV94 wants to merge 6 commits intopymc-labs:mainfrom
ricardoV94:mmm_dims
Open

Implement MultiDimensionalMMM with pymc.dims#2204
ricardoV94 wants to merge 6 commits intopymc-labs:mainfrom
ricardoV94:mmm_dims

Conversation

@ricardoV94
Copy link
Contributor

@ricardoV94 ricardoV94 commented Jan 20, 2026

TODO:

  • Dev migration guide
  • Adapt new incrementality code
  • Censored Prior
  • Deprecate dims to Transformation.apply

Closes #1981
Closes #1514
Closes #1630 (just a regression test, it worked fine before)
related to #2017


📚 Documentation preview 📚: https://pymc-marketing--2204.org.readthedocs.build/en/2204/

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@github-actions github-actions bot added docs Improvements or additions to documentation MMM labels Jan 20, 2026
@ricardoV94 ricardoV94 changed the title POC implement MultiDimensionalMMM with dimmed variables POC implement MultiDimensionalMMM with pymc.dims Jan 20, 2026
name="channel_contribution",
var=baseline_channel_contribution * media_broadcast,
dims=("date", *self.dims, "channel"),
channel_contribution = pmd.Deterministic(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's awesome.

@codecov
Copy link

codecov bot commented Jan 20, 2026

Codecov Report

❌ Patch coverage is 96.18321% with 25 lines in your changes missing coverage. Please review.
✅ Project coverage is 93.11%. Comparing base (f4e85e2) to head (36aaee4).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
pymc_marketing/mmm/utility.py 81.25% 9 Missing ⚠️
pymc_marketing/mmm/components/base.py 90.00% 4 Missing ⚠️
pymc_marketing/special_priors.py 88.88% 4 Missing ⚠️
pymc_marketing/mmm/hsgp.py 94.54% 3 Missing ⚠️
pymc_marketing/mmm/sensitivity_analysis.py 94.11% 2 Missing ⚠️
pymc_marketing/mmm/dims.py 94.44% 1 Missing ⚠️
pymc_marketing/mmm/linear_trend.py 95.00% 1 Missing ⚠️
pymc_marketing/model_graph.py 95.45% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2204      +/-   ##
==========================================
- Coverage   93.15%   93.11%   -0.05%     
==========================================
  Files          79       80       +1     
  Lines       12646    12708      +62     
==========================================
+ Hits        11781    11833      +52     
- Misses        865      875      +10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@cetagostini cetagostini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I found a few typos while reviewing the pymc.dims implementation. See inline comments below 👇

Copy link
Contributor

@cetagostini cetagostini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more inconsistencies I noticed - some return type annotations say TensorVariable but the code actually returns XTensorVariable 🤔

@cetagostini
Copy link
Contributor

Function Equivalence Testing Results 🧪

I tested the transformation functions to verify they produce equivalent outputs with the new pymc.dims API. Here's what I found:

✅ All Functions Pass Numerical Equivalence

Function Match
logistic_saturation
tanh_saturation
hill_function
michaelis_menten
root_saturation
geometric_adstock weights

⚠️ Important Finding: Dimension Ordering

When broadcasting with xtensor, the dimension order may differ from numpy!

# Example:
x_2d = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])  # (3 dates, 2 channels)
lam_channel = np.array([0.3, 0.7])  # per-channel lambda

# Old API (numpy-style):
# Output shape: (3, 2) → (date, channel) order preserved

# New API (xtensor):
# Output shape: (2, 3) → (channel, date) order CHANGED!

The values are correct (match after transpose), but any code relying on axis indices (axis=0, axis=1) instead of dimension names may break.

This is expected behavior for xarray-like semantics, but worth noting for anyone working with the codebase! 👍

@ricardoV94
Copy link
Contributor Author

Yes order doesn't matter while in xtensor land. You can always transpose explicitly to another order, which you should do if you need to go to tensor land (via .values).

Note even if not so relevant, dimension order matches the same as xarray

@ricardoV94 ricardoV94 force-pushed the mmm_dims branch 2 times, most recently from 2a75405 to c2193c6 Compare February 2, 2026 15:43
@ricardoV94 ricardoV94 force-pushed the mmm_dims branch 2 times, most recently from aa79d7b to c9a5a89 Compare February 2, 2026 18:36
budgets_expanded * self._budget_distribution_over_period_tensor
) # Shape: (num_periods, num_optimized_budgets)
budgets_optimized * self._budget_distribution_over_period_tensor
).transpose("date", "budgets_flat")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nice!

@ricardoV94 ricardoV94 force-pushed the mmm_dims branch 3 times, most recently from d8bfaaa to 872eeec Compare February 6, 2026 22:27
@ricardoV94
Copy link
Contributor Author

Could you also run the new (short) notebook mmm_custom_splines.ipynb? It is failing, but I can not reproduce :)

It's failing with numba mode, probably that's why you can't repro

@juanitorduz
Copy link
Collaborator

Could you also run the new (short) notebook mmm_custom_splines.ipynb? It is failing, but I can not reproduce :)

It's failing with numba mode, probably that's why you can't repro

Is there a fundamental problem? We acn also add it into the BLACKLIST in the runner script and create an issue

@ricardoV94
Copy link
Contributor Author

@juanitorduz it's fixed by pymc-devs/pytensor#1937

Hopefully by the time I cut a release @isofer hasn't finished 20 more PRs and this is still up to date :D

@isofer
Copy link
Contributor

isofer commented Mar 5, 2026

@juanitorduz it's fixed by pymc-devs/pytensor#1937

Hopefully by the time I cut a release @isofer hasn't finished 20 more PRs and this is still up to date :D

My goal is to finish 30 PRs in 30 days and for each PR to touch whole the files in the repo :D

@ricardoV94
Copy link
Contributor Author

ricardoV94 commented Mar 6, 2026

Question. Many notebooks use adstock.apply directly. This PR has breaking change for these, unless the user passes a DataArray or xtensor input and specifies core_dim (which dimension is adstocked over).

  • Saturation transforms don't require core_dim, but still require x to have dims.

Do we consider this was mostly for devs, and it's fine to break, or do we want to try and infer dims positionally like Prior is doing after pymc-devs/pymc-extras#657? (with a warning that the behavior is only temporarily supported?)

The logic would have to be something like: if there are no dims in x, we have to look at the transformation params, check their broadcasted dims, and try to map positionally to the dimensions of x, possibly adding our own dims for batch dimensions of x that go beyond those of the parameters.

When core_dim is needed, it would default to the leading dim of x.

I thought it was useful to do it for Prior objects because those are clearly user-facing when they customize models. But transforms feel more internal?

Note, this behavior is compatible with the old use (same result), but can be wrong in uses that are now supported (like x broadcasting with arbitrary dimensions of the parameters, that itself does not contain)


Somewhat related, these transforms used to accept a dims argument, that is irrelevant now since we always know the dims of the parameters. It was unclear what this should mean in some contexts (see my questions #767 (comment) and #1693 (comment))

Can someone confirm it shouldn't be needed anymore?

@juanitorduz
Copy link
Collaborator

Thabnks @ricardoV94 !

From what I have seen, adstock.apply mostly for devs, and it's fine to break. We just need to replace the notebooks with the corresponding corrected expressions (I would help with this) (is that fair to say @williambdean ?)

Regarding the "transforms used to accept a dims argument", I would let @williambdean give his input, as I have not used this that much.

@ricardoV94
Copy link
Contributor Author

The notebooks are already updated

@juanitorduz
Copy link
Collaborator

Then even better 😄.

@ricardoV94 ricardoV94 marked this pull request as ready for review March 8, 2026 13:50
@ricardoV94
Copy link
Contributor Author

Ci is passing. I didn't add the deprecation of dims to transforms (to tell users they aren't needed), and the migration guide.

Can add in this PR or a quick follow up. What do you think?

@ricardoV94
Copy link
Contributor Author

For review, the notebook changes are separate, so you don't have to worry about "19k LOC".

Each commit is logically consistent, so it can also be merged with rebase to avoid the same issue of putting those nb changes together

Copy link
Collaborator

@juanitorduz juanitorduz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 This looks great! I left some minor comments.

All in all, I think we should merge this one soon!

@juanitorduz juanitorduz requested a review from cetagostini March 9, 2026 09:49
@juanitorduz
Copy link
Collaborator

  1. thanks @ricardoV94 ! I suggest we wait for additional input until wednesday and then feel free to merge (I think the dev guide is missing from your TODOs, but you can also tackle that in a speratate PR, as you wish).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Bass model Dealing with the Bass Defusion model bug Something isn't working customer choice Related to customer choice module docs Improvements or additions to documentation MMM multidimensional Prior class tests TVPs Related to time varying parameters

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Basis functions fail with multidimensional events single channel MMMs don't work Handle dims for intercept

7 participants