4343
4444from pytensor import tensor as pt
4545from pytensor .graph import Apply , Op , node_rewriter
46- from pytensor .graph .basic import Constant , Variable , clone_get_equiv , graph_inputs , walk
46+ from pytensor .graph .basic import Constant , Variable , ancestors , clone_get_equiv , graph_inputs , walk
4747from pytensor .graph .fg import FunctionGraph
4848from pytensor .graph .op import HasInnerGraph
4949from pytensor .link .c .type import CType
5050from pytensor .raise_op import CheckAndRaise
51- from pytensor .scalar .basic import Mul
51+ from pytensor .scalar .basic import GE , LE , Exp , Mul
5252from pytensor .tensor .basic import get_underlying_scalar_constant_value
53- from pytensor .tensor .elemwise import Elemwise
53+ from pytensor .tensor .elemwise import DimShuffle , Elemwise
5454from pytensor .tensor .exceptions import NotScalarConstantError
5555from pytensor .tensor .random .op import RandomVariable
5656from pytensor .tensor .variable import TensorVariable
@@ -228,6 +228,55 @@ def local_remove_check_parameter(fgraph, node):
228228 return [node .inputs [0 ]]
229229
230230
231+ @node_rewriter (tracks = [pt .switch ])
232+ def local_remove_useless_bound_switch (fgraph , node ):
233+ """Remove bound checks ensured by the transformations.
234+
235+ switch(exp(x) >= 0, cond1, -inf) -> cond1 if exp(x) in cond1.
236+
237+ The reason we don't set it to simply True is that x could be `nan`.
238+ If we see exp(x) exists in cond1 we assume `nan` will be propagated anyway.
239+
240+ This isn't guaranteed to be True, for instance if exp(x) is inside another switch statement.
241+ """
242+ cond , true_branch , false_branch = node .inputs
243+ if not (cond .owner is not None and isinstance (cond .owner .op , Elemwise )):
244+ return
245+ scalar_op = cond .owner .op .scalar_op
246+ if isinstance (scalar_op , LE ):
247+ maybe_zero , var = cond .owner .inputs
248+ elif isinstance (scalar_op , GE ):
249+ var , maybe_zero = cond .owner .inputs
250+ else :
251+ return None
252+
253+ if not (
254+ (isinstance (maybe_zero , Constant ) and maybe_zero .unique_value == 0 )
255+ and (isinstance (false_branch , Constant ) and false_branch .unique_value == - np .inf )
256+ ):
257+ return None
258+
259+ # Check if var is exp(x) and x is present in the true branch
260+ if (
261+ var .owner is not None
262+ and (
263+ (isinstance (var .owner .op , Elemwise ) and isinstance (var .owner .op .scalar_op , Exp ))
264+ or (
265+ isinstance (var .owner .op , DimShuffle )
266+ and (
267+ var .owner .inputs [0 ].owner is not None
268+ and isinstance (var .owner .inputs [0 ].owner .op , Elemwise )
269+ and isinstance (var .owner .inputs [0 ].owner .op .scalar_op , Exp )
270+ )
271+ )
272+ )
273+ and var in ancestors ([true_branch ])
274+ ):
275+ return [true_branch ]
276+
277+ return None
278+
279+
231280@node_rewriter (tracks = [CheckParameterValue ])
232281def local_check_parameter_to_ninf_switch (fgraph , node ):
233282 if not node .op .can_be_replaced_by_ninf :
@@ -248,17 +297,23 @@ def local_check_parameter_to_ninf_switch(fgraph, node):
248297
249298
250299pytensor .compile .optdb ["canonicalize" ].register (
251- " local_remove_check_parameter" ,
300+ local_remove_check_parameter . __name__ ,
252301 local_remove_check_parameter ,
253302 use_db_name_as_tag = False ,
254303)
255304
256305pytensor .compile .optdb ["canonicalize" ].register (
257- " local_check_parameter_to_ninf_switch" ,
306+ local_check_parameter_to_ninf_switch . __name__ ,
258307 local_check_parameter_to_ninf_switch ,
259308 use_db_name_as_tag = False ,
260309)
261310
311+ pytensor .compile .optdb ["canonicalize" ].register (
312+ local_remove_useless_bound_switch .__name__ ,
313+ local_remove_useless_bound_switch ,
314+ use_db_name_as_tag = False ,
315+ )
316+
262317
263318class DiracDelta (MeasurableOp , Op ):
264319 """An `Op` that represents a Dirac-delta distribution."""
0 commit comments