Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import pymc.logprob.mixture
import pymc.logprob.order
import pymc.logprob.scan
import pymc.logprob.switch
import pymc.logprob.tensor
import pymc.logprob.transforms

Expand Down
202 changes: 202 additions & 0 deletions pymc/logprob/switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2021-2022 aesara-devs
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Measurable switch-based transforms."""

from typing import cast

import pytensor.tensor as pt

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar import Switch
from pytensor.scalar import switch as scalar_switch
from pytensor.scalar.basic import GE, GT, LE, LT, Mul
from pytensor.tensor.basic import switch as tensor_switch
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import MeasurableElemwise, MeasurableOp, _logprob, _logprob_helper
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.transforms import MeasurableTransform
from pymc.logprob.utils import (
CheckParameterValue,
check_potential_measurability,
filter_measurable_variables,
)


class MeasurableSwitchNonOverlapping(MeasurableElemwise):
"""Placeholder for switch transforms whose branch images do not overlap.

Currently matches leaky-ReLU graphs of the form `switch(x > 0, x, a * x)`.
"""

valid_scalar_types = (Switch,)


measurable_switch_non_overlapping = MeasurableSwitchNonOverlapping(scalar_switch)


def _is_x_threshold_condition(cond: TensorVariable, x: TensorVariable) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better name for this helper to emphasize zero? perhaps _is_zero_x_threshold

"""Check whether `cond` is equivalent to `x > 0` / `x >= 0` (and swapped forms)."""
if cond.owner is None:
return False
if not isinstance(cond.owner.op, Elemwise):
return False
scalar_op = cond.owner.op.scalar_op
if not isinstance(scalar_op, GT | GE | LT | LE):
return False

left, right = cond.owner.inputs

def _is_zero(v: TensorVariable) -> bool:
try:
return pt.get_underlying_scalar_constant_value(v) == 0
except NotScalarConstantError:
return False

# x > 0 or x >= 0
if left is x and _is_zero(cast(TensorVariable, right)) and isinstance(scalar_op, GT | GE):
return True
# 0 < x or 0 <= x
if right is x and _is_zero(cast(TensorVariable, left)) and isinstance(scalar_op, LT | LE):
return True

return False


def _extract_scale_from_measurable_mul(
neg_branch: TensorVariable, x: TensorVariable
) -> TensorVariable | None:
"""Extract scale `a` from a measurable multiplication that represents `a * x`."""
if neg_branch is x:
return pt.constant(1.0)

if neg_branch.owner is None:
return None

if not isinstance(neg_branch.owner.op, MeasurableTransform):
return None

op = neg_branch.owner.op
if not isinstance(op.scalar_op, Mul):
return None

# MeasurableTransform takes (measurable_input, scale)
if len(neg_branch.owner.inputs) != 2:
return None

if neg_branch.owner.inputs[op.measurable_input_idx] is not x:
return None

scale = neg_branch.owner.inputs[1 - op.measurable_input_idx]
return cast(TensorVariable, scale)


@node_rewriter([tensor_switch])
def find_measurable_switch_non_overlapping(fgraph, node):
"""Detect `switch(x > 0, x, a * x)` and replace it by a measurable op."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you open an issue for follow-up extensions? From the top of my mind:

  1. We should try and support the equivalent switch written the other way around switch(x <= 0, a * x, x). The user shouldn't have to guess the specific format that we support

  2. Allow scaling factors on both branches, still with the constraint that they must have the same sign (or if we want to remain more restrictive for now, that they are both non-negative). Just because it isn't too much harder than what we have now.

  3. The more general cases discussed above with monotonic functions of x that don't overlap an the two branches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, that makes sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We should try and support the equivalent switch written the other way around switch(x <= 0, a * x, x). The user shouldn't have to guess the specific format that we support

I have raised #8049 for this, please check it and let me know if I should raise the rest of the issues as well, or if you’d prefer a different structure/scope.

I was thinking to raise issues for the following extensions:

  • support equivalent switch spellings (e.g. switch(x <= 0, a*x, x) / swapped comparisons) so users don’t need to match a single canonical form.
  • support scaling on both branches (switch(x > 0, a_pos*x, a_neg*x)) with constraints that guarantee non-overlapping images (start with both scales strictly positive or same-sign + non-zero).
  • track a broader "non-overlapping switch transform" framework for monotone branch functions of x where the observed value implies the branch, while still gating between existing branch logps.
  • support non-zero thresholds (x ? k) once the above is stable, since determining the value -> branch predicate becomes more subtle.

Copy link
Member

@ricardoV94 ricardoV94 Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a solid plan. Feel free to describe all that in the issue or open separate ones

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a solid plan

if isinstance(node.op, MeasurableOp):
return None

cond, pos_branch, neg_branch = node.inputs

# Only mark the switch measurable once both branches are already measurable.
# Then the logprob can simply gate between branch logps evaluated at `value`.
if set(filter_measurable_variables([pos_branch, neg_branch])) != {pos_branch, neg_branch}:
return None

x = cast(TensorVariable, pos_branch)

if x.type.dtype.startswith("int"):
return None

if x.type.broadcastable != node.outputs[0].type.broadcastable:
return None

if not _is_x_threshold_condition(cast(TensorVariable, cond), x):
return None

a = _extract_scale_from_measurable_mul(cast(TensorVariable, neg_branch), x)
if a is None:
return None

# Disallow slope `a` that could be (directly or indirectly) measurable.
# This rewrite targets deterministic, non-overlapping transforms parametrized by non-RVs.
if check_potential_measurability([a]):
return None

