Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions uxarray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
import uxarray.constants
import sys
#
# # TODO: numba recursion limit ?


from .core.api import open_grid, open_dataset, open_mfdataset

from .core.dataset import UxDataset
Expand All @@ -25,23 +19,6 @@
__version__ = "999"


# Flag for enabling FMA instructions across the package
def enable_fma():
"""Enables Fused-Multiply-Add (FMA) instructions using the ``pyfma``
package."""
uxarray.constants.ENABLE_FMA = True


def disable_fma():
"""Disable Fused-Multiply-Add (FMA) instructions using the ``pyfma``
package."""
uxarray.constants.ENABLE_FMA = False


disable_fma()
sys.setrecursionlimit(10000)


__all__ = (
"open_grid",
"open_dataset",
Expand All @@ -55,6 +32,4 @@ def disable_fma():
"diverging",
"sequential_blue",
"sequential_green",
"enable_fma",
"disable_fma",
)
4 changes: 1 addition & 3 deletions uxarray/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,9 +827,7 @@ def pole_point_inside_polygon(pole, face_edges_xyz, face_edges_lonlat):
return ((north_intersections + south_intersections) % 2) != 0

else:
raise ValueError(
f"Invalid pole point query. Current location: {location}, query pole point: {pole}"
)
raise ValueError("Invalid pole point query.")


@njit(cache=True)
Expand Down
15 changes: 8 additions & 7 deletions uxarray/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
_check_normalization,
)

from uxarray.utils.numba import is_numba_function_cached


from uxarray.conventions import ugrid

Expand Down Expand Up @@ -1367,13 +1369,12 @@ def bounds(self):
Dimensions ``(n_face", two, two)``
"""
if "bounds" not in self._ds:
if hasattr(compute_temp_latlon_array, "inspect_llvm"):
if len(compute_temp_latlon_array.inspect_llvm()) == 0:
warn(
"Necessary functions for computing face bounds are not translated yet with Numba. This initial"
"translation may take some time.",
RuntimeWarning,
)
if not is_numba_function_cached(compute_temp_latlon_array):
warn(
"Necessary functions for computing the bounds of each face are not yet compiled with Numba. "
"This initial execution will be significantly longer.",
RuntimeWarning,
)

_populate_bounds(self)

Expand Down
44 changes: 44 additions & 0 deletions uxarray/utils/numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
import pickle
import numba


def is_numba_function_cached(func):
"""
Determines if a numba function is cached and up-to-date.

Returns:
- True if cache exists and is valid or the input is not a Numba function.
- False if cache doesn't exist or needs recompilation
"""

if not hasattr(func, "_cache"):
return True

cache = func._cache
cache_file = cache._cache_file

# Check if cache file exists
full_path = os.path.join(cache._cache_path, cache_file._index_name)
if not os.path.isfile(full_path):
return False

try:
# Load and check version
with open(full_path, "rb") as f:
version = pickle.load(f)
if version != numba.__version__:
return False

# Load and check source stamp
data = f.read()
stamp, _ = pickle.loads(data)

# Get current source stamp
current_stamp = cache._impl.locator.get_source_stamp()

# Compare stamps
return stamp == current_stamp

except (OSError, pickle.PickleError):
return False
Loading