Skip to content

Conversation

@ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 17, 2025

This is a more proper fix for the problem highlighted in #7778

The normalizing constant for MinibatchRVs included the graph of the shape of the RVs.

Even though the shape of the MinibatchRV can be derived without evaluating the draws, passing any graph with RVs to pytensorf.compile will automatically integrate the updates which requires evaluating the RV anyway. This PR makes sure we don't include the RVs only to get the symbolic normalizing constant.


📚 Documentation preview 📚: https://pymc--7787.org.readthedocs.build/en/7787/

@ricardoV94 ricardoV94 added maintenance VI Variational Inference labels May 17, 2025
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR ensures that random variables (RVs) are not included in the symbolic normalizing constant graph by folding shapes to constants, adds shape inference for minibatch RVs, includes a test for the new behavior, and fixes a small typo.

  • Use constant_fold to derive batch shapes instead of carrying RVs into symbolic_normalizing_constant
  • Implement infer_shape on MinibatchRandomVariable so shape propagation works correctly
  • Add a dedicated test (assert_no_rvs) to confirm no RVs appear in the symbolic normalizing constant
  • Correct a typo in the constant_fold comment

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
tests/variational/test_opvi.py Added test_symbolic_normalizing_constant_no_rvs with assert_no_rvs
pymc/variational/opvi.py Swapped direct .shape usage for constant_fold([...].shape) in scaling
pymc/variational/minibatch_rv.py Added infer_shape method to propagate shapes without evaluation
pymc/pytensorf.py Fixed typo in comment (constand_foldingconstant_folding)
Comments suppressed due to low confidence (3)

tests/variational/test_opvi.py:284

  • [nitpick] The test verifies no RVs are in the graph but doesn't assert that the symbolic normalizing constant still produces the expected scalar or tensor shape. Consider adding an assertion on the returned value or shape to guard against regressions.
def test_symbolic_normalizing_constant_no_rvs():

pymc/variational/opvi.py:1109

  • Calling constant_fold inside the list comprehension for each RV will repeatedly clone and rewrite the graph, which may be costly. Consider computing all shapes once (e.g., collect inputs, call constant_fold outside the loop) or caching results before the comprehension.
get_scaling(

pymc/variational/opvi.py:1279

  • This mirrored use of constant_fold in another list comprehension also risks redundant graph rewriting. Extract a helper or hoist the folding step to improve efficiency and reduce duplicated logic.
get_scaling(

@codecov
Copy link

codecov bot commented May 17, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 92.84%. Comparing base (3a718f2) to head (0867cde).
⚠️ Report is 57 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7787   +/-   ##
=======================================
  Coverage   92.84%   92.84%           
=======================================
  Files         107      107           
  Lines       18378    18380    +2     
=======================================
+ Hits        17063    17065    +2     
  Misses       1315     1315           
Files with missing lines Coverage Δ
pymc/pytensorf.py 89.76% <ø> (ø)
pymc/variational/minibatch_rv.py 100.00% <100.00%> (ø)
pymc/variational/opvi.py 86.75% <ø> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 force-pushed the no_rvs_symbolic_normalizing_constant branch from 2551adf to 0867cde Compare May 18, 2025 13:30
@ricardoV94 ricardoV94 merged commit 618634b into pymc-devs:main May 18, 2025
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

maintenance VI Variational Inference

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants