Skip to content

Commit b2a71a2

Browse files
committed
Cleanup;
1 parent 0ade157 commit b2a71a2

File tree

2 files changed

+2
-51
lines changed

2 files changed

+2
-51
lines changed

pymc_extras/inference/pathfinder/vectorized_logp.py

Lines changed: 2 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ def create_vectorized_logp_graph(
4040
"""
4141
from pytensor.compile.function.types import Function
4242

43-
# For Numba mode, use OpFromGraph approach to avoid function closure issues
4443
if mode_name == "NUMBA":
45-
# Special handling for Numba: logp_func should be a PyMC model, not a compiled function
4644
if hasattr(logp_func, "value_vars"):
4745
return create_opfromgraph_logp(logp_func)
4846
else:
@@ -51,10 +49,8 @@ def create_vectorized_logp_graph(
5149
"Pass the model directly when using inference_backend='numba'."
5250
)
5351

54-
# Use proper type checking to determine if logp_func is a compiled function
5552
if isinstance(logp_func, Function):
56-
# Compiled PyTensor function - use LogLike Op approach
57-
from .pathfinder import LogLike # Import the existing LogLike Op
53+
from .pathfinder import LogLike
5854

5955
def vectorized_logp(phi: TensorVariable) -> TensorVariable:
6056
"""Vectorized logp using LogLike Op for compiled functions."""
@@ -65,22 +61,18 @@ def vectorized_logp(phi: TensorVariable) -> TensorVariable:
6561
return vectorized_logp
6662

6763
else:
68-
# Assume symbolic interface - use direct symbolic approach
6964
phi_scalar = pt.vector("phi_scalar", dtype="float64")
7065
logP_scalar = logp_func(phi_scalar)
7166

7267
def vectorized_logp(phi: TensorVariable) -> TensorVariable:
7368
"""Vectorized logp using symbolic interface."""
74-
# Use vectorize_graph to handle batch processing
7569
if phi.ndim == 2:
7670
result = vectorize_graph(logP_scalar, replace={phi_scalar: phi})
7771
else:
78-
# Multi-path case: (L, batch_size, num_params)
7972
phi_reshaped = phi.reshape((-1, phi.shape[-1]))
8073
result_flat = vectorize_graph(logP_scalar, replace={phi_scalar: phi_reshaped})
8174
result = result_flat.reshape(phi.shape[:-1])
8275

83-
# Handle nan/inf values
8476
mask = pt.isnan(result) | pt.isinf(result)
8577
return pt.where(mask, -pt.inf, result)
8678

@@ -115,16 +107,12 @@ def scan_logp(phi: TensorVariable) -> TensorVariable:
115107

116108
def scan_fn(phi_row):
117109
"""Single row log-probability computation."""
118-
# Call the compiled logp_func on individual parameter vectors
119-
# This works with Numba because pt.scan handles the iteration
120110
return logp_func(phi_row)
121111

122-
# Handle different input shapes
123112
if phi.ndim == 2:
124-
# Single path: (M, N) -> (M,)
125113
logP_result, _ = scan(fn=scan_fn, sequences=[phi], outputs_info=None, strict=True)
126114
elif phi.ndim == 3:
127-
# Multiple paths: (L, M, N) -> (L, M)
115+
128116
def scan_paths(phi_path):
129117
logP_path, _ = scan(
130118
fn=scan_fn, sequences=[phi_path], outputs_info=None, strict=True
@@ -135,7 +123,6 @@ def scan_paths(phi_path):
135123
else:
136124
raise ValueError(f"Expected 2D or 3D input, got {phi.ndim}D")
137125

138-
# Handle nan/inf values (same as LogLike Op)
139126
mask = pt.isnan(logP_result) | pt.isinf(logP_result)
140127
result = pt.where(mask, -pt.inf, logP_result)
141128

@@ -160,14 +147,12 @@ def create_direct_vectorized_logp(logp_func: CallableType) -> CallableType:
160147
Callable
161148
Function that takes a batch of parameter vectors and returns vectorized logp values
162149
"""
163-
# Use PyTensor's built-in vectorize
164150
vectorized_logp_func = pt.vectorize(logp_func, signature="(n)->()")
165151

166152
def direct_logp(phi: TensorVariable) -> TensorVariable:
167153
"""Compute log-probability using pt.vectorize."""
168154
logP_result = vectorized_logp_func(phi)
169155

170-
# Handle nan/inf values
171156
mask = pt.isnan(logP_result) | pt.isinf(logP_result)
172157
return pt.where(mask, -pt.inf, logP_result)
173158

@@ -191,37 +176,28 @@ def extract_model_symbolic_graph(model):
191176
(param_vector, model_vars, model_logp, param_sizes, total_params)
192177
"""
193178
with model:
194-
# Get the model's symbolic computation graph
195179
model_vars = list(model.value_vars)
196180
model_logp = model.logp()
197181

198-
# Extract parameter dimensions and create flattened parameter vector
199182
param_sizes = []
200183
for var in model_vars:
201184
if hasattr(var.type, "shape") and var.type.shape is not None:
202-
# Handle shaped variables
203185
if len(var.type.shape) == 0:
204-
# Scalar
205186
param_sizes.append(1)
206187
else:
207-
# Get product of shape dimensions
208188
size = 1
209189
for dim in var.type.shape:
210-
# For PyTensor, shape dimensions are often just integers
211190
if isinstance(dim, int):
212191
size *= dim
213192
elif hasattr(dim, "value") and dim.value is not None:
214193
size *= dim.value
215194
else:
216-
# Try to evaluate if it's a symbolic expression
217195
try:
218196
size *= int(dim.eval())
219197
except (AttributeError, ValueError, Exception):
220-
# Default to 1 for unknown dimensions
221198
size *= 1
222199
param_sizes.append(size)
223200
else:
224-
# Default to scalar
225201
param_sizes.append(1)
226202

227203
total_params = sum(param_sizes)
@@ -254,7 +230,6 @@ def create_symbolic_parameter_mapping(param_vector, model_vars, param_sizes):
254230
start_idx = 0
255231

256232
for var, size in zip(model_vars, param_sizes):
257-
# Extract slice from parameter vector
258233
if size == 1:
259234
# Scalar case
260235
var_slice = param_vector[start_idx]
@@ -309,19 +284,14 @@ def create_opfromgraph_logp(model) -> CallableType:
309284

310285
from pytensor.compile.builders import OpFromGraph
311286

312-
# Extract symbolic components - this is the critical step
313287
param_vector, model_vars, model_logp, param_sizes, total_params = extract_model_symbolic_graph(
314288
model
315289
)
316290

317-
# Create parameter mapping - replaces function closure with pure symbols
318291
substitutions = create_symbolic_parameter_mapping(param_vector, model_vars, param_sizes)
319292

320-
# Apply substitutions to create parameter-vector-based logp
321-
# This uses PyTensor's symbolic graph manipulation instead of function calls
322293
symbolic_logp = graph.clone_replace(model_logp, substitutions)
323294

324-
# Create OpFromGraph - this is Numba-compatible because it's pure symbolic
325295
logp_op = OpFromGraph([param_vector], [symbolic_logp])
326296

327297
def opfromgraph_logp(phi: TensorVariable) -> TensorVariable:
@@ -346,7 +316,6 @@ def compute_path(phi_path):
346316
else:
347317
raise ValueError(f"Expected 2D or 3D input, got {phi.ndim}D")
348318

349-
# Handle nan/inf values using PyTensor operations
350319
mask = pt.isnan(logP_result) | pt.isinf(logP_result)
351320
return pt.where(mask, -pt.inf, logP_result)
352321

@@ -396,33 +365,19 @@ def create_symbolic_reconstruction_logp(model) -> CallableType:
396365
def symbolic_logp(phi: TensorVariable) -> TensorVariable:
397366
"""Reconstruct logp computation symbolically for Numba compatibility."""
398367

399-
# Strategy: Replace the compiled function approach with direct symbolic computation
400-
# This requires mapping parameter vectors back to model variables symbolically
401-
402-
# For simple models, we can reconstruct the logp directly
403-
# This is a template - specific implementation depends on model structure
404-
405368
if phi.ndim == 2:
406369
# Single path case: (M, N) -> (M,)
407370

408-
# Use PyTensor's built-in vectorization primitives instead of scan
409-
# This avoids the function closure issue
410371
def compute_single_logp(param_vec):
411372
# Map parameter vector to model variables symbolically
412-
# This is where we'd implement the symbolic equivalent of logp_func
413-
414-
# For demonstration - this needs to be model-specific
415-
# In practice, this would use the model's symbolic graph
416373
return pt.sum(param_vec**2) * -0.5 # Simple quadratic form
417374

418-
# Use pt.vectorize for native PyTensor vectorization
419375
vectorized_fn = pt.vectorize(compute_single_logp, signature="(n)->()")
420376
logP_result = vectorized_fn(phi)
421377

422378
elif phi.ndim == 3:
423379
# Multiple paths case: (L, M, N) -> (L, M)
424380

425-
# Reshape and vectorize, then reshape back
426381
L, M, N = phi.shape
427382
phi_reshaped = phi.reshape((-1, N))
428383

@@ -436,7 +391,6 @@ def compute_single_logp(param_vec):
436391
else:
437392
raise ValueError(f"Expected 2D or 3D input, got {phi.ndim}D")
438393

439-
# Handle nan/inf values
440394
mask = pt.isnan(logP_result) | pt.isinf(logP_result)
441395
return pt.where(mask, -pt.inf, logP_result)
442396

tests/inference/pathfinder/conftest.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ def get_available_backends():
9494

9595
available = ["pymc"] # PyMC should always be available
9696

97-
if importlib.util.find_spec("jax") is not None:
98-
available.append("jax")
99-
10097
if importlib.util.find_spec("numba") is not None:
10198
available.append("numba")
10299

0 commit comments

Comments
 (0)