Skip to content

Conversation

@eclipse1605
Copy link
Contributor

@eclipse1605 eclipse1605 commented Dec 15, 2025

Description

added log-probability support for leaky-ReLU graphs constructed as

y = switch(x > 0, x, a * x)

where x is a single continuous measurable variable.

notes

  • only supports a single continuous measurable variable.
  • the slope a must be non-measurable and strictly positive.
  • behavior at y == 0 follows the y <= 0 branch (measure-zero set).

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94
Copy link
Member

The more general requirement which we could try to see of we can cover is a switch where the two branches don't overlap, so when you invert you know for sure which branch it came from.

@eclipse1605
Copy link
Contributor Author

that makes sense. correct me if im wrong, but the key property we need is that the two branches map to non overlapping regions in y, so that when we observe a y we can tell which branch it came from and apply the correct inverse + Jacobian.

if so, we can extend the current pattern matcher to y = switch(x > k, f1(x), f2(x)). then we can extract (m1, b1) and (m2, b2) for the two branches. check that each branch is monotone (m != 0). check the images don’t overlap once you restrict to the domains x ? k. if the check passes, build the inverse per branch x = (y - b) / m and add the Jacobian term -log|m|, using a switch on y with the boundary induced by k.

@eclipse1605
Copy link
Contributor Author

@ricardoV94 if my understanding is right would you prefer that i extend this PR with that generalization, or should I open a follow-up PR to keep the changes scoped?

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 15, 2025

that makes sense. correct me if im wrong, but the key property we need is that the two branches map to non overlapping regions in y, so that when we observe a y we can tell which branch it came from and apply the correct inverse + Jacobian.

if so, we can extend the current pattern matcher to y = switch(x > k, f1(x), f2(x)). then we can extract (m1, b1) and (m2, b2) for the two branches. check that each branch is monotone (m != 0). check the images don’t overlap once you restrict to the domains x ? k. if the check passes, build the inverse per branch x = (y - b) / m and add the Jacobian term -log|m|, using a switch on y with the boundary induced by k.

Yeah exactly! But you don't need to implement the logic to invert the branches, or the jacobian. If both branches of the switch are measurable, it means PyMC figured out the logp and you can just evaluate it at the final value (that will take care of those details). The question posed by the switch is which of them you need?

Logp for this switch might look something like (pseudocode):

def conditional_switch_logp(value, true_branch_rv, false_branch_rv):
  value_implies_true_branch = f(value)
  # Note the logp is evaluated at the value, the switch just gates which one is selected
  logp = switch(value_implies_true_branch, logp(true_branch1_rv, value), logp(true_branch2_rv, value))
  return logp

I think (need to confirm) that the way things are setup, the rewrite that marks the switch as being measurable only has to worry about whether we meet the constraints you mentioned.

Current code can already figure out the logp(Normal, value), or logp(Normal * a, value), for you.

The strategy may look something like this sequence of checks:

  1. You have a switch
  2. It's not yet Measurable
  3. Both true and false branches are measurable (that means PyMC already known how to get their logp)
  4. We have a simple condition expression, x > k or something along those lines. This is where you find what x even is.
  5. Both branches are related to x
    5.1 If none is connected, this is already handled by the switch mixture machinery switch(cond, x, y), do nothing
    5.2 If only one is connected, and the other is a constant, it's actually a censored process something like switch(x > 0, k, x). But this will never be the case given requirement 3.
    5.3 if only one is connected, and the other is a measurable variable it's also fine, but the checks for invertibility may be trickier? Didn't want to think about this right now. If it seems simple to you let me know.
  6. The setup passes the constraints for invertibility (from what we can infer)

Examples:
y = switch(x > 0, x, x * a), a > 0 -> true_branch = y > 0
y = switch(x > 0, x * a, x * b), a,b > 0 -> true_branch = y > 0
y = switch(x > 0, x * a, x * b), a, b < 0 -> true_branch = y < 0

But also:
y = switch(x > 1, x ** 3, x) -> true_branch = y > 1
y = switch(x > 0, exp(x) - 1, x) -> true_branch = y > 0
y = switch(x > -1, x, exp(x+1) - 2) -> true_branch = y > -1
y = switch(x > 1, exp(x - 1), log(x) + 1) -> true_branch = y > 1

(y is what becomes value in the logp function)

Restricting to the original monotonically increasing leaky RELU case is fine, but I would like to structure the code so it's ready to extend to more cases in the future.

If once you figure it out, you want to extend that's awesome and welcome but not a blocker.

How does that sound?

@eclipse1605
Copy link
Contributor Author