return [
measurable_switch_non_overlapping(
cast(TensorVariable, cond),
x,
cast(TensorVariable, neg_branch),
)
]


@_logprob.register(MeasurableSwitchNonOverlapping)
def logprob_switch_non_overlapping(op, values, cond, x, neg_branch, **kwargs):
(value,) = values

a = _extract_scale_from_measurable_mul(
cast(TensorVariable, neg_branch), cast(TensorVariable, x)
)
if a is None:
raise NotImplementedError("Could not extract non-overlapping scale")

a_is_positive = pt.all(pt.gt(a, 0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be positive or non-negative is enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it needs to be positive because if a == 0, the negative branch becomes 0 * x = 0, so a whole half-line of x collapses to a single value which is not invertible. ill add a comment about this in the file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah that makes sense

Copy link
Member

@ricardoV94 ricardoV94 Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OTOH If it goes on the 0 * x branch logp, that will already return nan (or inf if the value is zero):

pm.logp(pm.Normal.dist() * pt.scalar("a"), 1).eval({"a": 0})  # np.nan 

Probably better to leave as is though.


# For `a > 0`, `switch(x > 0, x, a * x)` maps to disjoint regions in `value`:
# true branch -> value > 0, false branch -> value <= 0.
value_implies_true_branch = pt.gt(value, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mathematically it should be the same with a continuous function, but the logp may be more stable? If we want to be pedantic we could check if the original cond had the exact zero in the true or neg branch.

Suggested change
value_implies_true_branch = pt.gt(value, 0)
value_implies_true_branch = pt.ge(value, 0)


logp_expr = pt.switch(
value_implies_true_branch,
_logprob_helper(x, value, **kwargs),
_logprob_helper(neg_branch, value, **kwargs),
)

return CheckParameterValue("switch non-overlapping scale > 0")(logp_expr, a_is_positive)


measurable_ir_rewrites_db.register(
"find_measurable_switch_non_overlapping",
find_measurable_switch_non_overlapping,
"basic",
"transform",
)
54 changes: 54 additions & 0 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

from pytensor.graph.basic import equal_computations

import pymc as pm

from pymc.distributions.continuous import Cauchy, ChiSquared
from pymc.distributions.discrete import Bernoulli
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
Expand Down Expand Up @@ -219,6 +221,7 @@ def test_exp_transform_rv():
logp_fn(y_val),
sp.stats.lognorm(s=1).logpdf(y_val),
)

np.testing.assert_almost_equal(
logcdf_fn(y_val),
sp.stats.lognorm(s=1).logcdf(y_val),
Expand All @@ -229,6 +232,57 @@ def test_exp_transform_rv():
)


def test_leaky_relu_switch_logp_scalar():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests should be moved to a test_switch.py file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a separate test that shows the failure if x is broadcast by cond or a, or if it's discrete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests should be moved to a test_switch.py file

oh right, that's my bad

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also add a separate test that shows the failure if x is broadcast by cond or a, or if it's discrete.

sure

a = 0.5
x = pm.Normal.dist(mu=0, sigma=1)
y = pm.math.switch(x > 0, x, a * x)

v_pos = 1.2
np.testing.assert_allclose(
pm.logp(y, v_pos, warn_rvs=False).eval(),
pm.logp(x, v_pos, warn_rvs=False).eval(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pm.logp(y, v_pos, warn_rvs=False).eval(),
pm.logp(x, v_pos, warn_rvs=False).eval(),
pm.logp(y, v_pos).eval(),
pm.logp(x, v_pos).eval(),

)

v_neg = -2.0
np.testing.assert_allclose(
pm.logp(y, v_neg, warn_rvs=False).eval(),
Copy link
Member

@ricardoV94 ricardoV94 Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're testing with two values, define the logp variable once (or compile a function with the logp as output once), and reuse it. That will avoid duplicated logp inference calls.

pm.logp(x, v_neg / a, warn_rvs=False).eval() - np.log(a),
)

# boundary point (measure-zero for continuous RVs): should still produce a finite logp
assert np.isfinite(pm.logp(y, 0.0, warn_rvs=False).eval())


def test_leaky_relu_switch_logp_vectorized():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use vectorized in the first test, and remove this. It's not conceputally that different

a = 0.5
x = pm.Normal.dist(mu=0, sigma=1, size=(3,))
y = pm.math.switch(x > 0, x, a * x)

v = np.array([-2.0, 0.0, 1.5])
expected = pm.logp(x, np.where(v > 0, v, v / a), warn_rvs=False).eval() + np.where(
v > 0, 0.0, -np.log(a)
)
np.testing.assert_allclose(pm.logp(y, v, warn_rvs=False).eval(), expected)


def test_leaky_relu_switch_logp_symbolic_slope_checks_positive():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use symbolic a in the first test and test the error there. Still a pretty straightforward test

a = pt.scalar("a")
x = pm.Normal.dist(mu=0, sigma=1)
y = pm.math.switch(x > 0, x, a * x)

# positive slope passes
res = pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.5})
expected = pm.logp(x, -1.0 / 0.5, warn_rvs=False).eval() - np.log(0.5)
np.testing.assert_allclose(res, expected)

# non pos slope raises
with pytest.raises(ParameterValueError, match="switch non-overlapping scale > 0"):
pm.logp(y, -1.0, warn_rvs=False).eval({a: -0.5})

with pytest.raises(ParameterValueError, match="switch non-overlapping scale > 0"):
pm.logp(y, -1.0, warn_rvs=False).eval({a: 0.0})


def test_log_transform_rv():
base_rv = pt.random.lognormal(0, 1, size=2, name="base_rv")
y_rv = pt.log(base_rv)
Expand Down