-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add logprob support for leaky-ReLU switch transforms #7995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
36705b4
4d620a3
565e191
7d1db41
f9f36c0
4530814
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,202 @@ | ||||||
| # Copyright 2024 - present The PyMC Developers | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
| # | ||||||
| # MIT License | ||||||
| # | ||||||
| # Copyright (c) 2021-2022 aesara-devs | ||||||
| # | ||||||
| # Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
| # of this software and associated documentation files (the "Software"), to deal | ||||||
| # in the Software without restriction, including without limitation the rights | ||||||
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||||
| # copies of the Software, and to permit persons to whom the Software is | ||||||
| # furnished to do so, subject to the following conditions: | ||||||
| # | ||||||
| # The above copyright notice and this permission notice shall be included in all | ||||||
| # copies or substantial portions of the Software. | ||||||
| # | ||||||
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||||
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||||
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||
| # SOFTWARE. | ||||||
|
|
||||||
| """Measurable switch-based transforms.""" | ||||||
|
|
||||||
| from typing import cast | ||||||
|
|
||||||
| import pytensor.tensor as pt | ||||||
|
|
||||||
| from pytensor.graph.rewriting.basic import node_rewriter | ||||||
| from pytensor.scalar import Switch | ||||||
| from pytensor.scalar import switch as scalar_switch | ||||||
| from pytensor.scalar.basic import GE, GT, LE, LT, Mul | ||||||
| from pytensor.tensor.basic import switch as tensor_switch | ||||||
| from pytensor.tensor.elemwise import Elemwise | ||||||
| from pytensor.tensor.exceptions import NotScalarConstantError | ||||||
| from pytensor.tensor.variable import TensorVariable | ||||||
|
|
||||||
| from pymc.logprob.abstract import MeasurableElemwise, MeasurableOp, _logprob, _logprob_helper | ||||||
| from pymc.logprob.rewriting import measurable_ir_rewrites_db | ||||||
| from pymc.logprob.transforms import MeasurableTransform | ||||||
| from pymc.logprob.utils import ( | ||||||
| CheckParameterValue, | ||||||
| check_potential_measurability, | ||||||
| filter_measurable_variables, | ||||||
| ) | ||||||
|
|
||||||
|
|
||||||
| class MeasurableSwitchNonOverlapping(MeasurableElemwise): | ||||||
| """Placeholder for switch transforms whose branch images do not overlap. | ||||||
|
|
||||||
| Currently matches leaky-ReLU graphs of the form `switch(x > 0, x, a * x)`. | ||||||
| """ | ||||||
|
|
||||||
| valid_scalar_types = (Switch,) | ||||||
|
|
||||||
|
|
||||||
| measurable_switch_non_overlapping = MeasurableSwitchNonOverlapping(scalar_switch) | ||||||
|
|
||||||
|
|
||||||
| def _is_x_threshold_condition(cond: TensorVariable, x: TensorVariable) -> bool: | ||||||
| """Check whether `cond` is equivalent to `x > 0` / `x >= 0` (and swapped forms).""" | ||||||
| if cond.owner is None: | ||||||
| return False | ||||||
| if not isinstance(cond.owner.op, Elemwise): | ||||||
| return False | ||||||
| scalar_op = cond.owner.op.scalar_op | ||||||
| if not isinstance(scalar_op, GT | GE | LT | LE): | ||||||
| return False | ||||||
|
|
||||||
| left, right = cond.owner.inputs | ||||||
|
|
||||||
| def _is_zero(v: TensorVariable) -> bool: | ||||||
| try: | ||||||
| return pt.get_underlying_scalar_constant_value(v) == 0 | ||||||
| except NotScalarConstantError: | ||||||
| return False | ||||||
|
|
||||||
| # x > 0 or x >= 0 | ||||||
| if left is x and _is_zero(cast(TensorVariable, right)) and isinstance(scalar_op, GT | GE): | ||||||
| return True | ||||||
| # 0 < x or 0 <= x | ||||||
| if right is x and _is_zero(cast(TensorVariable, left)) and isinstance(scalar_op, LT | LE): | ||||||
| return True | ||||||
|
|
||||||
| return False | ||||||
|
|
||||||
|
|
||||||
| def _extract_scale_from_measurable_mul( | ||||||
| neg_branch: TensorVariable, x: TensorVariable | ||||||
| ) -> TensorVariable | None: | ||||||
| """Extract scale `a` from a measurable multiplication that represents `a * x`.""" | ||||||
| if neg_branch is x: | ||||||
| return pt.constant(1.0) | ||||||
|
|
||||||
| if neg_branch.owner is None: | ||||||
| return None | ||||||
|
|
||||||
| if not isinstance(neg_branch.owner.op, MeasurableTransform): | ||||||
| return None | ||||||
|
|
||||||
| op = neg_branch.owner.op | ||||||
| if not isinstance(op.scalar_op, Mul): | ||||||
| return None | ||||||
|
|
||||||
| # MeasurableTransform takes (measurable_input, scale) | ||||||
| if len(neg_branch.owner.inputs) != 2: | ||||||
| return None | ||||||
|
|
||||||
| if neg_branch.owner.inputs[op.measurable_input_idx] is not x: | ||||||
| return None | ||||||
|
|
||||||
| scale = neg_branch.owner.inputs[1 - op.measurable_input_idx] | ||||||
| return cast(TensorVariable, scale) | ||||||
|
|
||||||
|
|
||||||
| @node_rewriter([tensor_switch]) | ||||||
| def find_measurable_switch_non_overlapping(fgraph, node): | ||||||
| """Detect `switch(x > 0, x, a * x)` and replace it by a measurable op.""" | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you open an issue for follow-up extensions? From the top of my mind:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure, that makes sense
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I have raised #8049 for this, please check it and let me know if I should raise the rest of the issues as well, or if you’d prefer a different structure/scope. I was thinking to raise issues for the following extensions:
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's a solid plan. Feel free to describe all that in the issue or open separate ones
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that's a solid plan |
||||||
| if isinstance(node.op, MeasurableOp): | ||||||
| return None | ||||||
|
|
||||||
| cond, pos_branch, neg_branch = node.inputs | ||||||
|
|
||||||
| # Only mark the switch measurable once both branches are already measurable. | ||||||
| # Then the logprob can simply gate between branch logps evaluated at `value`. | ||||||
| if set(filter_measurable_variables([pos_branch, neg_branch])) != {pos_branch, neg_branch}: | ||||||
| return None | ||||||
|
|
||||||
| x = cast(TensorVariable, pos_branch) | ||||||
|
|
||||||
| if x.type.dtype.startswith("int"): | ||||||
ricardoV94 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
| return None | ||||||
|
|
||||||
| if x.type.broadcastable != node.outputs[0].type.broadcastable: | ||||||
| return None | ||||||
|
|
||||||
| if not _is_x_threshold_condition(cast(TensorVariable, cond), x): | ||||||
| return None | ||||||
|
|
||||||
| a = _extract_scale_from_measurable_mul(cast(TensorVariable, neg_branch), x) | ||||||
| if a is None: | ||||||
| return None | ||||||
|
|
||||||
| # Disallow slope `a` that could be (directly or indirectly) measurable. | ||||||
| # This rewrite targets deterministic, non-overlapping transforms parametrized by non-RVs. | ||||||
| if check_potential_measurability([a]): | ||||||
| return None | ||||||
|
|
||||||
| return [ | ||||||
| measurable_switch_non_overlapping( | ||||||
| cast(TensorVariable, cond), | ||||||
| x, | ||||||
| cast(TensorVariable, neg_branch), | ||||||
| ) | ||||||
| ] | ||||||
|
|
||||||
|
|
||||||
| @_logprob.register(MeasurableSwitchNonOverlapping) | ||||||
| def logprob_switch_non_overlapping(op, values, cond, x, neg_branch, **kwargs): | ||||||
| (value,) = values | ||||||
|
|
||||||
| a = _extract_scale_from_measurable_mul( | ||||||
| cast(TensorVariable, neg_branch), cast(TensorVariable, x) | ||||||
| ) | ||||||
| if a is None: | ||||||
| raise NotImplementedError("Could not extract non-overlapping scale") | ||||||
|
|
||||||
| a_is_positive = pt.all(pt.gt(a, 0)) | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it need to be positive or non-negative is enough?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it needs to be positive because if
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yeah that makes sense
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OTOH If it goes on the pm.logp(pm.Normal.dist() * pt.scalar("a"), 1).eval({"a": 0}) # np.nan Probably better to leave as is though. |
||||||
|
|
||||||
| # For `a > 0`, `switch(x > 0, x, a * x)` maps to disjoint regions in `value`: | ||||||
| # true branch -> value > 0, false branch -> value <= 0. | ||||||
| value_implies_true_branch = pt.gt(value, 0) | ||||||
|
||||||
| value_implies_true_branch = pt.gt(value, 0) | |
| value_implies_true_branch = pt.ge(value, 0) |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -43,6 +43,8 @@ | |||||||||
|
|
||||||||||
| from pytensor.graph.basic import equal_computations | ||||||||||
|
|
||||||||||
| import pymc as pm | ||||||||||
|
|
||||||||||
| from pymc.distributions.continuous import Cauchy, ChiSquared | ||||||||||
| from pymc.distributions.discrete import Bernoulli | ||||||||||
| from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp | ||||||||||
|
|
@@ -219,6 +221,7 @@ def test_exp_transform_rv(): | |||||||||
| logp_fn(y_val), | ||||||||||
| sp.stats.lognorm(s=1).logpdf(y_val), | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| np.testing.assert_almost_equal( | ||||||||||
| logcdf_fn(y_val), | ||||||||||
| sp.stats.lognorm(s=1).logcdf(y_val), | ||||||||||
|
|
@@ -229,6 +232,57 @@ def test_exp_transform_rv(): | |||||||||
| ) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def test_leaky_relu_switch_logp_scalar(): | ||||||||||
|
||||||||||
| a = 0.5 | ||||||||||
| x = pm.Normal.dist(mu=0, sigma=1) | ||||||||||
| y = pm.math.switch(x > 0, x, a * x) | ||||||||||
|
|
||||||||||
| v_pos = 1.2 | ||||||||||
| np.testing.assert_allclose( | ||||||||||
| pm.logp(y, v_pos, warn_rvs=False).eval(), | ||||||||||
| pm.logp(x, v_pos, warn_rvs=False).eval(), | ||||||||||
|
||||||||||
| pm.logp(y, v_pos, warn_rvs=False).eval(), | |
| pm.logp(x, v_pos, warn_rvs=False).eval(), | |
| pm.logp(y, v_pos).eval(), | |
| pm.logp(x, v_pos).eval(), |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're testing with two values, define the logp variable once (or compile a function with the logp as output once), and reuse it. That will avoid duplicated logp inference calls.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use vectorized in the first test, and remove this. It's not conceputally that different
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use symbolic a in the first test and test the error there. Still a pretty straightforward test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better name for this helper to emphasize zero? perhaps
_is_zero_x_threshold