Skip to content

Commit e75e894

Browse files
committed
Add tests for _to_ndarray with various numeric dtypes
1 parent 35230fa commit e75e894

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed

pygmt/tests/test_clib_to_ndarray.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
"""
2+
Test the _to_ndarray function in the clib.conversion module.
3+
"""
4+
5+
import numpy as np
6+
import numpy.testing as npt
7+
import pandas as pd
8+
import pytest
9+
from pygmt.clib.conversion import _to_ndarray
10+
11+
try:
12+
import pyarrow as pa
13+
14+
_HAS_PYARROW = True
15+
except ImportError:
16+
_HAS_PYARROW = False
17+
18+
19+
@pytest.fixture(scope="module", name="dtypes_numpy_numeric")
20+
def fixture_dtypes_numpy_numeric():
21+
"""
22+
List of NumPy numeric dtypes.
23+
24+
Reference: https://numpy.org/doc/stable/reference/arrays.scalars.html
25+
"""
26+
return [
27+
np.int8,
28+
np.int16,
29+
np.int32,
30+
np.int64,
31+
np.longlong,
32+
np.uint8,
33+
np.uint16,
34+
np.uint32,
35+
np.uint64,
36+
np.ulonglong,
37+
np.float16,
38+
np.float32,
39+
np.float64,
40+
np.longdouble,
41+
np.complex64,
42+
np.complex128,
43+
np.clongdouble,
44+
]
45+
46+
47+
@pytest.fixture(scope="module", name="dtypes_pandas_numeric")
48+
def fixture_dtypes_pandas_numeric():
49+
"""
50+
List of pandas numeric dtypes.
51+
52+
Reference: https://pandas.pydata.org/docs/reference/arrays.html
53+
"""
54+
return [
55+
pd.Int8Dtype(),
56+
pd.Int16Dtype(),
57+
pd.Int32Dtype(),
58+
pd.Int64Dtype(),
59+
pd.UInt8Dtype(),
60+
pd.UInt16Dtype(),
61+
pd.UInt32Dtype(),
62+
pd.UInt64Dtype(),
63+
pd.Float32Dtype(),
64+
pd.Float64Dtype(),
65+
]
66+
67+
68+
@pytest.fixture(scope="module", name="dtypes_pandas_numeric_pyarrow_backend")
69+
def fixture_dtypes_pandas_numeric_pyarrow_backend():
70+
"""
71+
List of pandas dtypes that use pyarrow as the backend.
72+
73+
Reference: https://pandas.pydata.org/docs/user_guide/pyarrow.html
74+
"""
75+
return [
76+
"int8[pyarrow]",
77+
"int16[pyarrow]",
78+
"int32[pyarrow]",
79+
"int64[pyarrow]",
80+
"uint8[pyarrow]",
81+
"uint16[pyarrow]",
82+
"uint32[pyarrow]",
83+
"uint64[pyarrow]",
84+
"float32[pyarrow]",
85+
"float64[pyarrow]",
86+
]
87+
88+
89+
@pytest.fixture(scope="module", name="dtypes_pyarrow_numeric")
90+
def fixture_dtypes_pyarrow_numeric():
91+
"""
92+
List of pyarrow numeric dtypes.
93+
94+
Reference: https://arrow.apache.org/docs/python/api/datatypes.html
95+
"""
96+
if not _HAS_PYARROW:
97+
return []
98+
return [
99+
pa.int8(),
100+
pa.int16(),
101+
pa.int32(),
102+
pa.int64(),
103+
pa.uint8(),
104+
pa.uint16(),
105+
pa.uint32(),
106+
pa.uint64(),
107+
# pa.float16(), # Need special handling.
108+
pa.float32(),
109+
pa.float64(),
110+
]
111+
112+
113+
def _check_result(result):
114+
"""
115+
A helper function to check the result of the _to_ndarray function.
116+
117+
Check the following:
118+
119+
1. The result is a NumPy array.
120+
2. The result is C-contiguous.
121+
3. The result dtype is not np.object_.
122+
"""
123+
assert isinstance(result, np.ndarray)
124+
assert result.flags.c_contiguous is True
125+
assert result.dtype != np.object_
126+
127+
128+
def test_to_ndarray_numpy_ndarray_numpy_numeric(dtypes_numpy_numeric):
129+
"""
130+
Test the _to_ndarray function with 1-D NumPy arrays.
131+
"""
132+
# 1-D array
133+
for dtype in dtypes_numpy_numeric:
134+
array = np.array([1, 2, 3], dtype=dtype)
135+
assert array.dtype == dtype
136+
result = _to_ndarray(array)
137+
_check_result(result)
138+
npt.assert_array_equal(result, array)
139+
140+
# 2-D array
141+
for dtype in dtypes_numpy_numeric:
142+
array = np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)
143+
assert array.dtype == dtype
144+
result = _to_ndarray(array)
145+
_check_result(result)
146+
npt.assert_array_equal(result, array)
147+
148+
149+
def test_to_ndarray_pandas_series_numeric(
150+
dtypes_numpy_numeric, dtypes_pandas_numeric, dtypes_pandas_numeric_pyarrow_backend
151+
):
152+
"""
153+
Test the _to_ndarray function with pandas Series with NumPy dtypes, pandas dtypes,
154+
and pandas dtypes with pyarrow backend.
155+
"""
156+
for dtype in (
157+
dtypes_numpy_numeric
158+
+ dtypes_pandas_numeric
159+
+ dtypes_pandas_numeric_pyarrow_backend
160+
):
161+
series = pd.Series([1, 2, 3], dtype=dtype)
162+
assert series.dtype == dtype
163+
result = _to_ndarray(series)
164+
_check_result(result)
165+
npt.assert_array_equal(result, series)
166+
167+
168+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
169+
def test_to_ndarray_pandas_series_pyarrow_dtype(dtypes_pyarrow_numeric):
170+
"""
171+
Test the _to_ndarray function with pandas Series with pyarrow dtypes.
172+
"""
173+
for dtype in dtypes_pyarrow_numeric:
174+
array = pa.array([1, 2, 3], type=dtype)
175+
assert array.type == dtype
176+
result = _to_ndarray(array)
177+
_check_result(result)
178+
npt.assert_array_equal(result, array)
179+
180+
# Special handling for float16.
181+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
182+
array = pa.array(np.array([1.5, 2.5, 3.5], dtype=np.float16), type=pa.float16())
183+
result = _to_ndarray(array)
184+
_check_result(result)
185+
npt.assert_array_equal(result, array)

0 commit comments

Comments
 (0)