4343
4444from pytensor import tensor as pt
4545from pytensor .graph import Apply , Op , node_rewriter
46- from pytensor .graph .basic import Constant , Variable , clone_get_equiv , graph_inputs , walk
46+ from pytensor .graph .basic import Constant , Variable , ancestors , clone_get_equiv , graph_inputs , walk
4747from pytensor .graph .fg import FunctionGraph
4848from pytensor .graph .op import HasInnerGraph
4949from pytensor .link .c .type import CType
5050from pytensor .raise_op import CheckAndRaise
51- from pytensor .scalar .basic import Mul
51+ from pytensor .scalar .basic import GE , LE , Exp , Mul
5252from pytensor .tensor .basic import get_underlying_scalar_constant_value
53- from pytensor .tensor .elemwise import Elemwise
53+ from pytensor .tensor .elemwise import DimShuffle , Elemwise
5454from pytensor .tensor .exceptions import NotScalarConstantError
5555from pytensor .tensor .random .op import RandomVariable
5656from pytensor .tensor .variable import TensorVariable
@@ -228,6 +228,52 @@ def local_remove_check_parameter(fgraph, node):
228228 return [node .inputs [0 ]]
229229
230230
231+ @node_rewriter (tracks = [pt .switch ])
232+ def local_remove_useless_bound_checks (fgraph , node ):
233+ """Remove bound checks ensured by the transformations.
234+
235+ switch(exp(x) >= 0, cond1, cond2) -> cond1 if exp(x) in cond1
236+ The reason we don't set it to simply True is that x could be `nan`.
237+ If we see exp(x) exists in cond1 we assume `nan` will be propagated anyway.
238+
239+ This isn't guaranteed to be True, for instance if exp(x) is inside another switch statement.
240+ """
241+ cond , true_branch , false_branch = node .inputs
242+ if not (cond .owner is not None and isinstance (cond .owner .op , Elemwise )):
243+ return
244+ scalar_op = cond .owner .op .scalar_op
245+ if isinstance (scalar_op , LE ):
246+ maybe_zero , var = cond .owner .inputs
247+ elif isinstance (scalar_op , GE ):
248+ var , maybe_zero = cond .owner .inputs
249+ else :
250+ return None
251+
252+ if not (isinstance (maybe_zero , Constant ) and maybe_zero .unique_value == 0 ):
253+ return None
254+
255+ # Check if var is exp(x), x is a root variable and x is present in the true branch
256+ if (
257+ var .owner is not None
258+ and isinstance (var .owner .op , Elemwise )
259+ and isinstance (var .owner .op .scalar_op , Exp )
260+ ) or (
261+ (
262+ isinstance (var .owner .op , DimShuffle )
263+ and (
264+ var .owner .inputs [0 ].owner is not None
265+ and isinstance (var .owner .inputs [0 ].owner .op , Elemwise )
266+ and isinstance (var .owner .inputs [0 ].owner .op .scalar_op , Exp )
267+ )
268+ )
269+ # and var.owner.inputs[0].owner is None
270+ and var in ancestors ([true_branch ])
271+ ):
272+ return [true_branch ]
273+
274+ return None
275+
276+
231277@node_rewriter (tracks = [CheckParameterValue ])
232278def local_check_parameter_to_ninf_switch (fgraph , node ):
233279 if not node .op .can_be_replaced_by_ninf :
@@ -248,17 +294,23 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
248294
249295
250296pytensor .compile .optdb ["canonicalize" ].register (
251- " local_remove_check_parameter" ,
297+ local_remove_check_parameter . __name__ ,
252298 local_remove_check_parameter ,
253299 use_db_name_as_tag = False ,
254300)
255301
256302pytensor .compile .optdb ["canonicalize" ].register (
257- " local_check_parameter_to_ninf_switch" ,
303+ local_check_parameter_to_ninf_switch . __name__ ,
258304 local_check_parameter_to_ninf_switch ,
259305 use_db_name_as_tag = False ,
260306)
261307
308+ pytensor .compile .optdb ["canonicalize" ].register (
309+ local_remove_useless_bound_checks .__name__ ,
310+ local_remove_useless_bound_checks ,
311+ use_db_name_as_tag = False ,
312+ )
313+
262314
263315class DiracDelta (MeasurableOp , Op ):
264316 """An `Op` that represents a Dirac-delta distribution."""
0 commit comments