Skip to content

Commit 9366bca

Browse files
committed
Test cleanup
1 parent 945728d commit 9366bca

File tree

4 files changed

+22
-236
lines changed

4 files changed

+22
-236
lines changed

pymc_extras/inference/pathfinder/numba_dispatch.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,6 @@ def numba_funcify_LogLike(op, node=None, **kwargs):
239239
"""
240240
logp_func = op.logp_func
241241

242-
# Strategy: Use objmode for calling the Python function while keeping
243-
# the vectorization and error handling in nopython mode for performance
244242
@numba_basic.numba_njit(parallel=True, fastmath=True, cache=True)
245243
def loglike_vectorized_hybrid(phi):
246244
"""Vectorized log-likelihood with hybrid Python/Numba approach.
@@ -251,25 +249,17 @@ def loglike_vectorized_hybrid(phi):
251249
L, N = phi.shape
252250
logP = np.empty(L, dtype=phi.dtype)
253251

254-
# Parallel computation using objmode for each row
255252
for i in numba.prange(L):
256-
row = phi[i].copy() # Ensure contiguous memory for objmode
253+
row = phi[i].copy()
257254
with numba.objmode(val="float64"):
258-
# Call the Python function in objmode
259255
val = logp_func(row)
260256
logP[i] = val
261257

262-
# Handle NaN/Inf values exactly like the original implementation
263-
# Original: mask = np.isnan(logP) | np.isinf(logP)
264-
# Original: outputs[0][0] = np.where(mask, -np.inf, logP)
265258
mask = np.isnan(logP) | np.isinf(logP)
266259

267-
# Check if ALL values are invalid (would trigger PathInvalidLogP in original)
268260
if np.all(mask):
269-
# All values are invalid - signal this by returning all -inf
270261
logP[:] = -np.inf
271262
else:
272-
# Replace invalid values with -inf, preserve valid ones
273263
logP = np.where(mask, -np.inf, logP)
274264

275265
return logP

0 commit comments

Comments
 (0)