Skip to content

Commit 176693c

Browse files
authored
clib.Session: Refactor the __getitem__ special method to avoid calling API function GMT_Get_Enum repeatedly (#3261)
1 parent 4862bf6 commit 176693c

File tree

2 files changed

+39
-24
lines changed

2 files changed

+39
-24
lines changed

pygmt/clib/session.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191
np.datetime64: "GMT_DATETIME",
9292
np.timedelta64: "GMT_LONG",
9393
}
94+
# Dictionary for storing the values of GMT constants.
95+
GMT_CONSTANTS = {}
9496

9597
# Load the GMT library outside the Session class to avoid repeated loading.
9698
_libgmt = load_libgmt()
@@ -239,23 +241,41 @@ def __exit__(self, exc_type, exc_value, traceback):
239241
"""
240242
self.destroy()
241243

242-
def __getitem__(self, name):
244+
def __getitem__(self, name: str) -> int:
245+
"""
246+
Get the value of a GMT constant.
247+
248+
Parameters
249+
----------
250+
name
251+
The name of the constant (e.g., ``"GMT_SESSION_EXTERNAL"``).
252+
253+
Returns
254+
-------
255+
value
256+
Integer value of the constant. Do not rely on this value because it might
257+
change.
258+
"""
259+
if name not in GMT_CONSTANTS:
260+
GMT_CONSTANTS[name] = self.get_enum(name)
261+
return GMT_CONSTANTS[name]
262+
263+
def get_enum(self, name: str) -> int:
243264
"""
244265
Get the value of a GMT constant (C enum) from gmt_resources.h.
245266
246-
Used to set configuration values for other API calls. Wraps
247-
``GMT_Get_Enum``.
267+
Used to set configuration values for other API calls. Wraps ``GMT_Get_Enum``.
248268
249269
Parameters
250270
----------
251-
name : str
252-
The name of the constant (e.g., ``"GMT_SESSION_EXTERNAL"``)
271+
name
272+
The name of the constant (e.g., ``"GMT_SESSION_EXTERNAL"``).
253273
254274
Returns
255275
-------
256-
constant : int
257-
Integer value of the constant. Do not rely on this value because it
258-
might change.
276+
value
277+
Integer value of the constant. Do not rely on this value because it might
278+
change.
259279
260280
Raises
261281
------
@@ -266,18 +286,15 @@ def __getitem__(self, name):
266286
"GMT_Get_Enum", argtypes=[ctp.c_void_p, ctp.c_char_p], restype=ctp.c_int
267287
)
268288

269-
# The C lib introduced the void API pointer to GMT_Get_Enum so that
270-
# it's consistent with other functions. It doesn't use the pointer so
271-
# we can pass in None (NULL pointer). We can't give it the actual
272-
# pointer because we need to call GMT_Get_Enum when creating a new API
273-
# session pointer (chicken-and-egg type of thing).
289+
# The C library introduced the void API pointer to GMT_Get_Enum so that it's
290+
# consistent with other functions. It doesn't use the pointer so we can pass
291+
# in None (NULL pointer). We can't give it the actual pointer because we need
292+
# to call GMT_Get_Enum when creating a new API session pointer (chicken-and-egg
293+
# type of thing).
274294
session = None
275-
276295
value = c_get_enum(session, name.encode())
277-
278296
if value is None or value == -99999:
279297
raise GMTCLibError(f"Constant '{name}' doesn't exist in libgmt.")
280-
281298
return value
282299

283300
def get_libgmt_func(self, name, argtypes=None, restype=None):

pygmt/tests/test_clib.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,13 @@ def mock_get_libgmt_func(name, argtypes=None, restype=None):
6565

6666
def test_getitem():
6767
"""
68-
Test that I can get correct constants from the C lib.
68+
Test getting the GMT constants from the C library.
6969
"""
70-
ses = clib.Session()
71-
assert ses["GMT_SESSION_EXTERNAL"] != -99999
72-
assert ses["GMT_MODULE_CMD"] != -99999
73-
assert ses["GMT_PAD_DEFAULT"] != -99999
74-
assert ses["GMT_DOUBLE"] != -99999
75-
with pytest.raises(GMTCLibError):
76-
ses["A_WHOLE_LOT_OF_JUNK"]
70+
with clib.Session() as lib:
71+
for name in ["GMT_SESSION_EXTERNAL", "GMT_MODULE_CMD", "GMT_DOUBLE"]:
72+
assert lib[name] != -99999
73+
with pytest.raises(GMTCLibError):
74+
lib["A_WHOLE_LOT_OF_JUNK"]
7775

7876

7977
def test_create_destroy_session():

0 commit comments

Comments
 (0)