@@ -239,8 +239,6 @@ def numba_funcify_LogLike(op, node=None, **kwargs):
239
239
"""
240
240
logp_func = op .logp_func
241
241
242
- # Strategy: Use objmode for calling the Python function while keeping
243
- # the vectorization and error handling in nopython mode for performance
244
242
@numba_basic .numba_njit (parallel = True , fastmath = True , cache = True )
245
243
def loglike_vectorized_hybrid (phi ):
246
244
"""Vectorized log-likelihood with hybrid Python/Numba approach.
@@ -251,25 +249,17 @@ def loglike_vectorized_hybrid(phi):
251
249
L , N = phi .shape
252
250
logP = np .empty (L , dtype = phi .dtype )
253
251
254
- # Parallel computation using objmode for each row
255
252
for i in numba .prange (L ):
256
- row = phi [i ].copy () # Ensure contiguous memory for objmode
253
+ row = phi [i ].copy ()
257
254
with numba .objmode (val = "float64" ):
258
- # Call the Python function in objmode
259
255
val = logp_func (row )
260
256
logP [i ] = val
261
257
262
- # Handle NaN/Inf values exactly like the original implementation
263
- # Original: mask = np.isnan(logP) | np.isinf(logP)
264
- # Original: outputs[0][0] = np.where(mask, -np.inf, logP)
265
258
mask = np .isnan (logP ) | np .isinf (logP )
266
259
267
- # Check if ALL values are invalid (would trigger PathInvalidLogP in original)
268
260
if np .all (mask ):
269
- # All values are invalid - signal this by returning all -inf
270
261
logP [:] = - np .inf
271
262
else :
272
- # Replace invalid values with -inf, preserve valid ones
273
263
logP = np .where (mask , - np .inf , logP )
274
264
275
265
return logP
0 commit comments