13
13
- Existing JAX dispatch in jax_dispatch.py
14
14
"""
15
15
16
- import numba
17
16
import numpy as np
18
17
import pytensor .tensor as pt
19
18
20
19
from pytensor .graph import Apply , Op
21
20
from pytensor .link .numba .dispatch import basic as numba_basic
22
21
from pytensor .link .numba .dispatch import numba_funcify
23
22
24
- # Import existing ops for registration
25
23
26
- # Module version for tracking
27
- __version__ = "0.1.0"
28
-
29
-
30
- # NOTE: LogLike Op registration for Numba is intentionally removed
31
- #
32
- # The LogLike Op cannot be compiled with Numba due to fundamental incompatibility:
33
- # - LogLike uses arbitrary Python function closures (logp_func)
34
- # - Numba requires concrete, statically-typeable operations
35
- # - Function closures from PyTensor compilation cannot be analyzed by Numba
36
- #
37
- # Instead, the vectorized_logp module handles Numba mode by using scan-based
38
- # approaches that avoid LogLike Op entirely.
39
- #
40
- # This is documented as a known limitation in CLAUDE.md
41
-
42
-
43
- # @numba_funcify.register(LogLike) # DISABLED - see note above
24
+ # @numba_funcify.register(LogLike) # DISABLED
44
25
def _disabled_numba_funcify_LogLike (op , node , ** kwargs ):
45
26
"""DISABLED: LogLike Op registration for Numba.
46
27
@@ -59,7 +40,6 @@ def _disabled_numba_funcify_LogLike(op, node, **kwargs):
59
40
)
60
41
61
42
62
- # Custom Op for Numba-compatible chi matrix computation
63
43
class NumbaChiMatrixOp (Op ):
64
44
"""Numba-optimized Chi matrix computation.
65
45
@@ -96,7 +76,7 @@ def make_node(self, diff):
96
76
Computation node for chi matrix
97
77
"""
98
78
diff = pt .as_tensor_variable (diff )
99
- # Output shape: (L, N, J) - use None for dynamic dimensions
79
+
100
80
output = pt .tensor (
101
81
dtype = diff .dtype ,
102
82
shape = (None , None , self .J ), # Only J is static
@@ -118,21 +98,18 @@ def perform(self, node, inputs, outputs):
118
98
outputs : list
119
99
Output arrays [chi_matrix]
120
100
"""
121
- diff = inputs [0 ] # Shape: (L, N)
101
+ diff = inputs [0 ]
122
102
L , N = diff .shape
123
103
J = self .J
124
104
125
- # Create output matrix
126
105
chi_matrix = np .zeros ((L , N , J ), dtype = diff .dtype )
127
106
128
- # Compute sliding window matrix (same logic as JAX version)
107
+ # Compute sliding window matrix
129
108
for idx in range (L ):
130
- # For each row idx, we want the last J values of diff up to position idx
131
109
start_idx = max (0 , idx - J + 1 )
132
110
end_idx = idx + 1
133
111
134
- # Get the relevant slice
135
- relevant_diff = diff [start_idx :end_idx ] # Shape: (actual_length, N)
112
+ relevant_diff = diff [start_idx :end_idx ]
136
113
actual_length = end_idx - start_idx
137
114
138
115
# If we have fewer than J values, pad with zeros at the beginning
@@ -142,8 +119,7 @@ def perform(self, node, inputs, outputs):
142
119
else :
143
120
padded_diff = relevant_diff
144
121
145
- # Assign to chi matrix
146
- chi_matrix [idx ] = padded_diff .T # Transpose to get (N, J)
122
+ chi_matrix [idx ] = padded_diff .T
147
123
148
124
outputs [0 ][0 ] = chi_matrix
149
125
@@ -198,11 +174,9 @@ def chi_matrix_numba(diff):
198
174
199
175
# Optimized sliding window with manual loop unrolling
200
176
for batch_idx in range (L ):
201
- # Efficient window extraction
202
177
start_idx = max (0 , batch_idx - J + 1 )
203
178
window_size = min (J , batch_idx + 1 )
204
179
205
- # Direct memory copy for efficiency
206
180
for j in range (window_size ):
207
181
source_idx = start_idx + j
208
182
target_idx = J - window_size + j
@@ -214,7 +188,6 @@ def chi_matrix_numba(diff):
214
188
return chi_matrix_numba
215
189
216
190
217
- # Custom Op for Numba-compatible BFGS sampling
218
191
class NumbaBfgsSampleOp (Op ):
219
192
"""Numba-optimized BFGS sampling with conditional logic.
220
193
@@ -262,7 +235,6 @@ def make_node(
262
235
Apply
263
236
Computation node with two outputs: phi and logdet
264
237
"""
265
- # Convert all inputs to tensor variables (same as JAX version)
266
238
inputs = [
267
239
pt .as_tensor_variable (inp )
268
240
for inp in [
@@ -278,10 +250,8 @@ def make_node(
278
250
]
279
251
]
280
252
281
- # Output phi: shape (L, M, N) - same as u
282
253
phi_out = pt .tensor (dtype = u .dtype , shape = (None , None , None ))
283
254
284
- # Output logdet: shape (L,) - same as first dimension of x
285
255
logdet_out = pt .tensor (dtype = u .dtype , shape = (None ,))
286
256
287
257
return Apply (self , inputs , [phi_out , logdet_out ])
@@ -299,20 +269,12 @@ def perform(self, node, inputs, outputs):
299
269
300
270
x , g , alpha , beta , gamma , alpha_diag , inv_sqrt_alpha_diag , sqrt_alpha_diag , u = inputs
301
271
302
- # Get shapes
303
272
L , M , N = u .shape
304
273
L , N , JJ = beta .shape
305
274
306
- # Define the condition: use dense when JJ >= N, sparse otherwise
307
- condition = JJ >= N
308
-
309
- # Regularization term (from pathfinder.py REGULARISATION_TERM)
310
275
REGULARISATION_TERM = 1e-8
311
276
312
- if condition :
313
- # Dense BFGS sampling branch
314
-
315
- # Create identity matrix with regularization
277
+ if JJ >= N :
316
278
IdN = np .eye (N )[None , ...]
317
279
IdN = IdN + IdN * REGULARISATION_TERM
318
280
@@ -325,68 +287,49 @@ def perform(self, node, inputs, outputs):
325
287
@ inv_sqrt_alpha_diag
326
288
)
327
289
328
- # Full inverse Hessian
329
290
H_inv = sqrt_alpha_diag @ (IdN + middle_term ) @ sqrt_alpha_diag
330
291
331
- # Cholesky decomposition (upper triangular)
332
292
Lchol = np .array ([cholesky (H_inv [i ], lower = False ) for i in range (L )])
333
293
334
- # Compute log determinant from Cholesky diagonal
335
294
logdet = 2.0 * np .sum (np .log (np .abs (np .diagonal (Lchol , axis1 = - 2 , axis2 = - 1 ))), axis = - 1 )
336
295
337
- # Compute mean: mu = x - H_inv @ g
338
296
mu = x - np .sum (H_inv * g [..., None , :], axis = - 1 )
339
297
340
- # Sample: phi = mu + Lchol @ u.T, then transpose back
341
298
phi_transposed = mu [..., None ] + Lchol @ np .transpose (u , axes = (0 , 2 , 1 ))
342
299
phi = np .transpose (phi_transposed , axes = (0 , 2 , 1 ))
343
300
344
301
else :
345
- # Sparse BFGS sampling branch
346
-
347
- # QR decomposition of qr_input = inv_sqrt_alpha_diag @ beta
302
+ # Sparse BFGS sampling
348
303
qr_input = inv_sqrt_alpha_diag @ beta
349
304
350
- # NumPy QR decomposition (applied along batch dimension)
351
- Q = np .zeros ((L , qr_input .shape [1 ], qr_input .shape [2 ])) # (L, N, JJ)
352
- R = np .zeros ((L , qr_input .shape [2 ], qr_input .shape [2 ])) # (L, JJ, JJ)
305
+ Q = np .zeros ((L , qr_input .shape [1 ], qr_input .shape [2 ]))
306
+ R = np .zeros ((L , qr_input .shape [2 ], qr_input .shape [2 ]))
353
307
for i in range (L ):
354
308
Q [i ], R [i ] = qr (qr_input [i ], mode = "economic" )
355
309
356
- # Identity matrix with regularization
357
310
IdJJ = np .eye (R .shape [1 ])[None , ...]
358
311
IdJJ = IdJJ + IdJJ * REGULARISATION_TERM
359
312
360
- # Cholesky input: IdJJ + R @ gamma @ R.T
361
313
Lchol_input = IdJJ + R @ gamma @ np .transpose (R , axes = (0 , 2 , 1 ))
362
314
363
- # Cholesky decomposition (upper triangular)
364
315
Lchol = np .array ([cholesky (Lchol_input [i ], lower = False ) for i in range (L )])
365
316
366
- # Compute log determinant: includes both Cholesky and alpha terms
367
317
logdet_chol = 2.0 * np .sum (
368
318
np .log (np .abs (np .diagonal (Lchol , axis1 = - 2 , axis2 = - 1 ))), axis = - 1
369
319
)
370
320
logdet_alpha = np .sum (np .log (alpha ), axis = - 1 )
371
321
logdet = logdet_chol + logdet_alpha
372
322
373
- # Compute inverse Hessian for sparse case: H_inv = alpha_diag + beta @ gamma @ beta.T
374
323
H_inv = alpha_diag + (beta @ gamma @ np .transpose (beta , axes = (0 , 2 , 1 )))
375
324
376
- # Compute mean: mu = x - H_inv @ g
377
325
mu = x - np .sum (H_inv * g [..., None , :], axis = - 1 )
378
326
379
- # Complex sampling transformation for sparse case
380
- # First part: Q @ (Lchol - IdJJ)
381
327
Q_Lchol_diff = Q @ (Lchol - IdJJ )
382
328
383
- # Second part: Q.T @ u.T
384
329
Qt_u = np .transpose (Q , axes = (0 , 2 , 1 )) @ np .transpose (u , axes = (0 , 2 , 1 ))
385
330
386
- # Combine: (Q @ (Lchol - IdJJ)) @ (Q.T @ u.T) + u.T
387
331
combined = Q_Lchol_diff @ Qt_u + np .transpose (u , axes = (0 , 2 , 1 ))
388
332
389
- # Final transformation: mu + sqrt_alpha_diag @ combined
390
333
phi_transposed = mu [..., None ] + sqrt_alpha_diag @ combined
391
334
phi = np .transpose (phi_transposed , axes = (0 , 2 , 1 ))
392
335
@@ -424,10 +367,9 @@ def numba_funcify_BfgsSampleOp(op, node, **kwargs):
424
367
Numba-compiled function that performs conditional BFGS sampling
425
368
"""
426
369
427
- # Regularization term constant
428
370
REGULARISATION_TERM = 1e-8
429
371
430
- @numba_basic .numba_njit (fastmath = True , parallel = True )
372
+ @numba_basic .numba_njit (fastmath = True , cache = True )
431
373
def dense_bfgs_numba (
432
374
x , g , alpha , beta , gamma , alpha_diag , inv_sqrt_alpha_diag , sqrt_alpha_diag , u
433
375
):
@@ -464,47 +406,37 @@ def dense_bfgs_numba(
464
406
"""
465
407
L , M , N = u .shape
466
408
467
- # Create identity matrix with regularization
468
409
IdN = np .eye (N ) + np .eye (N ) * REGULARISATION_TERM
469
410
470
- # Compute inverse Hessian using batched operations
471
411
phi = np .empty ((L , M , N ), dtype = u .dtype )
472
412
logdet = np .empty (L , dtype = u .dtype )
473
413
474
- for batch_idx in numba .prange (L ): # Parallel over batch dimension
475
- # Middle term computation for batch element batch_idx
476
- # middle_term = inv_sqrt_alpha_diag @ beta @ gamma @ beta.T @ inv_sqrt_alpha_diag
477
- beta_l = beta [batch_idx ] # (N, 2J)
478
- gamma_l = gamma [batch_idx ] # (2J, 2J)
479
- inv_sqrt_alpha_diag_l = inv_sqrt_alpha_diag [batch_idx ] # (N, N)
480
- sqrt_alpha_diag_l = sqrt_alpha_diag [batch_idx ] # (N, N)
481
-
482
- # Compute middle term step by step for efficiency
483
- temp1 = inv_sqrt_alpha_diag_l @ beta_l # (N, 2J)
484
- temp2 = temp1 @ gamma_l # (N, 2J)
485
- temp3 = temp2 @ beta_l .T # (N, N)
486
- middle_term = temp3 @ inv_sqrt_alpha_diag_l # (N, N)
487
-
488
- # Full inverse Hessian: H_inv = sqrt_alpha_diag @ (IdN + middle_term) @ sqrt_alpha_diag
414
+ for l in range (L ): # noqa: E741
415
+ beta_l = beta [l ]
416
+ gamma_l = gamma [l ]
417
+ inv_sqrt_alpha_diag_l = inv_sqrt_alpha_diag [l ]
418
+ sqrt_alpha_diag_l = sqrt_alpha_diag [l ]
419
+
420
+ temp1 = inv_sqrt_alpha_diag_l @ beta_l
421
+ temp2 = temp1 @ gamma_l
422
+ temp3 = temp2 @ beta_l .T
423
+ middle_term = temp3 @ inv_sqrt_alpha_diag_l
424
+
489
425
temp_matrix = IdN + middle_term
490
426
H_inv_l = sqrt_alpha_diag_l @ temp_matrix @ sqrt_alpha_diag_l
491
427
492
- # Cholesky decomposition (upper triangular)
493
428
Lchol_l = np .linalg .cholesky (H_inv_l ).T
494
429
495
- # Log determinant from Cholesky diagonal
496
- logdet [batch_idx ] = 2.0 * np .sum (np .log (np .abs (np .diag (Lchol_l ))))
430
+ logdet [l ] = 2.0 * np .sum (np .log (np .abs (np .diag (Lchol_l ))))
497
431
498
- # Mean computation: mu = x - H_inv @ g
499
- mu_l = x [batch_idx ] - H_inv_l @ g [batch_idx ]
432
+ mu_l = x [l ] - H_inv_l @ g [l ]
500
433
501
- # Sample generation: phi = mu + Lchol @ u.T
502
434
for m in range (M ):
503
- phi [batch_idx , m ] = mu_l + Lchol_l @ u [batch_idx , m ]
435
+ phi [l , m ] = mu_l + Lchol_l @ u [l , m ]
504
436
505
437
return phi , logdet
506
438
507
- @numba_basic .numba_njit (fastmath = True , parallel = True )
439
+ @numba_basic .numba_njit (fastmath = True , cache = True )
508
440
def sparse_bfgs_numba (
509
441
x , g , alpha , beta , gamma , alpha_diag , inv_sqrt_alpha_diag , sqrt_alpha_diag , u
510
442
):
@@ -545,38 +477,30 @@ def sparse_bfgs_numba(
545
477
phi = np .empty ((L , M , N ), dtype = u .dtype )
546
478
logdet = np .empty (L , dtype = u .dtype )
547
479
548
- for batch_idx in numba .prange (L ): # Parallel over batch dimension
549
- # QR decomposition of qr_input = inv_sqrt_alpha_diag @ beta
550
- qr_input_l = inv_sqrt_alpha_diag [batch_idx ] @ beta [batch_idx ]
480
+ for l in range (L ): # noqa: E741
481
+ qr_input_l = inv_sqrt_alpha_diag [l ] @ beta [l ]
551
482
Q_l , R_l = np .linalg .qr (qr_input_l )
552
483
553
- # Identity matrix with regularization
554
484
IdJJ = np .eye (JJ ) + np .eye (JJ ) * REGULARISATION_TERM
555
485
556
- # Cholesky input: IdJJ + R @ gamma @ R.T
557
- Lchol_input_l = IdJJ + R_l @ gamma [batch_idx ] @ R_l .T
486
+ Lchol_input_l = IdJJ + R_l @ gamma [l ] @ R_l .T
558
487
559
- # Cholesky decomposition (upper triangular)
560
488
Lchol_l = np .linalg .cholesky (Lchol_input_l ).T
561
489
562
- # Compute log determinant
563
490
logdet_chol = 2.0 * np .sum (np .log (np .abs (np .diag (Lchol_l ))))
564
- logdet_alpha = np .sum (np .log (alpha [batch_idx ]))
565
- logdet [batch_idx ] = logdet_chol + logdet_alpha
491
+ logdet_alpha = np .sum (np .log (alpha [l ]))
492
+ logdet [l ] = logdet_chol + logdet_alpha
566
493
567
- # Inverse Hessian for sparse case
568
- H_inv_l = alpha_diag [batch_idx ] + beta [batch_idx ] @ gamma [batch_idx ] @ beta [batch_idx ].T
494
+ H_inv_l = alpha_diag [l ] + beta [l ] @ gamma [l ] @ beta [l ].T
569
495
570
- # Mean computation
571
- mu_l = x [batch_idx ] - H_inv_l @ g [batch_idx ]
496
+ mu_l = x [l ] - H_inv_l @ g [l ]
572
497
573
- # Complex sampling transformation for sparse case
574
498
Q_Lchol_diff = Q_l @ (Lchol_l - IdJJ )
575
499
576
500
for m in range (M ):
577
- Qt_u_lm = Q_l .T @ u [batch_idx , m ]
578
- combined = Q_Lchol_diff @ Qt_u_lm + u [batch_idx , m ]
579
- phi [batch_idx , m ] = mu_l + sqrt_alpha_diag [batch_idx ] @ combined
501
+ Qt_u_lm = Q_l .T @ u [l , m ]
502
+ combined = Q_Lchol_diff @ Qt_u_lm + u [l , m ]
503
+ phi [l , m ] = mu_l + sqrt_alpha_diag [l ] @ combined
580
504
581
505
return phi , logdet
582
506
@@ -604,7 +528,6 @@ def bfgs_sample_numba(
604
528
L , M , N = u .shape
605
529
JJ = beta .shape [2 ]
606
530
607
- # Numba-optimized conditional compilation
608
531
if JJ >= N :
609
532
return dense_bfgs_numba (
610
533
x , g , alpha , beta , gamma , alpha_diag , inv_sqrt_alpha_diag , sqrt_alpha_diag , u
0 commit comments