-
Notifications
You must be signed in to change notification settings - Fork 2.2k
observe(sum) for Normal via rewrite to equivalent NormalRV #8067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
observe(sum) for Normal via rewrite to equivalent NormalRV #8067
Conversation
eclipse1605
commented
Jan 22, 2026
- Closes ENH: pymc.math.sum could not be observed #7990
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #8067 +/- ##
==========================================
+ Coverage 90.22% 90.89% +0.66%
==========================================
Files 116 123 +7
Lines 18972 19489 +517
==========================================
+ Hits 17117 17714 +597
+ Misses 1855 1775 -80
🚀 New features to boost your workflow:
|
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, left some minor comments
pymc/logprob/order.py
Outdated
| if not filter_measurable_variables(node.inputs): | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, NormalRV is always measurable
| if not filter_measurable_variables(node.inputs): | |
| return None |
pymc/logprob/order.py
Outdated
|
|
||
|
|
||
| @node_rewriter([Sum]) | ||
| def find_measurable_sum(fgraph: FunctionGraph, node: Apply) -> list[TensorVariable] | None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't be in this file. Perhaps arithmetic.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true, but wont tensor.py be better for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's more meant for shape / concatenation operations, not mathematical in nature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense
pymc/logprob/order.py
Outdated
| if getattr(latent_op, "ndim_supp", None) != 0: | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed, always the case for NormalRV
| if getattr(latent_op, "ndim_supp", None) != 0: | |
| return None |
pymc/logprob/order.py
Outdated
| return None | ||
| if getattr(latent_op, "ndim_supp", None) != 0: | ||
| return None | ||
| base_var = cast(TensorVariable, base_var) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't add type casts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was also failing mypy
pymc/logprob/order.py
Outdated
| if axis != tuple(range(base_var.ndim)): | ||
| return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could support this, just means we do the mean/std aggregations along the summed axis
pymc/logprob/order.py
Outdated
| mu_sum = pt.sum(mu_b) | ||
| sigma_sum = pt.sqrt(pt.sum(pt.square(sigma_b))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will handle arbitrary axes?
| mu_sum = pt.sum(mu_b) | |
| sigma_sum = pt.sqrt(pt.sum(pt.square(sigma_b))) | |
| mu_sum = pt.sum(mu_b, axis=axis) | |
| sigma_sum = pt.sqrt(pt.sum(pt.square(sigma_b), axis=axis)) |
pymc/logprob/order.py
Outdated
| sigma_sum = pt.sqrt(pt.sum(pt.square(sigma_b))) | ||
|
|
||
| # Create a scalar NormalRV for the sum | ||
| rng = base_var.owner.inputs[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you are using all args anyway just unpack them once:
rng, size, mu, sigma = base_var.owner.inputs|
@ricardoV94 does this seem good? |
|
@eclipse1605 left some more comments. Also for future PRs, feel free to close addressed comments |
sure, you mean "resolve conversation" right? |
Yup |
|
We should follow up with |