Skip to content

Commit da3db75

Browse files
committed
Handle edge case of logdiffexp(-inf, -inf)
1 parent 416737e commit da3db75

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

pymc/logprob/transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@
123123
filter_measurable_variables,
124124
find_negated_var,
125125
)
126+
from pymc.math import logdiffexp
126127

127128

128129
class Transform(abc.ABC):
@@ -267,7 +268,7 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
267268
logcdf_zero = _logcdf_helper(measurable_input, 0)
268269
logcdf = pt.switch(
269270
pt.lt(backward_value, 0),
270-
pt.log(pt.exp(logcdf_zero) - pt.exp(logcdf)),
271+
logdiffexp(logcdf_zero, logcdf),
271272
pt.logaddexp(logccdf, logcdf_zero),
272273
)
273274
else:

pymc/math.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,14 @@ def kron_diag(*diags):
282282

283283
def logdiffexp(a, b):
284284
"""Return log(exp(a) - exp(b))."""
285-
return a + pt.log1mexp(b - a)
285+
return pt.where(
286+
# Handle special case where b is -inf
287+
# If a == b == -inf, this will return the correct result of -inf
288+
# whereas the default else branch would get a nan due to -inf - (-inf)
289+
pt.isneginf(b),
290+
a,
291+
a + pt.log1mexp(b - a),
292+
)
286293

287294

288295
invlogit = sigmoid

tests/test_math.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,13 @@ def test_log1mexp_deprecation_warnings():
144144
def test_logdiffexp():
145145
a = np.log([1, 2, 3, 4])
146146
with warnings.catch_warnings():
147-
warnings.filterwarnings("ignore", "divide by zero encountered in log", RuntimeWarning)
148147
b = np.log([0, 1, 2, 3])
149-
assert np.allclose(logdiffexp(a, b).eval(), 0)
148+
np.testing.assert_allclose(logdiffexp(a, b).eval(), 0, atol=1e-15)
149+
150+
np.testing.assert_allclose(
151+
logdiffexp(-np.inf, -np.inf).eval(),
152+
-np.inf,
153+
)
150154

151155

152156
class TestLogDet:

0 commit comments

Comments
 (0)