Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 15, 2025

Description

A pain point for me when testing different algorithms (e.g. MCMC vs VI) is that I don't want to write a 2nd version of the model with pm.Minibatch on the data.

This PR adds a model transformation that does that for the user. It's the reverse of the remove_minibatched_nodes transformer that @zaxtax implemented recently.

This is a WIP, it doesn't actually work now, because I can't figure out how to rebuild the observed variable with the total_size set correctly. Help wanted.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7785.org.readthedocs.build/en/7785/

@jessegrabowski jessegrabowski requested a review from zaxtax May 15, 2025 12:28
@ricardoV94
Copy link
Member

This is a WIP, it doesn't actually work now, because I can't figure out how to rebuild the observed variable with the total_size set correctly. Help wanted.

You can use the lower level utility:

def create_minibatch_rv(

Then make that a vanilla observed RV

@ricardoV94
Copy link
Member

Ah you already did that, so your question is how to get total size? Grab the batch shape of the variable and constant fold it without raising if it can't be fully folded

@jessegrabowski
Copy link
Member Author

My real issue was not understanding what needs to be the key and value in the replacements, between:

  1. The model variable
  2. The memo variable
  3. The fgraph variable

@ricardoV94
Copy link
Member

ricardoV94 commented May 15, 2025

the best is usual to replace the whole fgraph ModelObservedRV by a new one. You probably have to discard any dims on the batch dimension which is an input to that op

@jessegrabowski
Copy link
Member Author

I don't really understand what that answer means

@ricardoV94
Copy link
Member

dprint the fgraph and it will perhaps be more obvious what I am mumbling

@jessegrabowski
Copy link
Member Author

The problem i was running into was that I ended up with two beta RVs after doing the replace. Beta was the only RV implicated in the ModelObservedRV sub-graph

@zaxtax
Copy link
Contributor

zaxtax commented May 15, 2025 via email

@zaxtax zaxtax force-pushed the model-to-minibatch branch from c1168de to 8d1b479 Compare June 9, 2025 12:52
minibatch_vars = Minibatch(*data_vars, batch_size=batch_size)
replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)}
assert 0
# Add total_size to all observed RVs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should only add to those that depend on the minibatch data no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct thing would be a dim analysis like we do for MarginaModel to confirm the first dim of the data maps to the first dim of the observed rvs, which is when the rewrite is valid. We may not want to do that, but we should be clear about the assumptions in the docstrings.

Example where minibatch rewrite will fail / do the wrong thing, is if you tranpose the data before you used it in the observations.

replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)}
assert 0
# Add total_size to all observed RVs
total_size = data_vars[0].get_value().shape[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total size can be symbolic I think?


data_vars = [
memo[datum].owner.inputs[0]
for datum in (model.named_vars[datum_name] for datum_name in model.named_vars)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a model.data_vars. You should however allow users to specify which data vars to be minibatched (default to all is fine). Alternatively we could restrict this to models with dims, and the user has to tell us which dim is being minibatched?

That makes the graph analysis easier

@zaxtax
Copy link
Contributor

zaxtax commented Jun 11, 2025 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants