Skip to content

Commit 2f193a4

Browse files
authored
Fix pickling error in Numba exception handling & improve cache checks (#1083)
* attempt to fix recursion error (1) * better checks for caching * add better check if a numba function is cached or needs to be re-compiled * add case for non-numba functions
1 parent b6b1eb8 commit 2f193a4

File tree

4 files changed

+53
-35
lines changed

4 files changed

+53
-35
lines changed

uxarray/__init__.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
import uxarray.constants
2-
import sys
3-
#
4-
# # TODO: numba recursion limit ?
5-
6-
71
from .core.api import open_grid, open_dataset, open_mfdataset
82

93
from .core.dataset import UxDataset
@@ -25,23 +19,6 @@
2519
__version__ = "999"
2620

2721

28-
# Flag for enabling FMA instructions across the package
29-
def enable_fma():
30-
"""Enables Fused-Multiply-Add (FMA) instructions using the ``pyfma``
31-
package."""
32-
uxarray.constants.ENABLE_FMA = True
33-
34-
35-
def disable_fma():
36-
"""Disable Fused-Multiply-Add (FMA) instructions using the ``pyfma``
37-
package."""
38-
uxarray.constants.ENABLE_FMA = False
39-
40-
41-
disable_fma()
42-
sys.setrecursionlimit(10000)
43-
44-
4522
__all__ = (
4623
"open_grid",
4724
"open_dataset",
@@ -55,6 +32,4 @@ def disable_fma():
5532
"diverging",
5633
"sequential_blue",
5734
"sequential_green",
58-
"enable_fma",
59-
"disable_fma",
6035
)

uxarray/grid/geometry.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -827,9 +827,7 @@ def pole_point_inside_polygon(pole, face_edges_xyz, face_edges_lonlat):
827827
return ((north_intersections + south_intersections) % 2) != 0
828828

829829
else:
830-
raise ValueError(
831-
f"Invalid pole point query. Current location: {location}, query pole point: {pole}"
832-
)
830+
raise ValueError("Invalid pole point query.")
833831

834832

835833
@njit(cache=True)

uxarray/grid/grid.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@
9595
_check_normalization,
9696
)
9797

98+
from uxarray.utils.numba import is_numba_function_cached
99+
98100

99101
from uxarray.conventions import ugrid
100102

@@ -1367,13 +1369,12 @@ def bounds(self):
13671369
Dimensions ``(n_face", two, two)``
13681370
"""
13691371
if "bounds" not in self._ds:
1370-
if hasattr(compute_temp_latlon_array, "inspect_llvm"):
1371-
if len(compute_temp_latlon_array.inspect_llvm()) == 0:
1372-
warn(
1373-
"Necessary functions for computing face bounds are not translated yet with Numba. This initial"
1374-
"translation may take some time.",
1375-
RuntimeWarning,
1376-
)
1372+
if not is_numba_function_cached(compute_temp_latlon_array):
1373+
warn(
1374+
"Necessary functions for computing the bounds of each face are not yet compiled with Numba. "
1375+
"This initial execution will be significantly longer.",
1376+
RuntimeWarning,
1377+
)
13771378

13781379
_populate_bounds(self)
13791380

uxarray/utils/numba.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import os
2+
import pickle
3+
import numba
4+
5+
6+
def is_numba_function_cached(func):
7+
"""
8+
Determines if a numba function is cached and up-to-date.
9+
10+
Returns:
11+
- True if cache exists and is valid or the input is not a Numba function.
12+
- False if cache doesn't exist or needs recompilation
13+
"""
14+
15+
if not hasattr(func, "_cache"):
16+
return True
17+
18+
cache = func._cache
19+
cache_file = cache._cache_file
20+
21+
# Check if cache file exists
22+
full_path = os.path.join(cache._cache_path, cache_file._index_name)
23+
if not os.path.isfile(full_path):
24+
return False
25+
26+
try:
27+
# Load and check version
28+
with open(full_path, "rb") as f:
29+
version = pickle.load(f)
30+
if version != numba.__version__:
31+
return False
32+
33+
# Load and check source stamp
34+
data = f.read()
35+
stamp, _ = pickle.loads(data)
36+
37+
# Get current source stamp
38+
current_stamp = cache._impl.locator.get_source_stamp()
39+
40+
# Compare stamps
41+
return stamp == current_stamp
42+
43+
except (OSError, pickle.PickleError):
44+
return False

0 commit comments

Comments
 (0)