Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ jobs:
tests/logprob/test_rewriting.py
tests/logprob/test_scan.py
tests/logprob/test_tensor.py
tests/logprob/test_switch.py
tests/logprob/test_transform_value.py
tests/logprob/test_transforms.py
tests/logprob/test_utils.py
Expand Down
1 change: 1 addition & 0 deletions pymc/logprob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import pymc.logprob.mixture
import pymc.logprob.order
import pymc.logprob.scan
import pymc.logprob.switch
import pymc.logprob.tensor
import pymc.logprob.transforms

Expand Down
225 changes: 225 additions & 0 deletions pymc/logprob/switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2021-2022 aesara-devs
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Measurable switch-based transforms."""

from typing import cast

import pytensor.tensor as pt

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar import Switch
from pytensor.scalar import switch as scalar_switch
from pytensor.scalar.basic import GE, GT, LE, LT, Mul
from pytensor.tensor.basic import switch as tensor_switch
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import MeasurableElemwise, MeasurableOp, _logprob, _logprob_helper
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.transforms import MeasurableTransform
from pymc.logprob.utils import (
CheckParameterValue,
check_potential_measurability,
filter_measurable_variables,
)


class MeasurableSwitchNonOverlapping(MeasurableElemwise):
"""Placeholder for switch transforms whose branch images do not overlap.

Currently matches leaky-ReLU graphs of the form `switch(x > 0, x, a * x)`.
"""

valid_scalar_types = (Switch,)


measurable_switch_non_overlapping = MeasurableSwitchNonOverlapping(scalar_switch)


def _zero_x_threshold_true_includes_zero(cond: TensorVariable, x: TensorVariable) -> bool | None:
"""Return whether `cond` is a zero threshold on `x` and includes `0` in the true branch.

Matches `x > 0`, `x >= 0` and swapped forms `0 < x`, `0 <= x`.

Returns
-------
- `False` for strict comparisons (`>`/`<`)
- `True` for non-strict comparisons (`>=`/`<=`)
- `None` if `cond` doesn't match a zero-threshold comparison on `x`
"""
if cond.owner is None:
return None
if not isinstance(cond.owner.op, Elemwise):
return None
scalar_op = cond.owner.op.scalar_op
if not isinstance(scalar_op, GT | GE | LT | LE):
return None

left, right = cond.owner.inputs

def _is_zero(v: TensorVariable) -> bool:
try:
return pt.get_underlying_scalar_constant_value(v) == 0
except NotScalarConstantError:
return False

# x > 0 or x >= 0
if left is x and _is_zero(cast(TensorVariable, right)) and isinstance(scalar_op, GT | GE):
return isinstance(scalar_op, GE)
# 0 < x or 0 <= x
if right is x and _is_zero(cast(TensorVariable, left)) and isinstance(scalar_op, LT | LE):
return isinstance(scalar_op, LE)

return None


def _extract_scale_from_measurable_mul(
neg_branch: TensorVariable, x: TensorVariable
) -> TensorVariable | None:
"""Extract scale `a` from a measurable multiplication that represents `a * x`."""
if neg_branch is x:
return pt.constant(1.0)

if neg_branch.owner is None:
return None

if not isinstance(neg_branch.owner.op, MeasurableTransform):
return None

op = neg_branch.owner.op
if not isinstance(op.scalar_op, Mul):
return None

# MeasurableTransform takes (measurable_input, scale)
if len(neg_branch.owner.inputs) != 2:
return None

if neg_branch.owner.inputs[op.measurable_input_idx] is not x:
return None

scale = neg_branch.owner.inputs[1 - op.measurable_input_idx]
return cast(TensorVariable, scale)


@node_rewriter([tensor_switch])
def find_measurable_switch_non_overlapping(fgraph, node):
"""Detect `switch(x > 0, x, a * x)` and replace it by a measurable op."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you open an issue for follow-up extensions? From the top of my mind:

  1. We should try and support the equivalent switch written the other way around switch(x <= 0, a * x, x). The user shouldn't have to guess the specific format that we support

  2. Allow scaling factors on both branches, still with the constraint that they must have the same sign (or if we want to remain more restrictive for now, that they are both non-negative). Just because it isn't too much harder than what we have now.

  3. The more general cases discussed above with monotonic functions of x that don't overlap an the two branches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, that makes sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We should try and support the equivalent switch written the other way around switch(x <= 0, a * x, x). The user shouldn't have to guess the specific format that we support

