Skip to content

Commit cc39ea3

Browse files
maresbricardoV94
authored andcommitted
Rewrite log1mexp(log1mexp(x)) to x
1 parent 10e5c92 commit cc39ea3

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,13 @@ def local_exp_log_nan_switch(fgraph, node):
625625
new_out = switch(ge(x, 0), log1p(-x), np.asarray(np.nan, old_out.dtype))
626626
return [new_out]
627627

628+
# Case for log1mexp(log1mexp(x)) -> x
629+
if isinstance(prev_op, ps_math.Log1mexp) and isinstance(node_op, ps_math.Log1mexp):
630+
x = x.owner.inputs[0]
631+
old_out = node.outputs[0]
632+
new_out = switch(le(x, 0), x, np.asarray(np.nan, old_out.dtype))
633+
return [new_out]
634+
628635

629636
@register_canonicalize
630637
@register_specialize

tests/tensor/rewriting/test_math.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,6 +2016,29 @@ def test_log1mexp_log(self):
20162016
np.testing.assert_almost_equal(f(data_valid), expected)
20172017
assert np.all(np.isnan(f(data_invalid)))
20182018

2019+
def test_log1mexp_log1mexp(self):
2020+
# log1mexp(log1mexp(x)) -> x
2021+
data_valid = -np.random.random((4, 3)).astype("float32")
2022+
data_valid[0, 0] = 0 # edge case
2023+
data_invalid = data_valid + 1.1
2024+
2025+
x = fmatrix()
2026+
f = function([x], log1mexp(log1mexp(x)), mode=self.mode.excluding("inplace"))
2027+
assert equal_computations(
2028+
f.maker.fgraph.outputs,
2029+
[
2030+
pt.switch(
2031+
x <= np.array([[0]], dtype=np.int8),
2032+
x,
2033+
np.array([[np.nan]], dtype=np.float32),
2034+
)
2035+
],
2036+
)
2037+
2038+
expected = data_valid
2039+
np.testing.assert_almost_equal(f(data_valid), expected)
2040+
assert np.all(np.isnan(f(data_invalid)))
2041+
20192042
@pytest.mark.parametrize(
20202043
["nested_expression", "expected_switches"],
20212044
[

0 commit comments

Comments
 (0)