From a1211a56ceb6e04dc4c6e16ffd52a63335171711 Mon Sep 17 00:00:00 2001 From: Luca Citi Date: Tue, 17 Jun 2025 09:35:53 +0000 Subject: [PATCH 1/5] Created some tests that fail due to #1476 --- tests/tensor/rewriting/test_math.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9a092663a9..c4999fcd33 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4438,11 +4438,22 @@ def test_local_add_neg_to_sub(first_negative): assert np.allclose(f(x_test, y_test), exp) -def test_log1mexp_stabilization(): +@pytest.mark.parametrize( + "op_name", + ["log_1_minus_exp", "log1p_minus_exp", "log_minus_expm1", "log_minus_exp_minus_1"], +) +def test_log1mexp_stabilization(op_name): mode = Mode("py").including("stabilize") x = vector() - f = function([x], log(1 - exp(x)), mode=mode) + if op_name == "log_1_minus_exp": + f = function([x], log(1 - exp(x)), mode=mode) + elif op_name == "log1p_minus_exp": + f = function([x], log1p(-exp(x)), mode=mode) + elif op_name == "log_minus_expm1": + f = function([x], log(-expm1(x)), mode=mode) + elif op_name == "log_minus_exp_minus_1": + f = function([x], log(-(exp(x) - 1)), mode=mode) nodes = [node.op for node in f.maker.fgraph.toposort()] assert nodes == [pt.log1mexp] From ef228325095a428bc7c188f1c79ab98734d138e1 Mon Sep 17 00:00:00 2001 From: Luca Citi Date: Tue, 17 Jun 2025 11:51:37 +0000 Subject: [PATCH 2/5] Fixes 1476 and other ways to create a log1mexp --- pytensor/tensor/rewriting/math.py | 35 +++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..06b062f64b 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -3582,12 +3582,35 @@ def local_reciprocal_1_plus_exp(fgraph, node): register_specialize(local_1msigmoid) -log1pmexp_to_log1mexp = PatternNodeRewriter( - (log1p, (neg, (exp, "x"))), - (log1mexp, "x"), - allow_multiple_clients=True, -) -register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp") +@register_stabilize +@node_rewriter([log1p]) +def log1pmexp_to_log1mexp(fgraph, node): + """``log1p(-exp(x)) -> log1mexp(x)`` + where "-" can be "neg" or any other expression detected by "is_neg" + """ + (log1p_arg,) = node.inputs + exp_info = is_exp(log1p_arg) + if exp_info is not None: + exp_neg, exp_arg = exp_info + if exp_neg: + return [log1mexp(exp_arg)] + else: + return # We could return [log1pexp(exp_arg)] here but that would conflict with log1pexp_to_softplus + + +@register_stabilize +@node_rewriter([log]) +def logmexpm1_to_log1mexp(fgraph, node): + """``log(-expm1(x)) -> log1mexp(x)``""" + (log_arg,) = node.inputs + if log_arg.owner and any( # ⬐ additional match to less-frequent expm1 + i.owner and i.owner.op == expm1 for i in log_arg.owner.inputs + ): + neg_arg = is_neg(log_arg) + if neg_arg.owner and neg_arg.owner.op == expm1: + (expm1_arg,) = neg_arg.owner.inputs + return [log1mexp(expm1_arg)] + # log(exp(a) - exp(b)) -> a + log1mexp(b - a) logdiffexp_to_log1mexpdiff = PatternNodeRewriter( From 8cd1b5247d64912a91e4afea6b364f1dc23a0388 Mon Sep 17 00:00:00 2001 From: Luca Citi Date: Tue, 17 Jun 2025 22:27:30 +0000 Subject: [PATCH 3/5] Reimplemented logmexpm1_to_log1mexp by tracking expm1 and then looking through the clients --- pytensor/tensor/rewriting/math.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 06b062f64b..11a51840ce 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -3599,17 +3599,20 @@ def log1pmexp_to_log1mexp(fgraph, node): @register_stabilize -@node_rewriter([log]) +@node_rewriter([expm1]) def logmexpm1_to_log1mexp(fgraph, node): - """``log(-expm1(x)) -> log1mexp(x)``""" - (log_arg,) = node.inputs - if log_arg.owner and any( # ⬐ additional match to less-frequent expm1 - i.owner and i.owner.op == expm1 for i in log_arg.owner.inputs - ): - neg_arg = is_neg(log_arg) - if neg_arg.owner and neg_arg.owner.op == expm1: - (expm1_arg,) = neg_arg.owner.inputs - return [log1mexp(expm1_arg)] + """``log(-expm1(x)) -> log1mexp(x)`` + where "-" can be "neg" or any other expression detected by "is_neg" + """ + rewrites = {} + for node in get_clients_at_depth(fgraph, node, depth=2): + if node.op == log: + (log_arg,) = node.inputs + neg_arg = is_neg(log_arg) + if neg_arg.owner and neg_arg.owner.op == expm1: + (expm1_arg,) = neg_arg.owner.inputs + rewrites[node.outputs[0]] = log1mexp(expm1_arg) + return rewrites # log(exp(a) - exp(b)) -> a + log1mexp(b - a) From 92a29d240aaf39009085c6f890ae918faefcfcba Mon Sep 17 00:00:00 2001 From: Luca Citi Date: Wed, 18 Jun 2025 10:46:34 +0000 Subject: [PATCH 4/5] Absorbed the rewrite log1pexp_to_softplus into the new rewrite for log1mexp --- pytensor/tensor/rewriting/math.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 11a51840ce..a7d484ef52 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -64,6 +64,7 @@ log, log1mexp, log1p, + log1pexp, makeKeepDims, maximum, mul, @@ -2999,12 +3000,6 @@ def _is_1(expr): tracks=[sigmoid], get_nodes=get_clients_at_depth2, ) -log1pexp_to_softplus = PatternNodeRewriter( - (log1p, (exp, "x")), - (softplus, "x"), - values_eq_approx=values_eq_approx_remove_inf, - allow_multiple_clients=True, -) log1p_neg_sigmoid = PatternNodeRewriter( (log1p, (neg, (sigmoid, "x"))), (neg, (softplus, "x")), @@ -3016,7 +3011,6 @@ def _is_1(expr): register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus") register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus") -register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus") register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid") register_specialize(log1p_neg_sigmoid, name="log1p_neg_sigmoid") @@ -3584,8 +3578,10 @@ def local_reciprocal_1_plus_exp(fgraph, node): @register_stabilize @node_rewriter([log1p]) -def log1pmexp_to_log1mexp(fgraph, node): - """``log1p(-exp(x)) -> log1mexp(x)`` +def local_log1p_plusminus_exp(fgraph, node): + """Transforms log1p of ±exp(x) into log1pexp (aka softplus) / log1mexp + ``log1p(exp(x)) -> log1pexp(x)`` + ``log1p(-exp(x)) -> log1mexp(x)`` where "-" can be "neg" or any other expression detected by "is_neg" """ (log1p_arg,) = node.inputs @@ -3595,7 +3591,7 @@ def log1pmexp_to_log1mexp(fgraph, node): if exp_neg: return [log1mexp(exp_arg)] else: - return # We could return [log1pexp(exp_arg)] here but that would conflict with log1pexp_to_softplus + return [log1pexp(exp_arg)] # aka softplus @register_stabilize From 9f9bc8bbb721bde772550f936ad03625b6b57783 Mon Sep 17 00:00:00 2001 From: Luca Citi Date: Wed, 18 Jun 2025 14:20:14 +0000 Subject: [PATCH 5/5] Fixed bug where I forgot to check whether result of is_neg was None or not before proceeding --- pytensor/tensor/rewriting/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index a7d484ef52..aef363655e 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -3605,7 +3605,7 @@ def logmexpm1_to_log1mexp(fgraph, node): if node.op == log: (log_arg,) = node.inputs neg_arg = is_neg(log_arg) - if neg_arg.owner and neg_arg.owner.op == expm1: + if neg_arg is not None and neg_arg.owner and neg_arg.owner.op == expm1: (expm1_arg,) = neg_arg.owner.inputs rewrites[node.outputs[0]] = log1mexp(expm1_arg) return rewrites