I have raised #8049 for this, please check it and let me know if I should raise the rest of the issues as well, or if you’d prefer a different structure/scope.

I was thinking to raise issues for the following extensions:

  • support equivalent switch spellings (e.g. switch(x <= 0, a*x, x) / swapped comparisons) so users don’t need to match a single canonical form.
  • support scaling on both branches (switch(x > 0, a_pos*x, a_neg*x)) with constraints that guarantee non-overlapping images (start with both scales strictly positive or same-sign + non-zero).
  • track a broader "non-overlapping switch transform" framework for monotone branch functions of x where the observed value implies the branch, while still gating between existing branch logps.
  • support non-zero thresholds (x ? k) once the above is stable, since determining the value -> branch predicate becomes more subtle.

Copy link
Member

@ricardoV94 ricardoV94 Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a solid plan. Feel free to describe all that in the issue or open separate ones

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's a solid plan

if isinstance(node.op, MeasurableOp):
return None

cond, pos_branch, neg_branch = node.inputs

# Only mark the switch measurable once both branches are already measurable.
# Then the logprob can simply gate between branch logps evaluated at `value`.
if set(filter_measurable_variables([pos_branch, neg_branch])) != {pos_branch, neg_branch}:
return None

x = cast(TensorVariable, pos_branch)

if x.type.numpy_dtype.kind != "f":
return None

# Avoid rewriting cases where `x` is broadcasted/replicated by `cond` or `a`.
# We require the positive branch to be a base `RandomVariable` output.
if x.owner is None or not isinstance(x.owner.op, RandomVariable):
return None

if x.type.broadcastable != node.outputs[0].type.broadcastable:
return None

includes_zero_in_true = _zero_x_threshold_true_includes_zero(cast(TensorVariable, cond), x)
if includes_zero_in_true is None:
return None

a = _extract_scale_from_measurable_mul(cast(TensorVariable, neg_branch), x)
if a is None:
return None

# Disallow slope `a` that could be (directly or indirectly) measurable.
# This rewrite targets deterministic, non-overlapping transforms parametrized by non-RVs.
if check_potential_measurability([a]):
return None

return [
measurable_switch_non_overlapping(
cast(TensorVariable, cond),
x,
cast(TensorVariable, neg_branch),
)
]


@_logprob.register(MeasurableSwitchNonOverlapping)
def logprob_switch_non_overlapping(op, values, cond, x, neg_branch, **kwargs):
(value,) = values

a = _extract_scale_from_measurable_mul(
cast(TensorVariable, neg_branch), cast(TensorVariable, x)
)
if a is None:
raise NotImplementedError("Could not extract non-overlapping scale")

