Skip to content

Commit 9f23160

Browse files
committed
Cruft removal
1 parent ae9ee59 commit 9f23160

File tree

3 files changed

+79
-278
lines changed

3 files changed

+79
-278
lines changed

pymc_extras/inference/pathfinder/numba_dispatch.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
Numba backend (mode="NUMBA").
99
1010
Architecture follows PyTensor patterns from:
11-
- doc/extending/creating_a_numba_jax_op.rst
11+
- doc/extending/creating_a_numba_op.rst
1212
- pytensor/link/numba/dispatch/
13-
- Existing JAX dispatch in jax_dispatch.py
13+
- Reference implementation ensures mathematical consistency
1414
"""
1515

1616
import numpy as np
@@ -86,7 +86,7 @@ def make_node(self, diff):
8686
def perform(self, node, inputs, outputs):
8787
"""NumPy fallback implementation for compatibility.
8888
89-
This matches the JAX implementation exactly to ensure
89+
This matches the reference implementation exactly to ensure
9090
mathematical correctness as fallback.
9191
9292
Parameters
@@ -136,7 +136,7 @@ def numba_funcify_ChiMatrixOp(op, node, **kwargs):
136136
137137
Uses Numba's optimized loop fusion and memory locality improvements
138138
for efficient sliding window operations. This avoids the dynamic
139-
indexing issues that block JAX compilation while providing better
139+
indexing issues while providing better
140140
CPU performance through cache-friendly access patterns.
141141
142142
Parameters
@@ -194,10 +194,10 @@ class NumbaBfgsSampleOp(Op):
194194
Handles conditional selection between dense and sparse BFGS sampling
195195
modes based on condition JJ >= N, using Numba's efficient conditional
196196
compilation instead of PyTensor's pt.switch. This avoids the dynamic
197-
indexing issues that block JAX compilation while providing superior
197+
indexing issues while providing superior
198198
CPU performance through Numba's optimizations.
199199
200-
The Op implements the same mathematical operations as the JAX version
200+
The Op implements the same mathematical operations as the reference version
201201
but uses Numba-specific optimizations for CPU workloads:
202202
- Parallel processing with numba.prange
203203
- Optimized matrix operations and memory layouts
@@ -257,10 +257,10 @@ def make_node(
257257
return Apply(self, inputs, [phi_out, logdet_out])
258258

259259
def perform(self, node, inputs, outputs):
260-
"""NumPy fallback implementation using JAX logic.
260+
"""NumPy fallback implementation using reference logic.
261261
262262
This provides the reference implementation for mathematical correctness,
263-
copied directly from the JAX version to ensure identical behavior.
263+
copied directly from the reference version to ensure identical behavior.
264264
The Numba-optimized version will be registered separately.
265265
"""
266266
import numpy as np
@@ -348,7 +348,7 @@ def numba_funcify_BfgsSampleOp(op, node, **kwargs):
348348
"""Numba implementation with optimized conditional matrix operations.
349349
350350
Uses Numba's efficient conditional compilation for optimal performance,
351-
avoiding the dynamic indexing issues that prevent JAX compilation while
351+
avoiding the dynamic indexing issues while
352352
providing superior CPU performance through parallel processing and
353353
optimized memory access patterns.
354354
@@ -512,7 +512,7 @@ def bfgs_sample_numba(
512512
513513
Uses efficient conditional compilation to select between dense and sparse
514514
algorithms based on problem dimensions. This avoids the dynamic indexing
515-
issues that prevent JAX compilation while providing optimal performance
515+
issues while providing optimal performance
516516
for both cases.
517517
518518
Parameters

0 commit comments

Comments
 (0)