@ricardoV94 thanks, that makes sense, i refactored the implementation to follow that approach.

  • rewrote find_measurable_leaky_relu_switch so it now only tags the switch as measurable when both branches are already measurable, so we can delegate all inversion/Jacobian details to existing logprob rules for each branch.
  • the _logprob for MeasurableLeakyReLUSwitch now just gates between branch logps evaluated at the observed value:
    switch(value > 0, _logprob_helper(x, value), _logprob_helper(neg_branch, value))
  • kept the runtime CheckParameterValue("leaky_relu slope > 0") guard to ensure the “value implies branch” predicate is valid, and attached it to the returned expression so it can’t get optimized away.

this is currently scoped to the leaky ReLU pattern, but the structure is such that extending to other non overlapping switch patterns should be straightforward (separate predicate + constraints logic). if approved ill go ahead trying to implement the general “non-overlapping images” framework.

@eclipse1605
Copy link
Contributor Author

@ricardoV94 should i make any more changes in this PR?

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay.

This is looking great, but I do have some comments / questions.

Let me know if my comments are off.

@eclipse1605
Copy link
Contributor Author

how about this @ricardoV94?

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty great. I only have minor requests/questions left

measurable_switch_non_overlapping = MeasurableSwitchNonOverlapping(scalar_switch)


def _is_x_threshold_condition(cond: TensorVariable, x: TensorVariable) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better name for this helper to emphasize zero? perhaps _is_zero_x_threshold


