2020from pytensor .link .numba .dispatch import numba_funcify
2121
2222
23- # @numba_funcify.register(LogLike) # DISABLED
24- def _disabled_numba_funcify_LogLike (op , node , ** kwargs ):
25- """DISABLED: LogLike Op registration for Numba.
26-
27- This registration is intentionally disabled because LogLike Op
28- cannot be compiled with Numba due to function closure limitations.
29-
30- The error would be:
31- numba.core.errors.TypingError: Untyped global name 'actual_logp_func':
32- Cannot determine Numba type of <class 'function'>
33-
34- Instead, use the scan-based approach in vectorized_logp module.
35- """
36- raise NotImplementedError (
37- "LogLike Op cannot be compiled with Numba due to function closure limitations. "
38- "Use scan-based vectorization instead."
39- )
40-
41-
4223class NumbaChiMatrixOp (Op ):
4324 """Numba-optimized Chi matrix computation.
4425
@@ -78,7 +59,7 @@ def make_node(self, diff):
7859
7960 output = pt .tensor (
8061 dtype = diff .dtype ,
81- shape = (None , None , self .J ), # Only J is static
62+ shape = (None , None , self .J ),
8263 )
8364 return Apply (self , [diff ], [output ])
8465
@@ -122,7 +103,6 @@ def __hash__(self):
122103def numba_funcify_ChiMatrixOp (op , node , ** kwargs ):
123104 """Numba implementation for ChiMatrix sliding window computation with smart parallelization.
124105
125- Phase 6: Uses intelligent parallelization and optimized memory access patterns.
126106 Automatically selects between parallel and sequential versions based on problem size.
127107
128108 Parameters
@@ -392,7 +372,7 @@ def numba_funcify_BfgsSampleOp(op, node, **kwargs):
392372 """
393373
394374 REGULARISATION_TERM = 1e-8
395- USE_CUSTOM_THRESHOLD = 100 # Use custom linear algebra for N < 100
375+ CUSTOM_THRESHOLD = 100
396376
397377 @numba_basic .numba_njit (
398378 fastmath = True , cache = True , error_model = "numpy" , boundscheck = False , inline = "never"
@@ -899,7 +879,7 @@ def dense_bfgs_with_memory_pool(
899879 matmul_inplace (sqrt_alpha_diag_l , temp_matrix_NN3 , temp_matrix_NN )
900880 matmul_inplace (temp_matrix_NN , sqrt_alpha_diag_l , H_inv_buffer )
901881
902- if N <= USE_CUSTOM_THRESHOLD :
882+ if N <= CUSTOM_THRESHOLD :
903883 Lchol_l = cholesky_small (H_inv_buffer , upper = True )
904884 else :
905885 Lchol_l = np .linalg .cholesky (H_inv_buffer ).T
@@ -968,7 +948,7 @@ def sparse_bfgs_with_memory_pool(
968948 for l in range (L ): # noqa: E741
969949 matmul_inplace (inv_sqrt_alpha_diag [l ], beta [l ], qr_input_buffer )
970950
971- if N <= USE_CUSTOM_THRESHOLD :
951+ if N <= CUSTOM_THRESHOLD :
972952 Q_l , R_l = qr_small (qr_input_buffer )
973953 copy_matrix_inplace (Q_l , Q_buffer )
974954 copy_matrix_inplace (R_l , R_buffer )
@@ -986,7 +966,7 @@ def sparse_bfgs_with_memory_pool(
986966 temp_matrix_JJ2 [i , j ] = sum_val
987967 add_inplace (Id_JJ_reg , temp_matrix_JJ2 , temp_matrix_JJ )
988968
989- if JJ <= USE_CUSTOM_THRESHOLD :
969+ if JJ <= CUSTOM_THRESHOLD :
990970 Lchol_l = cholesky_small (temp_matrix_JJ , upper = True )
991971 else :
992972 Lchol_l = np .linalg .cholesky (temp_matrix_JJ ).T
@@ -1101,7 +1081,7 @@ def dense_bfgs_numba(
11011081 sqrt_alpha_diag_l , matmul_contiguous (temp_matrix , sqrt_alpha_diag_l )
11021082 )
11031083
1104- if N <= USE_CUSTOM_THRESHOLD :
1084+ if N <= CUSTOM_THRESHOLD :
11051085 # 3-5x speedup over BLAS
11061086 Lchol_l = cholesky_small (H_inv_l , upper = True )
11071087 else :
@@ -1188,8 +1168,7 @@ def sparse_bfgs_numba(
11881168 for l in range (L ): # noqa: E741
11891169 qr_input_l = inv_sqrt_alpha_diag [l ] @ beta [l ]
11901170
1191- if N <= USE_CUSTOM_THRESHOLD :
1192- # 3-5x speedup over BLAS
1171+ if N <= CUSTOM_THRESHOLD :
11931172 Q_l , R_l = qr_small (qr_input_l )
11941173 else :
11951174 Q_l , R_l = np .linalg .qr (qr_input_l )
@@ -1203,10 +1182,9 @@ def sparse_bfgs_numba(
12031182
12041183 Lchol_input_l = temp_RgammaRT .copy ()
12051184 for i in range (JJ ):
1206- Lchol_input_l [i , i ] += IdJJ [i , i ] # Add identity efficiently
1185+ Lchol_input_l [i , i ] += IdJJ [i , i ]
12071186
1208- if JJ <= USE_CUSTOM_THRESHOLD :
1209- # 3-5x speedup over BLAS
1187+ if JJ <= CUSTOM_THRESHOLD :
12101188 Lchol_l = cholesky_small (Lchol_input_l , upper = True )
12111189 else :
12121190 Lchol_l = np .linalg .cholesky (Lchol_input_l ).T
@@ -1346,10 +1324,6 @@ def bfgs_sample_numba(
13461324 x , g , alpha , beta , gamma , alpha_diag , inv_sqrt_alpha_diag , sqrt_alpha_diag , u
13471325 )
13481326
1349- # ===============================================================================
1350- # Phase 6: Smart Parallelization
1351- # ===============================================================================
1352-
13531327 @numba_basic .numba_njit (
13541328 dense_bfgs_signature ,
13551329 fastmath = True ,
@@ -1426,7 +1400,7 @@ def dense_bfgs_parallel(
14261400 sqrt_alpha_diag_l , matmul_contiguous (temp_matrix , sqrt_alpha_diag_l )
14271401 )
14281402
1429- if N <= USE_CUSTOM_THRESHOLD :
1403+ if N <= CUSTOM_THRESHOLD :
14301404 Lchol_l = cholesky_small (H_inv_l , upper = True )
14311405 else :
14321406 Lchol_l = np .linalg .cholesky (H_inv_l ).T
@@ -1504,7 +1478,7 @@ def sparse_bfgs_parallel(
15041478 beta_l = ensure_contiguous_2d (beta [l ])
15051479 qr_input_l = matmul_contiguous (inv_sqrt_alpha_diag_l , beta_l )
15061480
1507- if N <= USE_CUSTOM_THRESHOLD :
1481+ if N <= CUSTOM_THRESHOLD :
15081482 Q_l , R_l = qr_small (qr_input_l )
15091483 else :
15101484 Q_l , R_l = np .linalg .qr (qr_input_l )
@@ -1520,7 +1494,7 @@ def sparse_bfgs_parallel(
15201494 for i in range (JJ ):
15211495 Lchol_input_l [i , i ] += IdJJ [i , i ]
15221496
1523- if JJ <= USE_CUSTOM_THRESHOLD :
1497+ if JJ <= CUSTOM_THRESHOLD :
15241498 Lchol_l = cholesky_small (Lchol_input_l , upper = True )
15251499 else :
15261500 Lchol_l = np .linalg .cholesky (Lchol_input_l ).T
@@ -1643,7 +1617,6 @@ def smart_dispatcher(
16431617 """
16441618 L , M , N = u .shape
16451619
1646- # This avoids thread overhead for small problems
16471620 if L >= 4 :
16481621 return bfgs_sample_parallel (
16491622 x , g , alpha , beta , gamma , alpha_diag , inv_sqrt_alpha_diag , sqrt_alpha_diag , u
@@ -1655,5 +1628,4 @@ def smart_dispatcher(
16551628
16561629 return smart_dispatcher
16571630
1658- # Phase 6: Return intelligent parallel dispatcher
16591631 return create_parallel_dispatcher ()
0 commit comments