Skip to content

Commit 0ade157

Browse files
committed
Better type checking and patterns
1 parent c675989 commit 0ade157

File tree

3 files changed

+32
-104
lines changed

3 files changed

+32
-104
lines changed

pymc_extras/inference/pathfinder/numba_dispatch.py

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,6 @@
2020
from pytensor.link.numba.dispatch import numba_funcify
2121

2222

23-
# @numba_funcify.register(LogLike) # DISABLED
24-
def _disabled_numba_funcify_LogLike(op, node, **kwargs):
25-
"""DISABLED: LogLike Op registration for Numba.
26-
27-
This registration is intentionally disabled because LogLike Op
28-
cannot be compiled with Numba due to function closure limitations.
29-
30-
The error would be:
31-
numba.core.errors.TypingError: Untyped global name 'actual_logp_func':
32-
Cannot determine Numba type of <class 'function'>
33-
34-
Instead, use the scan-based approach in vectorized_logp module.
35-
"""
36-
raise NotImplementedError(
37-
"LogLike Op cannot be compiled with Numba due to function closure limitations. "
38-
"Use scan-based vectorization instead."
39-
)
40-
41-
4223
class NumbaChiMatrixOp(Op):
4324
"""Numba-optimized Chi matrix computation.
4425
@@ -78,7 +59,7 @@ def make_node(self, diff):
7859

7960
output = pt.tensor(
8061
dtype=diff.dtype,
81-
shape=(None, None, self.J), # Only J is static
62+
shape=(None, None, self.J),
8263
)
8364
return Apply(self, [diff], [output])
8465

@@ -122,7 +103,6 @@ def __hash__(self):
122103
def numba_funcify_ChiMatrixOp(op, node, **kwargs):
123104
"""Numba implementation for ChiMatrix sliding window computation with smart parallelization.
124105
125-
Phase 6: Uses intelligent parallelization and optimized memory access patterns.
126106
Automatically selects between parallel and sequential versions based on problem size.
127107
128108
Parameters
@@ -392,7 +372,7 @@ def numba_funcify_BfgsSampleOp(op, node, **kwargs):
392372
"""
393373

394374
REGULARISATION_TERM = 1e-8
395-
USE_CUSTOM_THRESHOLD = 100 # Use custom linear algebra for N < 100
375+
CUSTOM_THRESHOLD = 100
396376

397377
@numba_basic.numba_njit(
398378
fastmath=True, cache=True, error_model="numpy", boundscheck=False, inline="never"
@@ -899,7 +879,7 @@ def dense_bfgs_with_memory_pool(
899879
matmul_inplace(sqrt_alpha_diag_l, temp_matrix_NN3, temp_matrix_NN)
900880
matmul_inplace(temp_matrix_NN, sqrt_alpha_diag_l, H_inv_buffer)
901881

902-
if N <= USE_CUSTOM_THRESHOLD:
882+
if N <= CUSTOM_THRESHOLD:
903883
Lchol_l = cholesky_small(H_inv_buffer, upper=True)
904884
else:
905885
Lchol_l = np.linalg.cholesky(H_inv_buffer).T
@@ -968,7 +948,7 @@ def sparse_bfgs_with_memory_pool(
968948
for l in range(L): # noqa: E741
969949
matmul_inplace(inv_sqrt_alpha_diag[l], beta[l], qr_input_buffer)
970950

971-
if N <= USE_CUSTOM_THRESHOLD:
951+
if N <= CUSTOM_THRESHOLD:
972952
Q_l, R_l = qr_small(qr_input_buffer)
973953
copy_matrix_inplace(Q_l, Q_buffer)
974954
copy_matrix_inplace(R_l, R_buffer)
@@ -986,7 +966,7 @@ def sparse_bfgs_with_memory_pool(
986966
temp_matrix_JJ2[i, j] = sum_val
987967
add_inplace(Id_JJ_reg, temp_matrix_JJ2, temp_matrix_JJ)
988968

989-
if JJ <= USE_CUSTOM_THRESHOLD:
969+
if JJ <= CUSTOM_THRESHOLD:
990970
Lchol_l = cholesky_small(temp_matrix_JJ, upper=True)
991971
else:
992972
Lchol_l = np.linalg.cholesky(temp_matrix_JJ).T
@@ -1101,7 +1081,7 @@ def dense_bfgs_numba(
11011081
sqrt_alpha_diag_l, matmul_contiguous(temp_matrix, sqrt_alpha_diag_l)
11021082
)
11031083

1104-
if N <= USE_CUSTOM_THRESHOLD:
1084+
if N <= CUSTOM_THRESHOLD:
11051085
# 3-5x speedup over BLAS
11061086
Lchol_l = cholesky_small(H_inv_l, upper=True)
11071087
else:
@@ -1188,8 +1168,7 @@ def sparse_bfgs_numba(
11881168
for l in range(L): # noqa: E741
11891169
qr_input_l = inv_sqrt_alpha_diag[l] @ beta[l]
11901170

1191-
if N <= USE_CUSTOM_THRESHOLD:
1192-
# 3-5x speedup over BLAS
1171+
if N <= CUSTOM_THRESHOLD:
11931172
Q_l, R_l = qr_small(qr_input_l)
11941173
else:
11951174
Q_l, R_l = np.linalg.qr(qr_input_l)
@@ -1203,10 +1182,9 @@ def sparse_bfgs_numba(
12031182

12041183
Lchol_input_l = temp_RgammaRT.copy()
12051184
for i in range(JJ):
1206-
Lchol_input_l[i, i] += IdJJ[i, i] # Add identity efficiently
1185+
Lchol_input_l[i, i] += IdJJ[i, i]
12071186

1208-
if JJ <= USE_CUSTOM_THRESHOLD:
1209-
# 3-5x speedup over BLAS
1187+
if JJ <= CUSTOM_THRESHOLD:
12101188
Lchol_l = cholesky_small(Lchol_input_l, upper=True)
12111189
else:
12121190
Lchol_l = np.linalg.cholesky(Lchol_input_l).T
@@ -1346,10 +1324,6 @@ def bfgs_sample_numba(
13461324
x, g, alpha, beta, gamma, alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag, u
13471325
)
13481326

1349-
# ===============================================================================
1350-
# Phase 6: Smart Parallelization
1351-
# ===============================================================================
1352-
13531327
@numba_basic.numba_njit(
13541328
dense_bfgs_signature,
13551329
fastmath=True,
@@ -1426,7 +1400,7 @@ def dense_bfgs_parallel(
14261400
sqrt_alpha_diag_l, matmul_contiguous(temp_matrix, sqrt_alpha_diag_l)
14271401
)
14281402

1429-
if N <= USE_CUSTOM_THRESHOLD:
1403+
if N <= CUSTOM_THRESHOLD:
14301404
Lchol_l = cholesky_small(H_inv_l, upper=True)
14311405
else:
14321406
Lchol_l = np.linalg.cholesky(H_inv_l).T
@@ -1504,7 +1478,7 @@ def sparse_bfgs_parallel(
15041478
beta_l = ensure_contiguous_2d(beta[l])
15051479
qr_input_l = matmul_contiguous(inv_sqrt_alpha_diag_l, beta_l)
15061480

1507-
if N <= USE_CUSTOM_THRESHOLD:
1481+
if N <= CUSTOM_THRESHOLD:
15081482
Q_l, R_l = qr_small(qr_input_l)
15091483
else:
15101484
Q_l, R_l = np.linalg.qr(qr_input_l)
@@ -1520,7 +1494,7 @@ def sparse_bfgs_parallel(
15201494
for i in range(JJ):
15211495
Lchol_input_l[i, i] += IdJJ[i, i]
15221496

1523-
if JJ <= USE_CUSTOM_THRESHOLD:
1497+
if JJ <= CUSTOM_THRESHOLD:
15241498
Lchol_l = cholesky_small(Lchol_input_l, upper=True)
15251499
else:
15261500
Lchol_l = np.linalg.cholesky(Lchol_input_l).T
@@ -1643,7 +1617,6 @@ def smart_dispatcher(
16431617
"""
16441618
L, M, N = u.shape
16451619

