Skip to content

Commit 488dc51

Browse files
committed
Remove useless switch on log transformed parameters
1 parent 11dcadf commit 488dc51

File tree

2 files changed

+65
-9
lines changed

2 files changed

+65
-9
lines changed

pymc/logprob/utils.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@
4343

4444
from pytensor import tensor as pt
4545
from pytensor.graph import Apply, Op, node_rewriter
46-
from pytensor.graph.basic import Constant, Variable, clone_get_equiv, graph_inputs, walk
46+
from pytensor.graph.basic import Constant, Variable, ancestors, clone_get_equiv, graph_inputs, walk
4747
from pytensor.graph.fg import FunctionGraph
4848
from pytensor.graph.op import HasInnerGraph
4949
from pytensor.link.c.type import CType
5050
from pytensor.raise_op import CheckAndRaise
51-
from pytensor.scalar.basic import Mul
51+
from pytensor.scalar.basic import GE, LE, Exp, Mul
5252
from pytensor.tensor.basic import get_underlying_scalar_constant_value
53-
from pytensor.tensor.elemwise import Elemwise
53+
from pytensor.tensor.elemwise import DimShuffle, Elemwise
5454
from pytensor.tensor.exceptions import NotScalarConstantError
5555
from pytensor.tensor.random.op import RandomVariable
5656
from pytensor.tensor.variable import TensorVariable
@@ -228,6 +228,52 @@ def local_remove_check_parameter(fgraph, node):
228228
return [node.inputs[0]]
229229

230230

231+
@node_rewriter(tracks=[pt.switch])
232+
def local_remove_useless_bound_checks(fgraph, node):
233+
"""Remove bound checks ensured by the transformations.
234+
235+
switch(exp(x) >= 0, cond1, cond2) -> cond1 if exp(x) in cond1
236+
The reason we don't set it to simply True is that x could be `nan`.
237+
If we see exp(x) exists in cond1 we assume `nan` will be propagated anyway.
238+
239+
This isn't guaranteed to be True, for instance if exp(x) is inside another switch statement.
240+
"""
241+
cond, true_branch, false_branch = node.inputs
242+
if not (cond.owner is not None and isinstance(cond.owner.op, Elemwise)):
243+
return
244+
scalar_op = cond.owner.op.scalar_op
245+
if isinstance(scalar_op, LE):
246+
maybe_zero, var = cond.owner.inputs
247+
elif isinstance(scalar_op, GE):
248+
var, maybe_zero = cond.owner.inputs
249+
else:
250+
return None
251+
252+
if not (isinstance(maybe_zero, Constant) and maybe_zero.unique_value == 0):
253+
return None
254+
255+
# Check if var is exp(x), x is a root variable and x is present in the true branch
256+
if (
257+
var.owner is not None
258+
and isinstance(var.owner.op, Elemwise)
259+
and isinstance(var.owner.op.scalar_op, Exp)
260+
) or (
261+
(
262+
isinstance(var.owner.op, DimShuffle)
263+
and (
264+
var.owner.inputs[0].owner is not None
265+
and isinstance(var.owner.inputs[0].owner.op, Elemwise)
266+
and isinstance(var.owner.inputs[0].owner.op.scalar_op, Exp)
267+
)
268+
)
269+
# and var.owner.inputs[0].owner is None
270+
and var in ancestors([true_branch])
271+
):
272+
return [true_branch]
273+
274+
return None
275+
276+
231277
@node_rewriter(tracks=[CheckParameterValue])
232278
def local_check_parameter_to_ninf_switch(fgraph, node):
233279
if not node.op.can_be_replaced_by_ninf:
@@ -248,17 +294,23 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
248294

249295

250296
pytensor.compile.optdb["canonicalize"].register(
251-
"local_remove_check_parameter",
297+
local_remove_check_parameter.__name__,
252298
local_remove_check_parameter,
253299
use_db_name_as_tag=False,
254300
)
255301

256302
pytensor.compile.optdb["canonicalize"].register(
257-
"local_check_parameter_to_ninf_switch",
303+
local_check_parameter_to_ninf_switch.__name__,
258304
local_check_parameter_to_ninf_switch,
259305
use_db_name_as_tag=False,
260306
)
261307

308+
pytensor.compile.optdb["canonicalize"].register(
309+
local_remove_useless_bound_checks.__name__,
310+
local_remove_useless_bound_checks,
311+
use_db_name_as_tag=False,
312+
)
313+
262314

263315
class DiracDelta(MeasurableOp, Op):
264316
"""An `Op` that represents a Dirac-delta distribution."""

pymc/pytensorf.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -937,12 +937,16 @@ def compile(
937937
check_bounds = model.check_bounds
938938
except TypeError:
939939
check_bounds = True
940-
check_parameter_opt = (
941-
"local_check_parameter_to_ninf_switch" if check_bounds else "local_remove_check_parameter"
942-
)
940+
if check_bounds:
941+
check_parameter_opt = ("local_check_parameter_to_ninf_switch",)
942+
else:
943+
check_parameter_opt = (
944+
"local_remove_check_parameter",
945+
"local_remove_useless_bound_checks",
946+
)
943947

944948
mode = get_mode(mode)
945-
opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
949+
opt_qry = mode.provided_optimizer.including("random_make_inplace", *check_parameter_opt)
946950
mode = Mode(linker=mode.linker, optimizer=opt_qry)
947951
pytensor_function = pytensor.function(
948952
inputs,

0 commit comments

Comments
 (0)