-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add logprob support for leaky-ReLU switch transforms #7995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add logprob support for leaky-ReLU switch transforms #7995
Conversation
|
The more general requirement which we could try to see of we can cover is a switch where the two branches don't overlap, so when you invert you know for sure which branch it came from. |
|
that makes sense. correct me if im wrong, but the key property we need is that the two branches map to non overlapping regions in if so, we can extend the current pattern matcher to |
|
@ricardoV94 if my understanding is right would you prefer that i extend this PR with that generalization, or should I open a follow-up PR to keep the changes scoped? |
Yeah exactly! But you don't need to implement the logic to invert the branches, or the jacobian. If both branches of the switch are measurable, it means PyMC figured out the logp and you can just evaluate it at the final value (that will take care of those details). The question posed by the switch is which of them you need? Logp for this switch might look something like (pseudocode): def conditional_switch_logp(value, true_branch_rv, false_branch_rv):
value_implies_true_branch = f(value)
# Note the logp is evaluated at the value, the switch just gates which one is selected
logp = switch(value_implies_true_branch, logp(true_branch1_rv, value), logp(true_branch2_rv, value))
return logpI think (need to confirm) that the way things are setup, the rewrite that marks the switch as being measurable only has to worry about whether we meet the constraints you mentioned. Current code can already figure out the logp(Normal, value), or logp(Normal * a, value), for you. The strategy may look something like this sequence of checks:
Examples: But also: (y is what becomes value in the logp function) Restricting to the original monotonically increasing leaky RELU case is fine, but I would like to structure the code so it's ready to extend to more cases in the future. If once you figure it out, you want to extend that's awesome and welcome but not a blocker. How does that sound? |
|
@ricardoV94 thanks, that makes sense, i refactored the implementation to follow that approach.
this is currently scoped to the leaky ReLU pattern, but the structure is such that extending to other non overlapping switch patterns should be straightforward (separate predicate + constraints logic). if approved ill go ahead trying to implement the general “non-overlapping images” framework. |
|
@ricardoV94 should i make any more changes in this PR? |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay.
This is looking great, but I do have some comments / questions.
Let me know if my comments are off.
|
how about this @ricardoV94? |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty great. I only have minor requests/questions left
pymc/logprob/switch.py
Outdated
| measurable_switch_non_overlapping = MeasurableSwitchNonOverlapping(scalar_switch) | ||
|
|
||
|
|
||
| def _is_x_threshold_condition(cond: TensorVariable, x: TensorVariable) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better name for this helper to emphasize zero? perhaps _is_zero_x_threshold
|
|
||
| @node_rewriter([tensor_switch]) | ||
| def find_measurable_switch_non_overlapping(fgraph, node): | ||
| """Detect `switch(x > 0, x, a * x)` and replace it by a measurable op.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you open an issue for follow-up extensions? From the top of my mind:
-
We should try and support the equivalent switch written the other way around
switch(x <= 0, a * x, x). The user shouldn't have to guess the specific format that we support -
Allow scaling factors on both branches, still with the constraint that they must have the same sign (or if we want to remain more restrictive for now, that they are both non-negative). Just because it isn't too much harder than what we have now.
-
The more general cases discussed above with monotonic functions of
xthat don't overlap an the two branches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, that makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- We should try and support the equivalent switch written the other way around
switch(x <= 0, a * x, x). The user shouldn't have to guess the specific format that we support
I have raised #8049 for this, please check it and let me know if I should raise the rest of the issues as well, or if you’d prefer a different structure/scope.
I was thinking to raise issues for the following extensions:
- support equivalent
switchspellings (e.g.switch(x <= 0, a*x, x)/ swapped comparisons) so users don’t need to match a single canonical form. - support scaling on both branches (
switch(x > 0, a_pos*x, a_neg*x)) with constraints that guarantee non-overlapping images (start with both scales strictly positive or same-sign + non-zero). - track a broader "non-overlapping switch transform" framework for monotone branch functions of
xwhere the observed value implies the branch, while still gating between existing branch logps. - support non-zero thresholds
(x ? k)once the above is stable, since determining thevalue -> branchpredicate becomes more subtle.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's a solid plan. Feel free to describe all that in the issue or open separate ones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's a solid plan
pymc/logprob/switch.py
Outdated
|
|
||
| # For `a > 0`, `switch(x > 0, x, a * x)` maps to disjoint regions in `value`: | ||
| # true branch -> value > 0, false branch -> value <= 0. | ||
| value_implies_true_branch = pt.gt(value, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mathematically it should be the same with a continuous function, but the logp may be more stable? If we want to be pedantic we could check if the original cond had the exact zero in the true or neg branch.
| value_implies_true_branch = pt.gt(value, 0) | |
| value_implies_true_branch = pt.ge(value, 0) |
tests/logprob/test_transforms.py
Outdated
| pm.logp(y, v_pos, warn_rvs=False).eval(), | ||
| pm.logp(x, v_pos, warn_rvs=False).eval(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| pm.logp(y, v_pos, warn_rvs=False).eval(), | |
| pm.logp(x, v_pos, warn_rvs=False).eval(), | |
| pm.logp(y, v_pos).eval(), | |
| pm.logp(x, v_pos).eval(), |
tests/logprob/test_transforms.py
Outdated
|
|
||
| v_neg = -2.0 | ||
| np.testing.assert_allclose( | ||
| pm.logp(y, v_neg, warn_rvs=False).eval(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're testing with two values, define the logp variable once (or compile a function with the logp as output once), and reuse it. That will avoid duplicated logp inference calls.
tests/logprob/test_transforms.py
Outdated
| assert np.isfinite(pm.logp(y, 0.0, warn_rvs=False).eval()) | ||
|
|
||
|
|
||
| def test_leaky_relu_switch_logp_vectorized(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use vectorized in the first test, and remove this. It's not conceputally that different
tests/logprob/test_transforms.py
Outdated
| np.testing.assert_allclose(pm.logp(y, v, warn_rvs=False).eval(), expected) | ||
|
|
||
|
|
||
| def test_leaky_relu_switch_logp_symbolic_slope_checks_positive(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use symbolic a in the first test and test the error there. Still a pretty straightforward test
tests/logprob/test_transforms.py
Outdated
| ) | ||
|
|
||
|
|
||
| def test_leaky_relu_switch_logp_scalar(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests should be moved to a test_switch.py file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add a separate test that shows the failure if x is broadcast by cond or a, or if it's discrete.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests should be moved to a
test_switch.pyfile
oh right, that's my bad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also add a separate test that shows the failure if x is broadcast by cond or a, or if it's discrete.
sure
|
@ricardoV94 now? |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple more nits, just about tests now.
I don't understand why you keep passing warn_rvs=False, there shouldn't be any and the default is more strict
| if a is None: | ||
| raise NotImplementedError("Could not extract non-overlapping scale") | ||
|
|
||
| a_is_positive = pt.all(pt.gt(a, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it need to be positive or non-negative is enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it needs to be positive because if a == 0, the negative branch becomes 0 * x = 0, so a whole half-line of x collapses to a single value which is not invertible. ill add a comment about this in the file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah that makes sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OTOH If it goes on the 0 * x branch logp, that will already return nan (or inf if the value is zero):
pm.logp(pm.Normal.dist() * pt.scalar("a"), 1).eval({"a": 0}) # np.nan Probably better to leave as is though.
tests/logprob/test_switch.py
Outdated
| def test_switch_non_overlapping_logp_matches_change_of_variables(): | ||
| scale = pt.scalar("scale") | ||
| x = pm.Normal.dist(mu=0, sigma=1, size=(3,)) | ||
| y = cast(TensorVariable, pt.switch(x > 0, x, scale * x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for casting in the tests, we don't run mypy here and it clutters the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh sure
tests/logprob/test_switch.py
Outdated
| # No warning-based shortcuts: also match under default warn_rvs (scalar case) | ||
| x_s = pm.Normal.dist(mu=0, sigma=1) | ||
| y_s = cast(TensorVariable, pt.switch(x_s > 0, x_s, scale * x_s)) | ||
|
|
||
| v_pos = 1.2 | ||
| np.testing.assert_allclose(logp(y_s, v_pos).eval({scale: 0.5}), logp(x_s, v_pos).eval()) | ||
|
|
||
| v_neg = -2.0 | ||
| np.testing.assert_allclose( | ||
| logp(y_s, v_neg).eval({scale: 0.5}), | ||
| logp(x_s, v_neg / 0.5).eval() - np.log(0.5), | ||
| ) | ||
|
|
||
| # boundary point (measure-zero for continuous RVs): should still produce a finite logp | ||
| assert np.isfinite(logp(y_s, 0.0, warn_rvs=False).eval({scale: 0.5})) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these needed? Can't you test everything with the first functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah they aren’t needed.
tests/logprob/test_switch.py
Outdated
| y = cast(TensorVariable, pt.switch(cond, x, scale * x)) | ||
|
|
||
| with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"): | ||
| logp(y, np.zeros((2, 3)), warn_rvs=False).eval({scale: 0.5}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No point in eval as it should already fail in the logp call
tests/logprob/test_switch.py
Outdated
| y = cast(TensorVariable, pt.switch(x > 0, x, scale * x)) | ||
|
|
||
| with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"): | ||
| logp(y, np.zeros((3,)), warn_rvs=False).eval({scale: np.array([0.5, 0.5, 0.5])}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I would merge with previous test, conceptually similar
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I would merge with previous test, conceptually similar
| x = pm.Normal.dist(mu=0, sigma=1) | ||
| y = cast(TensorVariable, pt.switch(x > 0, x, scale * x)) | ||
|
|
||
| with pytest.raises(ParameterValueError, match="switch non-overlapping scale > 0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use the function from the first test to check the negative scale raising, no need for separate test
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7995 +/- ##
==========================================
+ Coverage 90.22% 91.42% +1.20%
==========================================
Files 116 117 +1
Lines 18972 19154 +182
==========================================
+ Hits 17117 17512 +395
+ Misses 1855 1642 -213
🚀 New features to boost your workflow:
|

Description
added log-probability support for leaky-ReLU graphs constructed as
where
xis a single continuous measurable variable.notes
amust be non-measurable and strictly positive.y == 0follows they <= 0branch (measure-zero set).Related Issue
Checklist
Type of change