Skip to content

Commit ae9ee59

Browse files
committed
Cleanup
1 parent e401574 commit ae9ee59

File tree

1 file changed

+36
-113
lines changed

1 file changed

+36
-113
lines changed

pymc_extras/inference/pathfinder/numba_dispatch.py

Lines changed: 36 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,15 @@
1313
- Existing JAX dispatch in jax_dispatch.py
1414
"""
1515

16-
import numba
1716
import numpy as np
1817
import pytensor.tensor as pt
1918

2019
from pytensor.graph import Apply, Op
2120
from pytensor.link.numba.dispatch import basic as numba_basic
2221
from pytensor.link.numba.dispatch import numba_funcify
2322

24-
# Import existing ops for registration
2523

26-
# Module version for tracking
27-
__version__ = "0.1.0"
28-
29-
30-
# NOTE: LogLike Op registration for Numba is intentionally removed
31-
#
32-
# The LogLike Op cannot be compiled with Numba due to fundamental incompatibility:
33-
# - LogLike uses arbitrary Python function closures (logp_func)
34-
# - Numba requires concrete, statically-typeable operations
35-
# - Function closures from PyTensor compilation cannot be analyzed by Numba
36-
#
37-
# Instead, the vectorized_logp module handles Numba mode by using scan-based
38-
# approaches that avoid LogLike Op entirely.
39-
#
40-
# This is documented as a known limitation in CLAUDE.md
41-
42-
43-
# @numba_funcify.register(LogLike) # DISABLED - see note above
24+
# @numba_funcify.register(LogLike) # DISABLED
4425
def _disabled_numba_funcify_LogLike(op, node, **kwargs):
4526
"""DISABLED: LogLike Op registration for Numba.
4627
@@ -59,7 +40,6 @@ def _disabled_numba_funcify_LogLike(op, node, **kwargs):
5940
)
6041

6142

62-
# Custom Op for Numba-compatible chi matrix computation
6343
class NumbaChiMatrixOp(Op):
6444
"""Numba-optimized Chi matrix computation.
6545
@@ -96,7 +76,7 @@ def make_node(self, diff):
9676
Computation node for chi matrix
9777
"""
9878
diff = pt.as_tensor_variable(diff)
99-
# Output shape: (L, N, J) - use None for dynamic dimensions
79+
10080
output = pt.tensor(
10181
dtype=diff.dtype,
10282
shape=(None, None, self.J), # Only J is static
@@ -118,21 +98,18 @@ def perform(self, node, inputs, outputs):
11898
outputs : list
11999
Output arrays [chi_matrix]
120100
"""
121-
diff = inputs[0] # Shape: (L, N)
101+
diff = inputs[0]
122102
L, N = diff.shape
123103
J = self.J
124104

125-
# Create output matrix
126105
chi_matrix = np.zeros((L, N, J), dtype=diff.dtype)
127106

128-
# Compute sliding window matrix (same logic as JAX version)
107+
# Compute sliding window matrix
129108
for idx in range(L):
130-
# For each row idx, we want the last J values of diff up to position idx
131109
start_idx = max(0, idx - J + 1)
132110
end_idx = idx + 1
133111

134-
# Get the relevant slice
135-
relevant_diff = diff[start_idx:end_idx] # Shape: (actual_length, N)
112+
relevant_diff = diff[start_idx:end_idx]
136113
actual_length = end_idx - start_idx
137114

138115
# If we have fewer than J values, pad with zeros at the beginning
@@ -142,8 +119,7 @@ def perform(self, node, inputs, outputs):
142119
else:
143120
padded_diff = relevant_diff
144121

145-
# Assign to chi matrix
146-
chi_matrix[idx] = padded_diff.T # Transpose to get (N, J)
122+
chi_matrix[idx] = padded_diff.T
147123

148124
outputs[0][0] = chi_matrix
149125

@@ -198,11 +174,9 @@ def chi_matrix_numba(diff):
198174

199175
# Optimized sliding window with manual loop unrolling
200176
for batch_idx in range(L):
201-
# Efficient window extraction
202177
start_idx = max(0, batch_idx - J + 1)
203178
window_size = min(J, batch_idx + 1)
204179

205-
# Direct memory copy for efficiency
206180
for j in range(window_size):
207181
source_idx = start_idx + j
208182
target_idx = J - window_size + j
@@ -214,7 +188,6 @@ def chi_matrix_numba(diff):
214188
return chi_matrix_numba
215189

216190

217-
# Custom Op for Numba-compatible BFGS sampling
218191
class NumbaBfgsSampleOp(Op):
219192
"""Numba-optimized BFGS sampling with conditional logic.
220193
@@ -262,7 +235,6 @@ def make_node(
262235
Apply
263236
Computation node with two outputs: phi and logdet
264237
"""
265-
# Convert all inputs to tensor variables (same as JAX version)
266238
inputs = [
267239
pt.as_tensor_variable(inp)
268240
for inp in [
@@ -278,10 +250,8 @@ def make_node(
278250
]
279251
]
280252

281-
# Output phi: shape (L, M, N) - same as u
282253
phi_out = pt.tensor(dtype=u.dtype, shape=(None, None, None))
283254

284-
# Output logdet: shape (L,) - same as first dimension of x
285255
logdet_out = pt.tensor(dtype=u.dtype, shape=(None,))
286256

287257
return Apply(self, inputs, [phi_out, logdet_out])
@@ -299,20 +269,12 @@ def perform(self, node, inputs, outputs):
299269

300270
x, g, alpha, beta, gamma, alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag, u = inputs
301271

302-
# Get shapes
303272
L, M, N = u.shape
304273
L, N, JJ = beta.shape
305274

306-
# Define the condition: use dense when JJ >= N, sparse otherwise
307-
condition = JJ >= N
308-
309-
# Regularization term (from pathfinder.py REGULARISATION_TERM)
310275
REGULARISATION_TERM = 1e-8
311276

312-
if condition:
313-
# Dense BFGS sampling branch
314-
315-
# Create identity matrix with regularization
277+
if JJ >= N:
316278
IdN = np.eye(N)[None, ...]
317279
IdN = IdN + IdN * REGULARISATION_TERM
318280

@@ -325,68 +287,49 @@ def perform(self, node, inputs, outputs):
325287
@ inv_sqrt_alpha_diag
326288
)
327289

328-
# Full inverse Hessian
329290
H_inv = sqrt_alpha_diag @ (IdN + middle_term) @ sqrt_alpha_diag
330291

331-
# Cholesky decomposition (upper triangular)
332292
Lchol = np.array([cholesky(H_inv[i], lower=False) for i in range(L)])
333293

334-
# Compute log determinant from Cholesky diagonal
335294
logdet = 2.0 * np.sum(np.log(np.abs(np.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
336295

337-
# Compute mean: mu = x - H_inv @ g
338296
mu = x - np.sum(H_inv * g[..., None, :], axis=-1)
339297

340-
# Sample: phi = mu + Lchol @ u.T, then transpose back
341298
phi_transposed = mu[..., None] + Lchol @ np.transpose(u, axes=(0, 2, 1))
342299
phi = np.transpose(phi_transposed, axes=(0, 2, 1))
343300

344301
else:
345-
# Sparse BFGS sampling branch
346-
347-
# QR decomposition of qr_input = inv_sqrt_alpha_diag @ beta
302+
# Sparse BFGS sampling
348303
qr_input = inv_sqrt_alpha_diag @ beta
349304

350-
# NumPy QR decomposition (applied along batch dimension)
351-
Q = np.zeros((L, qr_input.shape[1], qr_input.shape[2])) # (L, N, JJ)
352-
R = np.zeros((L, qr_input.shape[2], qr_input.shape[2])) # (L, JJ, JJ)
305+
Q = np.zeros((L, qr_input.shape[1], qr_input.shape[2]))
306+
R = np.zeros((L, qr_input.shape[2], qr_input.shape[2]))
353307
for i in range(L):
354308
Q[i], R[i] = qr(qr_input[i], mode="economic")
355309

356-
# Identity matrix with regularization
357310
IdJJ = np.eye(R.shape[1])[None, ...]
358311
IdJJ = IdJJ + IdJJ * REGULARISATION_TERM
359312

360-
# Cholesky input: IdJJ + R @ gamma @ R.T
361313
Lchol_input = IdJJ + R @ gamma @ np.transpose(R, axes=(0, 2, 1))
362314

363-
# Cholesky decomposition (upper triangular)
364315
Lchol = np.array([cholesky(Lchol_input[i], lower=False) for i in range(L)])
365316

366-
# Compute log determinant: includes both Cholesky and alpha terms
367317
logdet_chol = 2.0 * np.sum(
368318
np.log(np.abs(np.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1
369319
)
370320
logdet_alpha = np.sum(np.log(alpha), axis=-1)
371321
logdet = logdet_chol + logdet_alpha
372322

373-
# Compute inverse Hessian for sparse case: H_inv = alpha_diag + beta @ gamma @ beta.T
374323
H_inv = alpha_diag + (beta @ gamma @ np.transpose(beta, axes=(0, 2, 1)))
375324

376-
# Compute mean: mu = x - H_inv @ g
377325
mu = x - np.sum(H_inv * g[..., None, :], axis=-1)
378326

379-
# Complex sampling transformation for sparse case
380-
# First part: Q @ (Lchol - IdJJ)
381327
Q_Lchol_diff = Q @ (Lchol - IdJJ)
382328

383-
# Second part: Q.T @ u.T
384329
Qt_u = np.transpose(Q, axes=(0, 2, 1)) @ np.transpose(u, axes=(0, 2, 1))
385330

386-
# Combine: (Q @ (Lchol - IdJJ)) @ (Q.T @ u.T) + u.T
387331
combined = Q_Lchol_diff @ Qt_u + np.transpose(u, axes=(0, 2, 1))
388332

389-
# Final transformation: mu + sqrt_alpha_diag @ combined
390333
phi_transposed = mu[..., None] + sqrt_alpha_diag @ combined
391334
phi = np.transpose(phi_transposed, axes=(0, 2, 1))
392335

@@ -424,10 +367,9 @@ def numba_funcify_BfgsSampleOp(op, node, **kwargs):
424367
Numba-compiled function that performs conditional BFGS sampling
425368
"""
426369

427-
# Regularization term constant
428370
REGULARISATION_TERM = 1e-8
429371

430-
@numba_basic.numba_njit(fastmath=True, parallel=True)
372+
@numba_basic.numba_njit(fastmath=True, cache=True)
431373
def dense_bfgs_numba(
432374
x, g, alpha, beta, gamma, alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag, u
433375
):
@@ -464,47 +406,37 @@ def dense_bfgs_numba(
464406
"""
465407
L, M, N = u.shape
466408

467-
# Create identity matrix with regularization
468409
IdN = np.eye(N) + np.eye(N) * REGULARISATION_TERM
469410

470-
# Compute inverse Hessian using batched operations
471411
phi = np.empty((L, M, N), dtype=u.dtype)
472412
logdet = np.empty(L, dtype=u.dtype)
473413

474-
for batch_idx in numba.prange(L): # Parallel over batch dimension
475-
# Middle term computation for batch element batch_idx
476-
# middle_term = inv_sqrt_alpha_diag @ beta @ gamma @ beta.T @ inv_sqrt_alpha_diag
477-
beta_l = beta[batch_idx] # (N, 2J)
478-
gamma_l = gamma[batch_idx] # (2J, 2J)
479-
inv_sqrt_alpha_diag_l = inv_sqrt_alpha_diag[batch_idx] # (N, N)
480-
sqrt_alpha_diag_l = sqrt_alpha_diag[batch_idx] # (N, N)
481-
482-
# Compute middle term step by step for efficiency
483-
temp1 = inv_sqrt_alpha_diag_l @ beta_l # (N, 2J)
484-
temp2 = temp1 @ gamma_l # (N, 2J)
485-
temp3 = temp2 @ beta_l.T # (N, N)
486-
middle_term = temp3 @ inv_sqrt_alpha_diag_l # (N, N)
487-
488-
# Full inverse Hessian: H_inv = sqrt_alpha_diag @ (IdN + middle_term) @ sqrt_alpha_diag
414+
for l in range(L): # noqa: E741
415+
beta_l = beta[l]
416+
gamma_l = gamma[l]
417+
inv_sqrt_alpha_diag_l = inv_sqrt_alpha_diag[l]
418+
sqrt_alpha_diag_l = sqrt_alpha_diag[l]
419+
420+
temp1 = inv_sqrt_alpha_diag_l @ beta_l
421+
temp2 = temp1 @ gamma_l
422+
temp3 = temp2 @ beta_l.T
423+
middle_term = temp3 @ inv_sqrt_alpha_diag_l
424+
489425
temp_matrix = IdN + middle_term
490426
H_inv_l = sqrt_alpha_diag_l @ temp_matrix @ sqrt_alpha_diag_l
491427

492-
# Cholesky decomposition (upper triangular)
493428
Lchol_l = np.linalg.cholesky(H_inv_l).T
494429

495-
# Log determinant from Cholesky diagonal
496-
logdet[batch_idx] = 2.0 * np.sum(np.log(np.abs(np.diag(Lchol_l))))
430+
logdet[l] = 2.0 * np.sum(np.log(np.abs(np.diag(Lchol_l))))
497431

498-
# Mean computation: mu = x - H_inv @ g
499-
mu_l = x[batch_idx] - H_inv_l @ g[batch_idx]
432+
mu_l = x[l] - H_inv_l @ g[l]
500433

501-
# Sample generation: phi = mu + Lchol @ u.T
502434
for m in range(M):
503-
phi[batch_idx, m] = mu_l + Lchol_l @ u[batch_idx, m]
435+
phi[l, m] = mu_l + Lchol_l @ u[l, m]
504436

505437
return phi, logdet
506438

507-
@numba_basic.numba_njit(fastmath=True, parallel=True)
439+
@numba_basic.numba_njit(fastmath=True, cache=True)
508440
def sparse_bfgs_numba(
509441
x, g, alpha, beta, gamma, alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag, u
510442
):
@@ -545,38 +477,30 @@ def sparse_bfgs_numba(
545477
phi = np.empty((L, M, N), dtype=u.dtype)
546478
logdet = np.empty(L, dtype=u.dtype)
547479

548-
for batch_idx in numba.prange(L): # Parallel over batch dimension
549-
# QR decomposition of qr_input = inv_sqrt_alpha_diag @ beta
550-
qr_input_l = inv_sqrt_alpha_diag[batch_idx] @ beta[batch_idx]
480+
for l in range(L): # noqa: E741
481+
qr_input_l = inv_sqrt_alpha_diag[l] @ beta[l]
551482
Q_l, R_l = np.linalg.qr(qr_input_l)
552483

553-
# Identity matrix with regularization
554484
IdJJ = np.eye(JJ) + np.eye(JJ) * REGULARISATION_TERM
555485

556-
# Cholesky input: IdJJ + R @ gamma @ R.T
557-
Lchol_input_l = IdJJ + R_l @ gamma[batch_idx] @ R_l.T
486+
Lchol_input_l = IdJJ + R_l @ gamma[l] @ R_l.T
558487

559-
# Cholesky decomposition (upper triangular)
560488
Lchol_l = np.linalg.cholesky(Lchol_input_l).T
561489

562-
# Compute log determinant
563490
logdet_chol = 2.0 * np.sum(np.log(np.abs(np.diag(Lchol_l))))
564-
logdet_alpha = np.sum(np.log(alpha[batch_idx]))
565-
logdet[batch_idx] = logdet_chol + logdet_alpha
491+
logdet_alpha = np.sum(np.log(alpha[l]))
492+
logdet[l] = logdet_chol + logdet_alpha
566493

567-
# Inverse Hessian for sparse case
568-
H_inv_l = alpha_diag[batch_idx] + beta[batch_idx] @ gamma[batch_idx] @ beta[batch_idx].T
494+
H_inv_l = alpha_diag[l] + beta[l] @ gamma[l] @ beta[l].T
569495

570-
# Mean computation
571-
mu_l = x[batch_idx] - H_inv_l @ g[batch_idx]
496+
mu_l = x[l] - H_inv_l @ g[l]
572497

573-
# Complex sampling transformation for sparse case
574498
Q_Lchol_diff = Q_l @ (Lchol_l - IdJJ)
575499

576500
for m in range(M):
577-
Qt_u_lm = Q_l.T @ u[batch_idx, m]
578-
combined = Q_Lchol_diff @ Qt_u_lm + u[batch_idx, m]
579-
phi[batch_idx, m] = mu_l + sqrt_alpha_diag[batch_idx] @ combined
501+
Qt_u_lm = Q_l.T @ u[l, m]
502+
combined = Q_Lchol_diff @ Qt_u_lm + u[l, m]
503+
phi[l, m] = mu_l + sqrt_alpha_diag[l] @ combined
580504

581505
return phi, logdet
582506

@@ -604,7 +528,6 @@ def bfgs_sample_numba(
604528
L, M, N = u.shape
605529
JJ = beta.shape[2]
606530

607-
# Numba-optimized conditional compilation
608531
if JJ >= N:
609532
return dense_bfgs_numba(
610533
x, g, alpha, beta, gamma, alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag, u

0 commit comments

Comments
 (0)