1646-
# This avoids thread overhead for small problems
16471620
if L >= 4:
16481621
return bfgs_sample_parallel(
16491622
x, g, alpha, beta, gamma, alpha_diag, inv_sqrt_alpha_diag, sqrt_alpha_diag, u
@@ -1655,5 +1628,4 @@ def smart_dispatcher(
16551628

16561629
return smart_dispatcher
16571630

1658-
# Phase 6: Return intelligent parallel dispatcher
16591631
return create_parallel_dispatcher()

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,7 +1776,7 @@ def multipath_pathfinder(
17761776
TimeRemainingColumn(),
17771777
TextColumn("/"),
17781778
TimeElapsedColumn(),
1779-
console=Console(), # Use default theme if default_progress_theme is None
1779+
console=Console(),
17801780
disable=not progressbar,
17811781
)
17821782
with progress:
@@ -2031,7 +2031,6 @@ def fit_pathfinder(
20312031
pathfinder_samples = mp_result.samples
20322032
elif inference_backend == "numba":
20332033
# Numba backend: Use PyTensor compilation with Numba mode
2034-
# Import Numba dispatch to register custom Op conversions
20352034

20362035
numba_compile_kwargs = {"mode": "NUMBA", **compile_kwargs}
20372036
mp_result = multipath_pathfinder(

pymc_extras/inference/pathfinder/vectorized_logp.py

Lines changed: 19 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,9 @@
1-
# Copyright 2022 The PyMC Developers
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
151
"""
162
Native PyTensor vectorized logp implementation.
173
18-
This module provides a PyTensor First approach to vectorizing log-probability
4+
This module provides a PyTensor-based approach to vectorizing log-probability
195
computations, eliminating the need for custom LogLike Op and ensuring automatic
206
backend compatibility through native PyTensor operations.
21-
22-
Expert Guidance Applied:
23-
- Uses vectorize_graph instead of custom Ops (Jesse Grabowski's recommendation)
24-
- Eliminates numpy.apply_along_axis dependency
25-
- Leverages existing PyTensor functionality per "PyTensor First" principle
267
"""
278

289
from collections.abc import Callable as CallableType
@@ -40,9 +21,8 @@ def create_vectorized_logp_graph(
4021
"""
4122
Create a vectorized log-probability computation graph using native PyTensor operations.
4223
43-
IMPORTANT: This function now detects the interface type and compilation mode to handle
44-
both compiled functions and symbolic expressions properly, with special handling for
45-
Numba mode to avoid LogLike Op compilation issues.
24+
This function determines the appropriate vectorization strategy based on the input type
25+
and compilation mode.
4626
4727
Parameters
4828
----------
@@ -57,45 +37,35 @@ def create_vectorized_logp_graph(
5737
-------
5838
Callable
5939
Function that takes a batch of parameter vectors and returns vectorized logp values
60-
61-
Notes
62-
-----
63-
This implementation follows PyTensor expert recommendations:
64-
- "Can the perform method of that `Loglike` op be directly written in pytensor?" - Jesse Grabowski
65-
- "PyTensor vectorize / vectorize_graph directly" - Ricardo
66-
- Fixed interface mismatch between compiled functions and symbolic variables
67-
- Automatic backend support through PyTensor's existing infrastructure
68-
- Numba compatibility through scan-based approach
6940
"""
41+
from pytensor.compile.function.types import Function
7042

7143
# For Numba mode, use OpFromGraph approach to avoid function closure issues
7244
if mode_name == "NUMBA":
7345
# Special handling for Numba: logp_func should be a PyMC model, not a compiled function
74-
if hasattr(logp_func, "value_vars"): # It's a PyMC model
46+
if hasattr(logp_func, "value_vars"):
7547
return create_opfromgraph_logp(logp_func)
7648
else:
7749
raise ValueError(
7850
"Numba backend requires PyMC model object, not compiled function. "
7951
"Pass the model directly when using inference_backend='numba'."
8052
)
8153

82-
# Check if logp_func is a compiled function by testing its interface
83-
phi_test = pt.vector("phi_test", dtype="float64")
54+
# Use proper type checking to determine if logp_func is a compiled function
55+
if isinstance(logp_func, Function):
56+
# Compiled PyTensor function - use LogLike Op approach
57+
from .pathfinder import LogLike # Import the existing LogLike Op
8458

85-
try:
86-
# Try to call logp_func with symbolic input
87-
logP_scalar = logp_func(phi_test)
88-
if hasattr(logP_scalar, "type"): # It's a symbolic variable
89-
use_symbolic_interface = True
90-
else:
91-
use_symbolic_interface = False
92-
except (TypeError, AttributeError):
93-
# logp_func is a compiled function that expects numeric input
94-
# Fall back to LogLike Op approach for non-Numba modes
95-
use_symbolic_interface = False
96-
97-
if use_symbolic_interface:
98-
# Direct symbolic approach (ideal case)
59+
def vectorized_logp(phi: TensorVariable) -> TensorVariable:
60+
"""Vectorized logp using LogLike Op for compiled functions."""
61+
loglike_op = LogLike(logp_func)
62+
result = loglike_op(phi)
63+
return result
64+
65+
return vectorized_logp
66+
67+
else:
68+
# Assume symbolic interface - use direct symbolic approach
9969
phi_scalar = pt.vector("phi_scalar", dtype="float64")
10070
logP_scalar = logp_func(phi_scalar)
10171

@@ -116,19 +86,6 @@ def vectorized_logp(phi: TensorVariable) -> TensorVariable:
11686

11787
return vectorized_logp
11888

119-
else:
120-
# Fallback to LogLike Op for compiled functions (non-Numba modes only)
121-
# This maintains compatibility while we transition to symbolic approach
122-
from .pathfinder import LogLike # Import the existing LogLike Op
123-
124-
def vectorized_logp(phi: TensorVariable) -> TensorVariable:
125-
"""Vectorized logp using LogLike Op fallback."""
126-
loglike_op = LogLike(logp_func)
127-
result = loglike_op(phi)
128-
return result
129-
130-
return vectorized_logp
131-
13289

13390
def create_scan_based_logp_graph(logp_func: CallableType) -> CallableType:
13491
"""

0 commit comments

Comments
 (0)