diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9694a022e3..aef363655e 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -64,6 +64,7 @@ log, log1mexp, log1p, + log1pexp, makeKeepDims, maximum, mul, @@ -2999,12 +3000,6 @@ def _is_1(expr): tracks=[sigmoid], get_nodes=get_clients_at_depth2, ) -log1pexp_to_softplus = PatternNodeRewriter( - (log1p, (exp, "x")), - (softplus, "x"), - values_eq_approx=values_eq_approx_remove_inf, - allow_multiple_clients=True, -) log1p_neg_sigmoid = PatternNodeRewriter( (log1p, (neg, (sigmoid, "x"))), (neg, (softplus, "x")), @@ -3016,7 +3011,6 @@ def _is_1(expr): register_stabilize(logsigm_to_softplus, name="logsigm_to_softplus") register_stabilize(log1msigm_to_softplus, name="log1msigm_to_softplus") -register_stabilize(log1pexp_to_softplus, name="log1pexp_to_softplus") register_stabilize(log1p_neg_sigmoid, name="log1p_neg_sigmoid") register_specialize(log1p_neg_sigmoid, name="log1p_neg_sigmoid") @@ -3582,12 +3576,40 @@ def local_reciprocal_1_plus_exp(fgraph, node): register_specialize(local_1msigmoid) -log1pmexp_to_log1mexp = PatternNodeRewriter( - (log1p, (neg, (exp, "x"))), - (log1mexp, "x"), - allow_multiple_clients=True, -) -register_stabilize(log1pmexp_to_log1mexp, name="log1pmexp_to_log1mexp") +@register_stabilize +@node_rewriter([log1p]) +def local_log1p_plusminus_exp(fgraph, node): + """Transforms log1p of ±exp(x) into log1pexp (aka softplus) / log1mexp + ``log1p(exp(x)) -> log1pexp(x)`` + ``log1p(-exp(x)) -> log1mexp(x)`` + where "-" can be "neg" or any other expression detected by "is_neg" + """ + (log1p_arg,) = node.inputs + exp_info = is_exp(log1p_arg) + if exp_info is not None: + exp_neg, exp_arg = exp_info + if exp_neg: + return [log1mexp(exp_arg)] + else: + return [log1pexp(exp_arg)] # aka softplus + + +@register_stabilize +@node_rewriter([expm1]) +def logmexpm1_to_log1mexp(fgraph, node): + """``log(-expm1(x)) -> log1mexp(x)`` + where "-" can be "neg" or any other expression detected by "is_neg" + """ + rewrites = {} + for node in get_clients_at_depth(fgraph, node, depth=2): + if node.op == log: + (log_arg,) = node.inputs + neg_arg = is_neg(log_arg) + if neg_arg is not None and neg_arg.owner and neg_arg.owner.op == expm1: + (expm1_arg,) = neg_arg.owner.inputs + rewrites[node.outputs[0]] = log1mexp(expm1_arg) + return rewrites + # log(exp(a) - exp(b)) -> a + log1mexp(b - a) logdiffexp_to_log1mexpdiff = PatternNodeRewriter( diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 9a092663a9..c4999fcd33 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4438,11 +4438,22 @@ def test_local_add_neg_to_sub(first_negative): assert np.allclose(f(x_test, y_test), exp) -def test_log1mexp_stabilization(): +@pytest.mark.parametrize( + "op_name", + ["log_1_minus_exp", "log1p_minus_exp", "log_minus_expm1", "log_minus_exp_minus_1"], +) +def test_log1mexp_stabilization(op_name): mode = Mode("py").including("stabilize") x = vector() - f = function([x], log(1 - exp(x)), mode=mode) + if op_name == "log_1_minus_exp": + f = function([x], log(1 - exp(x)), mode=mode) + elif op_name == "log1p_minus_exp": + f = function([x], log1p(-exp(x)), mode=mode) + elif op_name == "log_minus_expm1": + f = function([x], log(-expm1(x)), mode=mode) + elif op_name == "log_minus_exp_minus_1": + f = function([x], log(-(exp(x) - 1)), mode=mode) nodes = [node.op for node in f.maker.fgraph.toposort()] assert nodes == [pt.log1mexp]