|
8 | 8 | Numba backend (mode="NUMBA").
|
9 | 9 |
|
10 | 10 | Architecture follows PyTensor patterns from:
|
11 |
| -- doc/extending/creating_a_numba_jax_op.rst |
| 11 | +- doc/extending/creating_a_numba_op.rst |
12 | 12 | - pytensor/link/numba/dispatch/
|
13 |
| -- Existing JAX dispatch in jax_dispatch.py |
| 13 | +- Reference implementation ensures mathematical consistency |
14 | 14 | """
|
15 | 15 |
|
16 | 16 | import numpy as np
|
@@ -86,7 +86,7 @@ def make_node(self, diff):
|
86 | 86 | def perform(self, node, inputs, outputs):
|
87 | 87 | """NumPy fallback implementation for compatibility.
|
88 | 88 |
|
89 |
| - This matches the JAX implementation exactly to ensure |
| 89 | + This matches the reference implementation exactly to ensure |
90 | 90 | mathematical correctness as fallback.
|
91 | 91 |
|
92 | 92 | Parameters
|
@@ -136,7 +136,7 @@ def numba_funcify_ChiMatrixOp(op, node, **kwargs):
|
136 | 136 |
|
137 | 137 | Uses Numba's optimized loop fusion and memory locality improvements
|
138 | 138 | for efficient sliding window operations. This avoids the dynamic
|
139 |
| - indexing issues that block JAX compilation while providing better |
| 139 | + indexing issues while providing better |
140 | 140 | CPU performance through cache-friendly access patterns.
|
141 | 141 |
|
142 | 142 | Parameters
|
@@ -194,10 +194,10 @@ class NumbaBfgsSampleOp(Op):
|
194 | 194 | Handles conditional selection between dense and sparse BFGS sampling
|
195 | 195 | modes based on condition JJ >= N, using Numba's efficient conditional
|
196 | 196 | compilation instead of PyTensor's pt.switch. This avoids the dynamic
|
197 |
| - indexing issues that block JAX compilation while providing superior |
| 197 | + indexing issues while providing superior |
198 | 198 | CPU performance through Numba's optimizations.
|
199 | 199 |
|
200 |
| - The Op implements the same mathematical operations as the JAX version |
| 200 | + The Op implements the same mathematical operations as the reference version |
201 | 201 | but uses Numba-specific optimizations for CPU workloads:
|
202 | 202 | - Parallel processing with numba.prange
|
203 | 203 | - Optimized matrix operations and memory layouts
|
@@ -257,10 +257,10 @@ def make_node(
|
257 | 257 | return Apply(self, inputs, [phi_out, logdet_out])
|
258 | 258 |
|
259 | 259 | def perform(self, node, inputs, outputs):
|
260 |
| - """NumPy fallback implementation using JAX logic. |
| 260 | + """NumPy fallback implementation using reference logic. |
261 | 261 |
|
262 | 262 | This provides the reference implementation for mathematical correctness,
|
263 |
| - copied directly from the JAX version to ensure identical behavior. |
| 263 | + copied directly from the reference version to ensure identical behavior. |
264 | 264 | The Numba-optimized version will be registered separately.
|
265 | 265 | """
|
266 | 266 | import numpy as np
|
@@ -348,7 +348,7 @@ def numba_funcify_BfgsSampleOp(op, node, **kwargs):
|
348 | 348 | """Numba implementation with optimized conditional matrix operations.
|
349 | 349 |
|
350 | 350 | Uses Numba's efficient conditional compilation for optimal performance,
|
351 |
| - avoiding the dynamic indexing issues that prevent JAX compilation while |
| 351 | + avoiding the dynamic indexing issues while |
352 | 352 | providing superior CPU performance through parallel processing and
|
353 | 353 | optimized memory access patterns.
|
354 | 354 |
|
@@ -512,7 +512,7 @@ def bfgs_sample_numba(
|
512 | 512 |
|
513 | 513 | Uses efficient conditional compilation to select between dense and sparse
|
514 | 514 | algorithms based on problem dimensions. This avoids the dynamic indexing
|
515 |
| - issues that prevent JAX compilation while providing optimal performance |
| 515 | + issues while providing optimal performance |
516 | 516 | for both cases.
|
517 | 517 |
|
518 | 518 | Parameters
|
|
0 commit comments