Skip to content

Commit 4781957

Browse files
committed
Optimize DiracDelta logprob for exact equality
1 parent a7c856b commit 4781957

File tree

1 file changed

+18
-8
lines changed

1 file changed

+18
-8
lines changed

pymc/logprob/utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from pytensor.graph.op import HasInnerGraph
4949
from pytensor.raise_op import CheckAndRaise
5050
from pytensor.scalar.basic import Mul
51-
from pytensor.tensor.basic import get_underlying_scalar_constant_value
51+
from pytensor.tensor.basic import AllocEmpty, get_underlying_scalar_constant_value
5252
from pytensor.tensor.elemwise import Elemwise
5353
from pytensor.tensor.exceptions import NotScalarConstantError
5454
from pytensor.tensor.random.op import RandomVariable
@@ -244,7 +244,7 @@ class DiracDelta(MeasurableOp, Op):
244244

245245
__props__ = ("rtol", "atol")
246246

247-
def __init__(self, rtol=1e-5, atol=1e-8):
247+
def __init__(self, rtol, atol):
248248
self.rtol = rtol
249249
self.atol = atol
250250

@@ -267,15 +267,25 @@ def infer_shape(self, fgraph, node, input_shapes):
267267
return input_shapes
268268

269269

270-
dirac_delta = DiracDelta()
270+
def dirac_delta(x, rtol=1e-5, atol=1e-8):
271+
return DiracDelta(rtol, atol)(x)
271272

272273

273274
@_logprob.register(DiracDelta)
274-
def diracdelta_logprob(op, values, *inputs, **kwargs):
275-
(values,) = values
276-
(const_value,) = inputs
277-
values, const_value = pt.broadcast_arrays(values, const_value)
278-
return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf)
275+
def diracdelta_logprob(op, values, const_value, **kwargs):
276+
[value] = values
277+
278+
if const_value.owner and isinstance(const_value.owner.op, AllocEmpty):
279+
# Any value is considered valid for an AllocEmpty array
280+
return pt.zeros_like(value)
281+
282+
if op.rtol == 0 and op.atol == 0:
283+
# Strict equality, cheaper logp
284+
match = pt.eq(value, const_value)
285+
else:
286+
# Loose equality, more expensive logp
287+
match = pt.isclose(value, const_value, rtol=op.rtol, atol=op.atol)
288+
return pt.switch(match, np.array(0, dtype=value.dtype), -np.inf)
279289

280290

281291
def find_negated_var(var):

0 commit comments

Comments
 (0)