Skip to content

Commit 3a589fa

Browse files
authored
Refactor clib to avoid checking GMT version repeatedly and only check once when loading the GMT library (#3254)
1 parent 1df8f19 commit 3a589fa

File tree

7 files changed

+81
-54
lines changed

7 files changed

+81
-54
lines changed

.github/ISSUE_TEMPLATE/5-bump_gmt_checklist.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ using the following command:
3535
**To-Do for bumping the minimum required GMT version**:
3636

3737
- [ ] Bump the minimum required GMT version (1 PR)
38-
- [ ] Update `required_version` in `pygmt/clib/session.py`
38+
- [ ] Update `required_gmt_version` in `pygmt/clib/__init__.py`
3939
- [ ] Update `test_get_default` in `pygmt/tests/test_clib.py`
4040
- [ ] Update minimum required versions in `doc/minversions.md`
4141
- [ ] Remove unsupported GMT version from `.github/workflows/ci_tests_legacy.yaml`

doc/conf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77

88
# ruff: isort: off
99
from sphinx_gallery.sorting import ExplicitOrder, ExampleTitleSortKey
10-
import pygmt
10+
from pygmt.clib import required_gmt_version
1111
from pygmt import __commit__, __version__
1212
from pygmt.sphinx_gallery import PyGMTScraper
1313

1414
# ruff: isort: on
1515

1616
requires_python = metadata("pygmt")["Requires-Python"]
17-
with pygmt.clib.Session() as lib:
18-
requires_gmt = f">={lib.required_version}"
17+
requires_gmt = f">={required_gmt_version}"
1918

2019
extensions = [
2120
"myst_parser",

pygmt/clib/__init__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,16 @@
55
interface. Access to the C library is done through ctypes.
66
"""
77

8-
from pygmt.clib.session import Session
8+
from packaging.version import Version
9+
from pygmt.clib.session import Session, __gmt_version__
10+
from pygmt.exceptions import GMTVersionError
911

10-
with Session() as lib:
11-
__gmt_version__ = lib.info["version"]
12+
required_gmt_version = "6.3.0"
13+
14+
# Check if the GMT version is older than the required version.
15+
if Version(__gmt_version__) < Version(required_gmt_version):
16+
msg = (
17+
f"Using an incompatible GMT version {__gmt_version__}. "
18+
f"Must be equal or newer than {required_gmt_version}."
19+
)
20+
raise GMTVersionError(msg)

pygmt/clib/loading.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,33 @@ def load_libgmt(lib_fullnames: Iterator[str] | None = None) -> ctypes.CDLL:
6464
return libgmt
6565

6666

67+
def get_gmt_version(libgmt: ctypes.CDLL) -> str:
68+
"""
69+
Get the GMT version string of the GMT shared library.
70+
71+
Parameters
72+
----------
73+
libgmt
74+
The GMT shared library.
75+
76+
Returns
77+
-------
78+
The GMT version string in *major.minor.patch* format.
79+
"""
80+
func = libgmt.GMT_Get_Version
81+
func.argtypes = (
82+
ctypes.c_void_p, # Unused parameter, so it can be None.
83+
ctypes.POINTER(ctypes.c_uint), # major
84+
ctypes.POINTER(ctypes.c_uint), # minor
85+
ctypes.POINTER(ctypes.c_uint), # patch
86+
)
87+
# The function return value is the current library version as a float, e.g., 6.5.
88+
func.restype = ctypes.c_float
89+
major, minor, patch = ctypes.c_uint(0), ctypes.c_uint(0), ctypes.c_uint(0)
90+
func(None, major, minor, patch)
91+
return f"{major.value}.{minor.value}.{patch.value}"
92+
93+
6794
def clib_names(os_name: str) -> list[str]:
6895
"""
6996
Return the name(s) of GMT's shared library for the current operating system.

pygmt/clib/session.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,9 @@
2525
strings_to_ctypes_array,
2626
vectors_to_arrays,
2727
)
28-
from pygmt.clib.loading import load_libgmt
28+
from pygmt.clib.loading import get_gmt_version, load_libgmt
2929
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID, _GMT_IMAGE
30-
from pygmt.exceptions import (
31-
GMTCLibError,
32-
GMTCLibNoSessionError,
33-
GMTInvalidInput,
34-
GMTVersionError,
35-
)
30+
from pygmt.exceptions import GMTCLibError, GMTCLibNoSessionError, GMTInvalidInput
3631
from pygmt.helpers import (
3732
_validate_data_input,
3833
data_kind,
@@ -98,6 +93,7 @@
9893

9994
# Load the GMT library outside the Session class to avoid repeated loading.
10095
_libgmt = load_libgmt()
96+
__gmt_version__ = get_gmt_version(_libgmt)
10197

10298

10399
class Session:
@@ -155,9 +151,6 @@ class Session:
155151
-55 -47 -24 -10 190 981 1 1 8 14 1 1
156152
"""
157153

158-
# The minimum supported GMT version.
159-
required_version = "6.3.0"
160-
161154
@property
162155
def session_pointer(self):
163156
"""
@@ -212,27 +205,11 @@ def info(self):
212205

213206
def __enter__(self):
214207
"""
215-
Create a GMT API session and check the libgmt version.
208+
Create a GMT API session.
216209
217210
Calls :meth:`pygmt.clib.Session.create`.
218-
219-
Raises
220-
------
221-
GMTVersionError
222-
If the version reported by libgmt is less than
223-
``Session.required_version``. Will destroy the session before
224-
raising the exception.
225211
"""
226212
self.create("pygmt-session")
227-
# Need to store the version info because 'get_default' won't work after
228-
# the session is destroyed.
229-
version = self.info["version"]
230-
if Version(version) < Version(self.required_version):
231-
self.destroy()
232-
raise GMTVersionError(
233-
f"Using an incompatible GMT version {version}. "
234-
f"Must be equal or newer than {self.required_version}."
235-
)
236213
return self
237214

238215
def __exit__(self, exc_type, exc_value, traceback):

pygmt/tests/test_clib.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -577,27 +577,24 @@ def mock_defaults(api, name, value): # noqa: ARG001
577577
ses.destroy()
578578

579579

580-
def test_fails_for_wrong_version():
580+
def test_fails_for_wrong_version(monkeypatch):
581581
"""
582-
Make sure the clib.Session raises an exception if GMT is too old.
582+
Make sure that importing clib raise an exception if GMT is too old.
583583
"""
584+
import importlib
584585

585-
# Mock GMT_Get_Default to return an old version
586-
def mock_defaults(api, name, value): # noqa: ARG001
587-
"""
588-
Return an old version.
589-
"""
590-
if name == b"API_VERSION":
591-
value.value = b"5.4.3"
592-
else:
593-
value.value = b"bla"
594-
return 0
586+
with monkeypatch.context() as mpatch:
587+
# Make sure the current GMT major version is 6.
588+
assert clib.__gmt_version__.split(".")[0] == "6"
595589

596-
lib = clib.Session()
597-
with mock(lib, "GMT_Get_Default", mock_func=mock_defaults):
590+
# Monkeypatch the version string returned by pygmt.clib.loading.get_gmt_version.
591+
mpatch.setattr(clib.loading, "get_gmt_version", lambda libgmt: "5.4.3") # noqa: ARG005
592+
593+
# Reload clib.session and check the __gmt_version__ string.
594+
importlib.reload(clib.session)
595+
assert clib.session.__gmt_version__ == "5.4.3"
596+
597+
# Should raise an exception when pygmt.clib is loaded/reloaded.
598598
with pytest.raises(GMTVersionError):
599-
with lib:
600-
assert lib.info["version"] != "5.4.3"
601-
# Make sure the session is closed when the exception is raised.
602-
with pytest.raises(GMTCLibNoSessionError):
603-
assert lib.session_pointer
599+
importlib.reload(clib)
600+
assert clib.__gmt_version__ == "5.4.3" # Make sure it's still the old version

pygmt/tests/test_clib_loading.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111
from pathlib import PurePath
1212

1313
import pytest
14-
from pygmt.clib.loading import check_libgmt, clib_full_names, clib_names, load_libgmt
14+
from pygmt.clib.loading import (
15+
check_libgmt,
16+
clib_full_names,
17+
clib_names,
18+
get_gmt_version,
19+
load_libgmt,
20+
)
1521
from pygmt.clib.session import Session
1622
from pygmt.exceptions import GMTCLibError, GMTCLibNotFoundError, GMTOSError
1723

@@ -360,3 +366,15 @@ def test_clib_full_names_gmt_library_path_incorrect_path_included(
360366
# Windows: find_library() searches the library in PATH, so one more
361367
npath = 2 if sys.platform == "win32" else 1
362368
assert list(lib_fullpaths) == [gmt_lib_realpath] * npath + gmt_lib_names
369+
370+
371+
###############################################################################
372+
# Test get_gmt_version
373+
def test_get_gmt_version():
374+
"""
375+
Test if get_gmt_version returns a version string in major.minor.patch format.
376+
"""
377+
version = get_gmt_version(load_libgmt())
378+
assert isinstance(version, str)
379+
assert len(version.split(".")) == 3 # In major.minor.patch format
380+
assert version.split(".")[0] == "6" # Is GMT 6.x.x

0 commit comments

Comments
 (0)