a_is_positive = pt.all(pt.gt(a, 0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be positive or non-negative is enough?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it needs to be positive because if a == 0, the negative branch becomes 0 * x = 0, so a whole half-line of x collapses to a single value which is not invertible. ill add a comment about this in the file

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah that makes sense

Copy link
Member

@ricardoV94 ricardoV94 Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OTOH If it goes on the 0 * x branch logp, that will already return nan (or inf if the value is zero):

pm.logp(pm.Normal.dist() * pt.scalar("a"), 1).eval({"a": 0})  # np.nan 

Probably better to leave as is though.


includes_zero_in_true = _zero_x_threshold_true_includes_zero(
cast(TensorVariable, cond), cast(TensorVariable, x)
)
if includes_zero_in_true is None:
raise NotImplementedError("Could not identify zero-threshold condition")

# For `a > 0`, `switch(x > 0, x, a * x)` maps to disjoint regions in `value`.
# Select the branch using the observed `value` and the strictness of the original
# comparison (`>` vs `>=`).
value_implies_true_branch = pt.ge(value, 0) if includes_zero_in_true else pt.gt(value, 0)

logp_expr = pt.switch(
value_implies_true_branch,
_logprob_helper(x, value, **kwargs),
_logprob_helper(neg_branch, value, **kwargs),
)

return CheckParameterValue("switch non-overlapping scale > 0")(logp_expr, a_is_positive)


measurable_ir_rewrites_db.register(
"find_measurable_switch_non_overlapping",
find_measurable_switch_non_overlapping,
"basic",
"transform",
)
103 changes: 103 additions & 0 deletions tests/logprob/test_switch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2026 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast

import numpy as np
import pytensor.tensor as pt
import pytest

from pytensor.compile.function import function
from pytensor.tensor.variable import TensorVariable

import pymc as pm

from pymc.logprob.basic import logp
from pymc.logprob.utils import ParameterValueError


def test_switch_non_overlapping_logp_matches_change_of_variables():
scale = pt.scalar("scale")
x = pm.Normal.dist(mu=0, sigma=1, size=(3,))
y = cast(TensorVariable, pt.switch(x > 0, x, scale * x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for casting in the tests, we don't run mypy here and it clutters the code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sure


vv = pt.vector("vv")

logp_y = logp(y, vv, warn_rvs=False)
inv = cast(TensorVariable, pt.switch(pt.gt(vv, 0), vv, vv / scale))
expected = logp(x, inv, warn_rvs=False) + cast(
TensorVariable,
pt.switch(pt.gt(vv, 0), 0.0, -cast(TensorVariable, pt.log(scale))),
)

logp_y_fn = function([vv, scale], logp_y)
expected_fn = function([vv, scale], expected)

v = np.array([-2.0, 0.0, 1.5])
np.testing.assert_allclose(logp_y_fn(v, 0.5), expected_fn(v, 0.5))

# No warning-based shortcuts: also match under default warn_rvs (scalar case)
x_s = pm.Normal.dist(mu=0, sigma=1)
y_s = cast(TensorVariable, pt.switch(x_s > 0, x_s, scale * x_s))

v_pos = 1.2
np.testing.assert_allclose(logp(y_s, v_pos).eval({scale: 0.5}), logp(x_s, v_pos).eval())

v_neg = -2.0
np.testing.assert_allclose(
logp(y_s, v_neg).eval({scale: 0.5}),
logp(x_s, v_neg / 0.5).eval() - np.log(0.5),
)

# boundary point (measure-zero for continuous RVs): should still produce a finite logp
assert np.isfinite(logp(y_s, 0.0, warn_rvs=False).eval({scale: 0.5}))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these needed? Can't you test everything with the first functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah they aren’t needed.



def test_switch_non_overlapping_requires_positive_scale():
scale = pt.scalar("scale")
x = pm.Normal.dist(mu=0, sigma=1)
y = cast(TensorVariable, pt.switch(x > 0, x, scale * x))

with pytest.raises(ParameterValueError, match="switch non-overlapping scale > 0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use the function from the first test to check the negative scale raising, no need for separate test

logp(y, -1.0, warn_rvs=False).eval({scale: -0.5})

with pytest.raises(ParameterValueError, match="switch non-overlapping scale > 0"):
logp(y, -1.0, warn_rvs=False).eval({scale: 0.0})


def test_switch_non_overlapping_does_not_rewrite_if_x_replicated_by_condition():
scale = pt.scalar("scale")
x = pm.Normal.dist(mu=0, sigma=1, size=(3,))
cond = (x[None, :] > 0) & pt.ones((2, 1), dtype="bool")
y = cast(TensorVariable, pt.switch(cond, x, scale * x))

with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"):
logp(y, np.zeros((2, 3)), warn_rvs=False).eval({scale: 0.5})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No point in eval as it should already fail in the logp call



def test_switch_non_overlapping_does_not_rewrite_if_scale_broadcasts_x():
x = pm.Normal.dist(mu=0, sigma=1)
scale = pt.vector("scale")
y = cast(TensorVariable, pt.switch(x > 0, x, scale * x))

with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"):
logp(y, np.zeros((3,)), warn_rvs=False).eval({scale: np.array([0.5, 0.5, 0.5])})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I would merge with previous test, conceptually similar

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I would merge with previous test, conceptually similar



def test_switch_non_overlapping_does_not_apply_to_discrete_rv():
a = pt.scalar("a")
x = pm.Poisson.dist(3)
y = cast(TensorVariable, pt.switch(x > 0, x, a * x))

with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"):
logp(y, 1, warn_rvs=False).eval({a: 0.5})
1 change: 1 addition & 0 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def test_exp_transform_rv():
logp_fn(y_val),
sp.stats.lognorm(s=1).logpdf(y_val),
)

np.testing.assert_almost_equal(
logcdf_fn(y_val),
sp.stats.lognorm(s=1).logcdf(y_val),
Expand Down