4848from pytensor .graph .op import HasInnerGraph
4949from pytensor .raise_op import CheckAndRaise
5050from pytensor .scalar .basic import Mul
51- from pytensor .tensor .basic import get_underlying_scalar_constant_value
51+ from pytensor .tensor .basic import AllocEmpty , get_underlying_scalar_constant_value
5252from pytensor .tensor .elemwise import Elemwise
5353from pytensor .tensor .exceptions import NotScalarConstantError
5454from pytensor .tensor .random .op import RandomVariable
@@ -244,7 +244,7 @@ class DiracDelta(MeasurableOp, Op):
244244
245245 __props__ = ("rtol" , "atol" )
246246
247- def __init__ (self , rtol = 1e-5 , atol = 1e-8 ):
247+ def __init__ (self , rtol , atol ):
248248 self .rtol = rtol
249249 self .atol = atol
250250
@@ -267,15 +267,25 @@ def infer_shape(self, fgraph, node, input_shapes):
267267 return input_shapes
268268
269269
270- dirac_delta = DiracDelta ()
270+ def dirac_delta (x , rtol = 1e-5 , atol = 1e-8 ):
271+ return DiracDelta (rtol , atol )(x )
271272
272273
273274@_logprob .register (DiracDelta )
274- def diracdelta_logprob (op , values , * inputs , ** kwargs ):
275- (values ,) = values
276- (const_value ,) = inputs
277- values , const_value = pt .broadcast_arrays (values , const_value )
278- return pt .switch (pt .isclose (values , const_value , rtol = op .rtol , atol = op .atol ), 0.0 , - np .inf )
275+ def diracdelta_logprob (op , values , const_value , ** kwargs ):
276+ [value ] = values
277+
278+ if const_value .owner and isinstance (const_value .owner .op , AllocEmpty ):
279+ # Any value is considered valid for an AllocEmpty array
280+ return pt .zeros_like (value )
281+
282+ if op .rtol == 0 and op .atol == 0 :
283+ # Strict equality, cheaper logp
284+ match = pt .eq (value , const_value )
285+ else :
286+ # Loose equality, more expensive logp
287+ match = pt .isclose (value , const_value , rtol = op .rtol , atol = op .atol )
288+ return pt .switch (match , np .array (0 , dtype = value .dtype ), - np .inf )
279289
280290
281291def find_negated_var (var ):
0 commit comments