@@ -40,9 +40,7 @@ def create_vectorized_logp_graph(
40
40
"""
41
41
from pytensor .compile .function .types import Function
42
42
43
- # For Numba mode, use OpFromGraph approach to avoid function closure issues
44
43
if mode_name == "NUMBA" :
45
- # Special handling for Numba: logp_func should be a PyMC model, not a compiled function
46
44
if hasattr (logp_func , "value_vars" ):
47
45
return create_opfromgraph_logp (logp_func )
48
46
else :
@@ -51,10 +49,8 @@ def create_vectorized_logp_graph(
51
49
"Pass the model directly when using inference_backend='numba'."
52
50
)
53
51
54
- # Use proper type checking to determine if logp_func is a compiled function
55
52
if isinstance (logp_func , Function ):
56
- # Compiled PyTensor function - use LogLike Op approach
57
- from .pathfinder import LogLike # Import the existing LogLike Op
53
+ from .pathfinder import LogLike
58
54
59
55
def vectorized_logp (phi : TensorVariable ) -> TensorVariable :
60
56
"""Vectorized logp using LogLike Op for compiled functions."""
@@ -65,22 +61,18 @@ def vectorized_logp(phi: TensorVariable) -> TensorVariable:
65
61
return vectorized_logp
66
62
67
63
else :
68
- # Assume symbolic interface - use direct symbolic approach
69
64
phi_scalar = pt .vector ("phi_scalar" , dtype = "float64" )
70
65
logP_scalar = logp_func (phi_scalar )
71
66
72
67
def vectorized_logp (phi : TensorVariable ) -> TensorVariable :
73
68
"""Vectorized logp using symbolic interface."""
74
- # Use vectorize_graph to handle batch processing
75
69
if phi .ndim == 2 :
76
70
result = vectorize_graph (logP_scalar , replace = {phi_scalar : phi })
77
71
else :
78
- # Multi-path case: (L, batch_size, num_params)
79
72
phi_reshaped = phi .reshape ((- 1 , phi .shape [- 1 ]))
80
73
result_flat = vectorize_graph (logP_scalar , replace = {phi_scalar : phi_reshaped })
81
74
result = result_flat .reshape (phi .shape [:- 1 ])
82
75
83
- # Handle nan/inf values
84
76
mask = pt .isnan (result ) | pt .isinf (result )
85
77
return pt .where (mask , - pt .inf , result )
86
78
@@ -115,16 +107,12 @@ def scan_logp(phi: TensorVariable) -> TensorVariable:
115
107
116
108
def scan_fn (phi_row ):
117
109
"""Single row log-probability computation."""
118
- # Call the compiled logp_func on individual parameter vectors
119
- # This works with Numba because pt.scan handles the iteration
120
110
return logp_func (phi_row )
121
111
122
- # Handle different input shapes
123
112
if phi .ndim == 2 :
124
- # Single path: (M, N) -> (M,)
125
113
logP_result , _ = scan (fn = scan_fn , sequences = [phi ], outputs_info = None , strict = True )
126
114
elif phi .ndim == 3 :
127
- # Multiple paths: (L, M, N) -> (L, M)
115
+
128
116
def scan_paths (phi_path ):
129
117
logP_path , _ = scan (
130
118
fn = scan_fn , sequences = [phi_path ], outputs_info = None , strict = True
@@ -135,7 +123,6 @@ def scan_paths(phi_path):
135
123
else :
136
124
raise ValueError (f"Expected 2D or 3D input, got { phi .ndim } D" )
137
125
138
- # Handle nan/inf values (same as LogLike Op)
139
126
mask = pt .isnan (logP_result ) | pt .isinf (logP_result )
140
127
result = pt .where (mask , - pt .inf , logP_result )
141
128
@@ -160,14 +147,12 @@ def create_direct_vectorized_logp(logp_func: CallableType) -> CallableType:
160
147
Callable
161
148
Function that takes a batch of parameter vectors and returns vectorized logp values
162
149
"""
163
- # Use PyTensor's built-in vectorize
164
150
vectorized_logp_func = pt .vectorize (logp_func , signature = "(n)->()" )
165
151
166
152
def direct_logp (phi : TensorVariable ) -> TensorVariable :
167
153
"""Compute log-probability using pt.vectorize."""
168
154
logP_result = vectorized_logp_func (phi )
169
155
170
- # Handle nan/inf values
171
156
mask = pt .isnan (logP_result ) | pt .isinf (logP_result )
172
157
return pt .where (mask , - pt .inf , logP_result )
173
158
@@ -191,37 +176,28 @@ def extract_model_symbolic_graph(model):
191
176
(param_vector, model_vars, model_logp, param_sizes, total_params)
192
177
"""
193
178
with model :
194
- # Get the model's symbolic computation graph
195
179
model_vars = list (model .value_vars )
196
180
model_logp = model .logp ()
197
181
198
- # Extract parameter dimensions and create flattened parameter vector
199
182
param_sizes = []
200
183
for var in model_vars :
201
184
if hasattr (var .type , "shape" ) and var .type .shape is not None :
202
- # Handle shaped variables
203
185
if len (var .type .shape ) == 0 :
204
- # Scalar
205
186
param_sizes .append (1 )
206
187
else :
207
- # Get product of shape dimensions
208
188
size = 1
209
189
for dim in var .type .shape :
210
- # For PyTensor, shape dimensions are often just integers
211
190
if isinstance (dim , int ):
212
191
size *= dim
213
192
elif hasattr (dim , "value" ) and dim .value is not None :
214
193
size *= dim .value
215
194
else :
216
- # Try to evaluate if it's a symbolic expression
217
195
try :
218
196
size *= int (dim .eval ())
219
197
except (AttributeError , ValueError , Exception ):
220
- # Default to 1 for unknown dimensions
221
198
size *= 1
222
199
param_sizes .append (size )
223
200
else :
224
- # Default to scalar
225
201
param_sizes .append (1 )
226
202
227
203
total_params = sum (param_sizes )
@@ -254,7 +230,6 @@ def create_symbolic_parameter_mapping(param_vector, model_vars, param_sizes):
254
230
start_idx = 0
255
231
256
232
for var , size in zip (model_vars , param_sizes ):
257
- # Extract slice from parameter vector
258
233
if size == 1 :
259
234
# Scalar case
260
235
var_slice = param_vector [start_idx ]
@@ -309,19 +284,14 @@ def create_opfromgraph_logp(model) -> CallableType:
309
284
310
285
from pytensor .compile .builders import OpFromGraph
311
286
312
- # Extract symbolic components - this is the critical step
313
287
param_vector , model_vars , model_logp , param_sizes , total_params = extract_model_symbolic_graph (
314
288
model
315
289
)
316
290
317
- # Create parameter mapping - replaces function closure with pure symbols
318
291
substitutions = create_symbolic_parameter_mapping (param_vector , model_vars , param_sizes )
319
292
320
- # Apply substitutions to create parameter-vector-based logp
321
- # This uses PyTensor's symbolic graph manipulation instead of function calls
322
293
symbolic_logp = graph .clone_replace (model_logp , substitutions )
323
294
324
- # Create OpFromGraph - this is Numba-compatible because it's pure symbolic
325
295
logp_op = OpFromGraph ([param_vector ], [symbolic_logp ])
326
296
327
297
def opfromgraph_logp (phi : TensorVariable ) -> TensorVariable :
@@ -346,7 +316,6 @@ def compute_path(phi_path):
346
316
else :
347
317
raise ValueError (f"Expected 2D or 3D input, got { phi .ndim } D" )
348
318
349
- # Handle nan/inf values using PyTensor operations
350
319
mask = pt .isnan (logP_result ) | pt .isinf (logP_result )
351
320
return pt .where (mask , - pt .inf , logP_result )
352
321
@@ -396,33 +365,19 @@ def create_symbolic_reconstruction_logp(model) -> CallableType:
396
365
def symbolic_logp (phi : TensorVariable ) -> TensorVariable :
397
366
"""Reconstruct logp computation symbolically for Numba compatibility."""
398
367
399
- # Strategy: Replace the compiled function approach with direct symbolic computation
400
- # This requires mapping parameter vectors back to model variables symbolically
401
-
402
- # For simple models, we can reconstruct the logp directly
403
- # This is a template - specific implementation depends on model structure
404
-
405
368
if phi .ndim == 2 :
406
369
# Single path case: (M, N) -> (M,)
407
370
408
- # Use PyTensor's built-in vectorization primitives instead of scan
409
- # This avoids the function closure issue
410
371
def compute_single_logp (param_vec ):
411
372
# Map parameter vector to model variables symbolically
412
- # This is where we'd implement the symbolic equivalent of logp_func
413
-
414
- # For demonstration - this needs to be model-specific
415
- # In practice, this would use the model's symbolic graph
416
373
return pt .sum (param_vec ** 2 ) * - 0.5 # Simple quadratic form
417
374
418
- # Use pt.vectorize for native PyTensor vectorization
419
375
vectorized_fn = pt .vectorize (compute_single_logp , signature = "(n)->()" )
420
376
logP_result = vectorized_fn (phi )
421
377
422
378
elif phi .ndim == 3 :
423
379
# Multiple paths case: (L, M, N) -> (L, M)
424
380
425
- # Reshape and vectorize, then reshape back
426
381
L , M , N = phi .shape
427
382
phi_reshaped = phi .reshape ((- 1 , N ))
428
383
@@ -436,7 +391,6 @@ def compute_single_logp(param_vec):
436
391
else :
437
392
raise ValueError (f"Expected 2D or 3D input, got { phi .ndim } D" )
438
393
439
- # Handle nan/inf values
440
394
mask = pt .isnan (logP_result ) | pt .isinf (logP_result )
441
395
return pt .where (mask , - pt .inf , logP_result )
442
396
0 commit comments