Skip to content

Commit c68c56e

Browse files
authored
Add logprob support for leaky-ReLU switch transforms (#7995)
1 parent 689f736 commit c68c56e

File tree

5 files changed

+305
-0
lines changed

5 files changed

+305
-0
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ jobs:
130130
tests/logprob/test_rewriting.py
131131
tests/logprob/test_scan.py
132132
tests/logprob/test_tensor.py
133+
tests/logprob/test_switch.py
133134
tests/logprob/test_transform_value.py
134135
tests/logprob/test_transforms.py
135136
tests/logprob/test_utils.py

pymc/logprob/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import pymc.logprob.mixture
5555
import pymc.logprob.order
5656
import pymc.logprob.scan
57+
import pymc.logprob.switch
5758
import pymc.logprob.tensor
5859
import pymc.logprob.transforms
5960

pymc/logprob/switch.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright 2024 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# MIT License
16+
#
17+
# Copyright (c) 2021-2022 aesara-devs
18+
#
19+
# Permission is hereby granted, free of charge, to any person obtaining a copy
20+
# of this software and associated documentation files (the "Software"), to deal
21+
# in the Software without restriction, including without limitation the rights
22+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23+
# copies of the Software, and to permit persons to whom the Software is
24+
# furnished to do so, subject to the following conditions:
25+
#
26+
# The above copyright notice and this permission notice shall be included in all
27+
# copies or substantial portions of the Software.
28+
#
29+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35+
# SOFTWARE.
36+
37+
"""Measurable switch-based transforms."""
38+
39+
from typing import cast
40+
41+
import pytensor.tensor as pt
42+
43+
from pytensor.graph.rewriting.basic import node_rewriter
44+
from pytensor.scalar import Switch
45+
from pytensor.scalar import switch as scalar_switch
46+
from pytensor.scalar.basic import GE, GT, LE, LT, Mul
47+
from pytensor.tensor.basic import switch as tensor_switch
48+
from pytensor.tensor.elemwise import Elemwise
49+
from pytensor.tensor.exceptions import NotScalarConstantError
50+
from pytensor.tensor.random.op import RandomVariable
51+
from pytensor.tensor.variable import TensorVariable
52+
53+
from pymc.logprob.abstract import MeasurableElemwise, MeasurableOp, _logprob, _logprob_helper
54+
from pymc.logprob.rewriting import measurable_ir_rewrites_db
55+
from pymc.logprob.transforms import MeasurableTransform
56+
from pymc.logprob.utils import (
57+
CheckParameterValue,
58+
check_potential_measurability,
59+
filter_measurable_variables,
60+
)
61+
62+
63+
class MeasurableSwitchNonOverlapping(MeasurableElemwise):
64+
"""Placeholder for switch transforms whose branch images do not overlap.
65+
66+
Currently matches leaky-ReLU graphs of the form `switch(x > 0, x, a * x)`.
67+
"""
68+
69+
valid_scalar_types = (Switch,)
70+
71+
72+
measurable_switch_non_overlapping = MeasurableSwitchNonOverlapping(scalar_switch)
73+
74+
75+
def _zero_x_threshold_true_includes_zero(cond: TensorVariable, x: TensorVariable) -> bool | None:
76+
"""Return whether `cond` is a zero threshold on `x` and includes `0` in the true branch.
77+
78+
Matches `x > 0`, `x >= 0` and swapped forms `0 < x`, `0 <= x`.
79+
80+
Returns
81+
-------
82+
- `False` for strict comparisons (`>`/`<`)
83+
- `True` for non-strict comparisons (`>=`/`<=`)
84+
- `None` if `cond` doesn't match a zero-threshold comparison on `x`
85+
"""
86+
if cond.owner is None:
87+
return None
88+
if not isinstance(cond.owner.op, Elemwise):
89+
return None
90+
scalar_op = cond.owner.op.scalar_op
91+
if not isinstance(scalar_op, GT | GE | LT | LE):
92+
return None
93+
94+
left, right = cond.owner.inputs
95+
96+
def _is_zero(v: TensorVariable) -> bool:
97+
try:
98+
return pt.get_underlying_scalar_constant_value(v) == 0
99+
except NotScalarConstantError:
100+
return False
101+
102+
# x > 0 or x >= 0
103+
if left is x and _is_zero(cast(TensorVariable, right)) and isinstance(scalar_op, GT | GE):
104+
return isinstance(scalar_op, GE)
105+
# 0 < x or 0 <= x
106+
if right is x and _is_zero(cast(TensorVariable, left)) and isinstance(scalar_op, LT | LE):
107+
return isinstance(scalar_op, LE)
108+
109+
return None
110+
111+
112+
def _extract_scale_from_measurable_mul(
113+
neg_branch: TensorVariable, x: TensorVariable
114+
) -> TensorVariable | None:
115+
"""Extract scale `a` from a measurable multiplication that represents `a * x`."""
116+
if neg_branch is x:
117+
return pt.constant(1.0)
118+
119+
if neg_branch.owner is None:
120+
return None
121+
122+
if not isinstance(neg_branch.owner.op, MeasurableTransform):
123+
return None
124+
125+
op = neg_branch.owner.op
126+
if not isinstance(op.scalar_op, Mul):
127+
return None
128+
129+
# MeasurableTransform takes (measurable_input, scale)
130+
if len(neg_branch.owner.inputs) != 2:
131+
return None
132+
133+
if neg_branch.owner.inputs[op.measurable_input_idx] is not x:
134+
return None
135+
136+
scale = neg_branch.owner.inputs[1 - op.measurable_input_idx]
137+
return cast(TensorVariable, scale)
138+
139+
140+
@node_rewriter([tensor_switch])
141+
def find_measurable_switch_non_overlapping(fgraph, node):
142+
"""Detect `switch(x > 0, x, a * x)` and replace it by a measurable op."""
143+
if isinstance(node.op, MeasurableOp):
144+
return None
145+
146+
cond, pos_branch, neg_branch = node.inputs
147+
148+
# Only mark the switch measurable once both branches are already measurable.
149+
# Then the logprob can simply gate between branch logps evaluated at `value`.
150+
if set(filter_measurable_variables([pos_branch, neg_branch])) != {pos_branch, neg_branch}:
151+
return None
152+
153+
x = cast(TensorVariable, pos_branch)
154+
155+
if x.type.numpy_dtype.kind != "f":
156+
return None
157+
158+
# Avoid rewriting cases where `x` is broadcasted/replicated by `cond` or `a`.
159+
# We require the positive branch to be a base `RandomVariable` output.
160+
if x.owner is None or not isinstance(x.owner.op, RandomVariable):
161+
return None
162+
163+
if x.type.broadcastable != node.outputs[0].type.broadcastable:
164+
return None
165+
166+
includes_zero_in_true = _zero_x_threshold_true_includes_zero(cast(TensorVariable, cond), x)
167+
if includes_zero_in_true is None:
168+
return None
169+
170+
a = _extract_scale_from_measurable_mul(cast(TensorVariable, neg_branch), x)
171+
if a is None:
172+
return None
173+
174+
# Disallow slope `a` that could be (directly or indirectly) measurable.
175+
# This rewrite targets deterministic, non-overlapping transforms parametrized by non-RVs.
176+
if check_potential_measurability([a]):
177+
return None
178+
179+
return [
180+
measurable_switch_non_overlapping(
181+
cast(TensorVariable, cond),
182+
x,
183+
cast(TensorVariable, neg_branch),
184+
)
185+
]
186+
187+
188+
@_logprob.register(MeasurableSwitchNonOverlapping)
189+
def logprob_switch_non_overlapping(op, values, cond, x, neg_branch, **kwargs):
190+
(value,) = values
191+
192+
a = _extract_scale_from_measurable_mul(
193+
cast(TensorVariable, neg_branch), cast(TensorVariable, x)
194+
)
195+
if a is None:
196+
raise NotImplementedError("Could not extract non-overlapping scale")
197+
198+
# Must be strictly positive: a == 0 is not invertible (collapses a region) and
199+
# invalidates the non-overlapping branch inference.
200+
a_is_positive = pt.all(pt.gt(a, 0))
201+
202+
includes_zero_in_true = _zero_x_threshold_true_includes_zero(
203+
cast(TensorVariable, cond), cast(TensorVariable, x)
204+
)
205+
if includes_zero_in_true is None:
206+
raise NotImplementedError("Could not identify zero-threshold condition")
207+
208+
# For `a > 0`, `switch(x > 0, x, a * x)` maps to disjoint regions in `value`.
209+
# Select the branch using the observed `value` and the strictness of the original
210+
# comparison (`>` vs `>=`).
211+
value_implies_true_branch = pt.ge(value, 0) if includes_zero_in_true else pt.gt(value, 0)
212+
213+
logp_expr = pt.switch(
214+
value_implies_true_branch,
215+
_logprob_helper(x, value, **kwargs),
216+
_logprob_helper(neg_branch, value, **kwargs),
217+
)
218+
219+
return CheckParameterValue("switch non-overlapping scale > 0")(logp_expr, a_is_positive)
220+
221+
222+
measurable_ir_rewrites_db.register(
223+
"find_measurable_switch_non_overlapping",
224+
find_measurable_switch_non_overlapping,
225+
"basic",
226+
"transform",
227+
)

tests/logprob/test_switch.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2026 - present The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytensor.tensor as pt
16+
import pytest
17+
18+
from pytensor.compile.function import function
19+
20+
import pymc as pm
21+
22+
from pymc.logprob.basic import logp
23+
from pymc.logprob.utils import ParameterValueError
24+
25+
26+
def test_switch_non_overlapping_logp_matches_change_of_variables():
27+
scale = pt.scalar("scale")
28+
x = pm.Normal.dist(mu=0, sigma=1, size=(3,))
29+
y = pt.switch(x > 0, x, scale * x)
30+
31+
vv = pt.vector("vv")
32+
33+
logp_y = logp(y, vv)
34+
inv = pt.switch(pt.gt(vv, 0), vv, vv / scale)
35+
expected = logp(x, inv) + pt.switch(pt.gt(vv, 0), 0.0, -pt.log(scale))
36+
37+
logp_y_fn = function([vv, scale], logp_y)
38+
expected_fn = function([vv, scale], expected)
39+
40+
v = np.array([-2.0, 0.0, 1.5])
41+
np.testing.assert_allclose(logp_y_fn(v, 0.5), expected_fn(v, 0.5))
42+
43+
with pytest.raises(ParameterValueError, match="switch non-overlapping scale > 0"):
44+
logp_y_fn(v, -0.5)
45+
46+
with pytest.raises(ParameterValueError, match="switch non-overlapping scale > 0"):
47+
logp_y_fn(v, 0.0)
48+
49+
50+
def test_switch_non_overlapping_does_not_rewrite_if_x_replicated_by_condition():
51+
scale = pt.scalar("scale")
52+
x = pm.Normal.dist(mu=0, sigma=1, size=(3,))
53+
cond = (x[None, :] > 0) & pt.ones((2, 1), dtype="bool")
54+
y = pt.switch(cond, x, scale * x)
55+
56+
with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"):
57+
logp(y, np.zeros((2, 3)))
58+
59+
60+
def test_switch_non_overlapping_does_not_rewrite_if_scale_broadcasts_x():
61+
x = pm.Normal.dist(mu=0, sigma=1)
62+
scale = pt.vector("scale")
63+
y = pt.switch(x > 0, x, scale * x)
64+
65+
with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"):
66+
logp(y, np.zeros((3,)))
67+
68+
69+
def test_switch_non_overlapping_does_not_apply_to_discrete_rv():
70+
a = pt.scalar("a")
71+
x = pm.Poisson.dist(3)
72+
y = pt.switch(x > 0, x, a * x)
73+
74+
with pytest.raises(NotImplementedError, match="Logprob method not implemented for Switch"):
75+
logp(y, 1)

tests/logprob/test_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def test_exp_transform_rv():
219219
logp_fn(y_val),
220220
sp.stats.lognorm(s=1).logpdf(y_val),
221221
)
222+
222223
np.testing.assert_almost_equal(
223224
logcdf_fn(y_val),
224225
sp.stats.lognorm(s=1).logcdf(y_val),

0 commit comments

Comments
 (0)