@node_rewriter([tensor_switch])
def find_measurable_switch_non_overlapping(fgraph, node):
"""Detect `switch(x > 0, x, a * x)` and replace it by a measurable op."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you open an issue for follow-up extensions? From the top of my mind:

  1. We should try and support the equivalent switch written the other way around switch(x <= 0, a * x, x). The user shouldn't have to guess the specific format that we support

  2. Allow scaling factors on both branches, still with the constraint that they must have the same sign (or if we want to remain more restrictive for now, that they are both non-negative). Just because it isn't too much harder than what we have now.

  3. The more general cases discussed above with monotonic functions of x that don't overlap an the two branches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, that makes sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We should try and support the equivalent switch written the other way around switch(x <= 0, a * x, x). The user shouldn't have to guess the specific format that we support

I have raised #8049 for this, please check it and let me know if I should raise the rest of the issues as well, or if you’d prefer a different structure/scope.

I was thinking to raise issues for the following extensions:

  • support equivalent switch spellings (e.g. switch(x <= 0, a*x, x) / swapped comparisons) so users don’t need to match a single canonical form.
  • support scaling on both branches (switch(x > 0, a_pos*x, a_neg*x)) with constraints that guarantee non-overlapping images (start with both scales strictly positive or same-sign + non-zero).
  • track a broader "non-overlapping switch transform" framework for monotone branch functions of x where the observed value implies the branch, while still gating between existing branch logps.
  • support non-zero thresholds (x ? k) once the above is stable, since determining the value -> branch predicate becomes more subtle.

Copy link
Member

@ricardoV94 ricardoV94 Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a solid plan. Feel free to describe all that in the issue or open separate ones

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a solid plan


# For `a > 0`, `switch(x > 0, x, a * x)` maps to disjoint regions in `value`:
# true branch -> value > 0, false branch -> value <= 0.
value_implies_true_branch = pt.gt(value, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mathematically it should be the same with a continuous function, but the logp may be more stable? If we want to be pedantic we could check if the original cond had the exact zero in the true or neg branch.

Suggested change
value_implies_true_branch = pt.gt(value, 0)
value_implies_true_branch = pt.ge(value, 0)

Comment on lines 242 to 243
pm.logp(y, v_pos, warn_rvs=False).eval(),
pm.logp(x, v_pos, warn_rvs=False).eval(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pm.logp(y, v_pos, warn_rvs=False).eval(),
pm.logp(x, v_pos, warn_rvs=False).eval(),
pm.logp(y, v_pos).eval(),
pm.logp(x, v_pos).eval(),


v_neg = -2.0
np.testing.assert_allclose(
pm.logp(y, v_neg, warn_rvs=False).eval(),
Copy link
Member

@ricardoV94 ricardoV94 Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're testing with two values, define the logp variable once (or compile a function with the logp as output once), and reuse it. That will avoid duplicated logp inference calls.

assert np.isfinite(pm.logp(y, 0.0, warn_rvs=False).eval())


def test_leaky_relu_switch_logp_vectorized():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use vectorized in the first test, and remove this. It's not conceputally that different

np.testing.assert_allclose(pm.logp(y, v, warn_rvs=False).eval(), expected)


def test_leaky_relu_switch_logp_symbolic_slope_checks_positive():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use symbolic a in the first test and test the error there. Still a pretty straightforward test

)


def test_leaky_relu_switch_logp_scalar():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests should be moved to a test_switch.py file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a separate test that shows the failure if x is broadcast by cond or a, or if it's discrete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests should be moved to a test_switch.py file

oh right, that's my bad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a separate test that shows the failure if x is broadcast by cond or a, or if it's discrete.

sure

@eclipse1605
Copy link
Contributor Author

@ricardoV94 now?

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple more nits, just about tests now.

I don't understand why you keep passing warn_rvs=False, there shouldn't be any and the default is more strict

if a is None:
raise NotImplementedError("Could not extract non-overlapping scale")

a_is_positive = pt.all(pt.gt(a, 0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be positive or non-negative is enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it needs to be positive because if a == 0, the negative branch becomes 0 * x = 0, so a whole half-line of x collapses to a single value which is not invertible. ill add a comment about this in the file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah that makes sense

Copy link
Member

@ricardoV94 ricardoV94 Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OTOH If it goes on the 0 * x branch logp, that will already return nan (or inf if the value is zero):

pm.logp(pm.Normal.dist() * pt.scalar("a"), 1).eval({"a": 0})  # np.nan 

Probably better to leave as is though.

def test_switch_non_overlapping_logp_matches_change_of_variables():
scale = pt.scalar("scale")
x = pm.Normal.dist(mu=0, sigma=1, size=(3,))
y = cast(TensorVariable, pt.switch(x > 0, x, scale * x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for casting in the tests, we don't run mypy here and it clutters the code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sure

Comment on lines 49 to 63
# No warning-based shortcuts: also match under default warn_rvs (scalar case)
x_s = pm.Normal.dist(mu=0, sigma=1)
y_s = cast(TensorVariable, pt.switch(x_s > 0, x_s, scale * x_s))

v_pos = 1.2
np.testing.assert_allclose(logp(y_s, v_pos).eval({scale: 0.5}), logp(x_s, v_pos).eval())

v_neg = -2.0
np.testing.assert_allclose(
logp(y_s, v_neg).eval({scale: 0.5}),
logp(x_s, v_neg / 0.5).eval() - np.log(0.5),
)

# boundary point (measure-zero for continuous RVs): should still produce a finite logp
assert np.isfinite(logp(y_s, 0.0, warn_rvs=False).eval({scale: 0.5}))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these needed? Can't you test everything with the first functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah they aren’t needed.

y = cast(TensorVariable, pt.switch(cond, x, scale * x))

with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"):
logp(y, np.zeros((2, 3)), warn_rvs=False).eval({scale: 0.5})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No point in eval as it should already fail in the logp call

y = cast(TensorVariable, pt.switch(x > 0, x, scale * x))

with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"):
logp(y, np.zeros((3,)), warn_rvs=False).eval({scale: np.array([0.5, 0.5, 0.5])})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I would merge with previous test, conceptually similar

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I would merge with previous test, conceptually similar

x = pm.Normal.dist(mu=0, sigma=1)
y = cast(TensorVariable, pt.switch(x > 0, x, scale * x))

with pytest.raises(ParameterValueError, match="switch non-overlapping scale > 0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use the function from the first test to check the negative scale raising, no need for separate test

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome

@codecov
Copy link

codecov bot commented Jan 13, 2026

Codecov Report

❌ Patch coverage is 80.89888% with 17 lines in your changes missing coverage. Please review.
✅ Project coverage is 91.42%. Comparing base (cadb97a) to head (4530814).
⚠️ Report is 22 commits behind head on main.

Files with missing lines Patch % Lines
pymc/logprob/switch.py 80.68% 17 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7995      +/-   ##
==========================================
+ Coverage   90.22%   91.42%   +1.20%     
==========================================
  Files         116      117       +1     
  Lines       18972    19154     +182     
==========================================
+ Hits        17117    17512     +395     
+ Misses       1855     1642     -213     
Files with missing lines Coverage Δ
pymc/logprob/__init__.py 100.00% <100.00%> (ø)
pymc/logprob/switch.py 80.68% <80.68%> (ø)

... and 25 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 merged commit c68c56e into pymc-devs:main Jan 13, 2026
40 of 42 checks passed
@welcome
Copy link

welcome bot commented Jan 13, 2026

Congratulations Banner]
Congrats on merging your first pull request! 🎉 We here at PyMC are proud of you! 💖 Thank you so much for your contribution 🎁

@ricardoV94 ricardoV94 changed the title add logprob support for leaky-ReLU switch transforms Add logprob support for leaky-ReLU switch transforms Jan 13, 2026
MSK-005 pushed a commit to MSK-005/pymc that referenced this pull request Jan 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Derive logprob of Leaky ReLU transform

2 participants