Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Aug 27, 2024

Description

This PR fixes issues related to Minibatch indexing reported in https://discourse.pymc.io/t/warning-using-minibatch-and-censored-together-rng-variable-has-shared-clients/14943 and extends the MinibatchRV functionality for derived RVs.

Minibatch value variables are uniquely tricky because they are random graphs, that can share RNG with other variables in the forward / logp graph. As such we need to make sure they are not mutated for the default updates to work. We tried some tricks in the past but as revealed in the discourse issue that was not enough. This PR solves the problem by encapsulating the random graph in an OpFromGraph so that the inner graph will not be touched by PyMC logprob derivation routines. It will still be inlined in the final compiled functions to avoid overhead.

I also decided to deprecate Generators as data, which showed up in some of the refactoring. The GeneratorOp is not a true Op, which should not have any side-effects. It is also not compatible with non default backends like Numba and JAX that we are moving towards to. If needed, the logic should be handled by the sampler by consuming the generator and setting the values before subsequent function calls.

@ricardoV94 ricardoV94 added enhancements major Include in major changes release notes section labels Aug 27, 2024
@ricardoV94 ricardoV94 requested a review from ferrine August 27, 2024 15:34
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was already checked in the test above

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@ricardoV94 ricardoV94 force-pushed the minibatch_censored branch 2 times, most recently from 479c4a4 to 290a643 Compare August 29, 2024 14:41
Copy link

codecov bot commented Aug 29, 2024

Codecov Report

Attention: Patch coverage is 97.91667% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.15%. Comparing base (c92a9a9) to head (49542b5).
Report is 119 commits behind head on main.

Files with missing lines Patch % Lines
pymc/data.py 96.00% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7480      +/-   ##
==========================================
- Coverage   92.16%   92.15%   -0.02%     
==========================================
  Files         103      103              
  Lines       17214    17224      +10     
==========================================
+ Hits        15866    15873       +7     
- Misses       1348     1351       +3     
Files with missing lines Coverage Δ
pymc/logprob/basic.py 94.36% <100.00%> (ø)
pymc/logprob/rewriting.py 89.75% <100.00%> (ø)
pymc/model/core.py 91.75% <100.00%> (-0.03%) ⬇️
pymc/pytensorf.py 90.62% <100.00%> (+0.11%) ⬆️
pymc/variational/minibatch_rv.py 100.00% <100.00%> (ø)
pymc/variational/opvi.py 87.42% <100.00%> (ø)
pymc/data.py 89.09% <96.00%> (-0.36%) ⬇️

... and 2 files with indirect coverage changes

mb_tensors = [tensor[mb_indices] for tensor in tensors]

# Wrap graph in OFG so it's easily identifiable and not rewritten accidentally
*mb_tensors, _ = MinibatchOp([*tensors, rng], [*mb_tensors, rng_update])(*tensors, rng)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice trick, did not know that

__props__ = ("generator",)

def __init__(self, gen, default=None):
warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok



def test_fit_oo(inference, fit_kwargs, simple_model_data):
# Minibatch data can't be extracted into the `observed_data` group in the final InferenceData
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no more issues there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope I allow to extract the data in the idata, it extracts the whole data

minibatch_idx = minibatch_index(0, 10, size=(9,))
AD_mt = AD[minibatch_idx]
TD_mt = TD[minibatch_idx]
AD_mt, TD_mt = Minibatch(AD, TD, batch_size=9)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thisis much cleaner

@ferrine
Copy link
Member

ferrine commented Sep 7, 2024

I've created an issue to continue this work later and improve scalability of minibatches #7496

@ricardoV94 ricardoV94 merged commit 2856062 into pymc-devs:main Sep 7, 2024
@ricardoV94 ricardoV94 deleted the minibatch_censored branch September 7, 2024 17:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancements major Include in major changes release notes section

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants