Skip to content

Commit 8bc2f56

Browse files
committed
Check the expected dtype
1 parent f9bf19c commit 8bc2f56

File tree

1 file changed

+47
-50
lines changed

1 file changed

+47
-50
lines changed

pygmt/tests/test_clib_to_numpy.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,16 @@
77
import pandas as pd
88
import pytest
99
from pygmt.clib.conversion import _to_numpy
10-
from pygmt.clib.session import DTYPES
1110

1211

13-
def _check_result(result, supported):
12+
def _check_result(result, expected_dtype):
1413
"""
15-
Check the result of the _to_numpy function.
14+
A helper function to check if the result of the _to_numpy function is a C-contiguous
15+
NumPy array with the expected dtype.
1616
"""
17-
# Check that the result is a NumPy array and is C-contiguous.
1817
assert isinstance(result, np.ndarray)
1918
assert result.flags.c_contiguous
20-
# Check that the dtype is supported by PyGMT (or the GMT C API).
21-
assert (result.dtype.type in DTYPES) == supported
19+
assert result.dtype.type == expected_dtype
2220

2321

2422
########################################################################################
@@ -36,8 +34,8 @@ def test_to_numpy_python_types_numeric(data, expected_dtype):
3634
Test the _to_numpy function with Python built-in numeric types.
3735
"""
3836
result = _to_numpy(data)
39-
_check_result(result, supported=True)
40-
npt.assert_array_equal(result, np.array(data, dtype=expected_dtype), strict=True)
37+
_check_result(result, expected_dtype)
38+
npt.assert_array_equal(result, data)
4139

4240

4341
########################################################################################
@@ -60,28 +58,28 @@ def test_to_numpy_python_types_numeric(data, expected_dtype):
6058
# Reference: https://numpy.org/doc/2.1/reference/arrays.scalars.html
6159
########################################################################################
6260
@pytest.mark.parametrize(
63-
("dtype", "supported"),
61+
("dtype", "expected_dtype"),
6462
[
65-
(np.int8, True),
66-
(np.int16, True),
67-
(np.int32, True),
68-
(np.int64, True),
69-
(np.longlong, True),
70-
(np.uint8, True),
71-
(np.uint16, True),
72-
(np.uint32, True),
73-
(np.uint64, True),
74-
(np.ulonglong, True),
75-
(np.float16, False),
76-
(np.float32, True),
77-
(np.float64, True),
78-
(np.longdouble, False),
79-
(np.complex64, False),
80-
(np.complex128, False),
81-
(np.clongdouble, False),
63+
pytest.param(np.int8, np.int8, id="int8"),
64+
pytest.param(np.int16, np.int16, id="int16"),
65+
pytest.param(np.int32, np.int32, id="int32"),
66+
pytest.param(np.int64, np.int64, id="int64"),
67+
pytest.param(np.longlong, np.longlong, id="longlong"),
68+
pytest.param(np.uint8, np.uint8, id="uint8"),
69+
pytest.param(np.uint16, np.uint16, id="uint16"),
70+
pytest.param(np.uint32, np.uint32, id="uint32"),
71+
pytest.param(np.uint64, np.uint64, id="uint64"),
72+
pytest.param(np.ulonglong, np.ulonglong, id="ulonglong"),
73+
pytest.param(np.float16, np.float16, id="float16"),
74+
pytest.param(np.float32, np.float32, id="float32"),
75+
pytest.param(np.float64, np.float64, id="float64"),
76+
pytest.param(np.longdouble, np.longdouble, id="longdouble"),
77+
pytest.param(np.complex64, np.complex64, id="complex64"),
78+
pytest.param(np.complex128, np.complex128, id="complex128"),
79+
pytest.param(np.clongdouble, np.clongdouble, id="clongdouble"),
8280
],
8381
)
84-
def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, supported):
82+
def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
8583
"""
8684
Test the _to_numpy function with NumPy arrays of NumPy numeric dtypes.
8785
@@ -90,13 +88,13 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, supported):
9088
# 1-D array
9189
array = np.array([1, 2, 3], dtype=dtype)
9290
result = _to_numpy(array)
93-
_check_result(result, supported)
91+
_check_result(result, expected_dtype)
9492
npt.assert_array_equal(result, array, strict=True)
9593

9694
# 2-D array
9795
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
9896
result = _to_numpy(array)
99-
_check_result(result, supported)
97+
_check_result(result, expected_dtype)
10098
npt.assert_array_equal(result, array, strict=True)
10199

102100

@@ -130,33 +128,32 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, supported):
130128
# 3. https://pandas.pydata.org/docs/user_guide/pyarrow.html
131129
########################################################################################
132130
@pytest.mark.parametrize(
133-
("dtype", "supported"),
131+
("dtype", "expected_dtype"),
134132
[
135-
(np.int8, True),
136-
(np.int16, True),
137-
(np.int32, True),
138-
(np.int64, True),
139-
(np.longlong, True),
140-
(np.uint8, True),
141-
(np.uint16, True),
142-
(np.uint32, True),
143-
(np.uint64, True),
144-
(np.ulonglong, True),
145-
(np.float16, False),
146-
(np.float32, True),
147-
(np.float64, True),
148-
(np.longdouble, False),
149-
(np.complex64, False),
150-
(np.complex128, False),
151-
(np.clongdouble, False),
133+
pytest.param(np.int8, np.int8, id="int8"),
134+
pytest.param(np.int16, np.int16, id="int16"),
135+
pytest.param(np.int32, np.int32, id="int32"),
136+
pytest.param(np.int64, np.int64, id="int64"),
137+
pytest.param(np.longlong, np.longlong, id="longlong"),
138+
pytest.param(np.uint8, np.uint8, id="uint8"),
139+
pytest.param(np.uint16, np.uint16, id="uint16"),
140+
pytest.param(np.uint32, np.uint32, id="uint32"),
141+
pytest.param(np.uint64, np.uint64, id="uint64"),
142+
pytest.param(np.ulonglong, np.ulonglong, id="ulonglong"),
143+
pytest.param(np.float16, np.float16, id="float16"),
144+
pytest.param(np.float32, np.float32, id="float32"),
145+
pytest.param(np.float64, np.float64, id="float64"),
146+
pytest.param(np.longdouble, np.longdouble, id="longdouble"),
147+
pytest.param(np.complex64, np.complex64, id="complex64"),
148+
pytest.param(np.complex128, np.complex128, id="complex128"),
149+
pytest.param(np.clongdouble, np.clongdouble, id="clongdouble"),
152150
],
153151
)
154-
def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, supported):
152+
def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
155153
"""
156154
Test the _to_numpy function with pandas.Series of NumPy numeric dtypes.
157155
"""
158156
series = pd.Series([1, 2, 3], dtype=dtype)
159-
assert series.dtype == dtype
160157
result = _to_numpy(series)
161-
_check_result(result, supported)
158+
_check_result(result, expected_dtype)
162159
npt.assert_array_equal(result, series)

0 commit comments

Comments
 (0)