Skip to content

Commit 63df796

Browse files
committed
clib.conversion._to_numpy: Add tests for pandas.Series and pyarrow.array with pyarrow numeric dtypes
1 parent dd78693 commit 63df796

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

pygmt/tests/test_clib_to_numpy.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from packaging.version import Version
1212
from pygmt.clib.conversion import _to_numpy
1313

14+
try:
15+
import pyarrow as pa
16+
17+
_HAS_PYARROW = True
18+
except ImportError:
19+
_HAS_PYARROW = False
20+
1421

1522
def _check_result(result, expected_dtype):
1623
"""
@@ -138,6 +145,11 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
138145
# - BooleanDtype
139146
# - ArrowDtype: a special dtype used to store data in the PyArrow format.
140147
#
148+
# PyArrow dtypes can be specified using the following formats:
149+
#
150+
# - Prefixed with the name of the dtype and "[pyarrow]" (e.g., "int8[pyarrow]")
151+
# - Specified using ``ArrowDType`` (e.g., "pd.ArrowDtype(pa.int8())")
152+
#
141153
# References:
142154
# 1. https://pandas.pydata.org/docs/reference/arrays.html
143155
# 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes
@@ -152,3 +164,137 @@ def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
152164
result = _to_numpy(series)
153165
_check_result(result, expected_dtype)
154166
npt.assert_array_equal(result, series)
167+
168+
169+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
170+
@pytest.mark.parametrize(
171+
("dtype", "expected_dtype"),
172+
[
173+
pytest.param("int8[pyarrow]", np.int8, id="int8[pyarrow]"),
174+
pytest.param("int16[pyarrow]", np.int16, id="int16[pyarrow]"),
175+
pytest.param("int32[pyarrow]", np.int32, id="int32[pyarrow]"),
176+
pytest.param("int64[pyarrow]", np.int64, id="int64[pyarrow]"),
177+
pytest.param("uint8[pyarrow]", np.uint8, id="uint8[pyarrow]"),
178+
pytest.param("uint16[pyarrow]", np.uint16, id="uint16[pyarrow]"),
179+
pytest.param("uint32[pyarrow]", np.uint32, id="uint32[pyarrow]"),
180+
pytest.param("uint64[pyarrow]", np.uint64, id="uint64[pyarrow]"),
181+
pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"),
182+
pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"),
183+
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"),
184+
],
185+
)
186+
def test_to_numpy_pandas_series_pyarrow_dtypes_numeric(dtype, expected_dtype):
187+
"""
188+
Test the _to_numpy function with pandas.Series of pandas numeric dtypes.
189+
"""
190+
series = pd.Series([1, 2, 3, 4, 5, 6], dtype=dtype)[::2] # Not C-contiguous
191+
result = _to_numpy(series)
192+
_check_result(result, expected_dtype)
193+
npt.assert_array_equal(result, series)
194+
195+
196+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
197+
@pytest.mark.parametrize(
198+
("dtype", "expected_dtype"),
199+
[
200+
pytest.param("int8[pyarrow]", np.float64, id="int8[pyarrow]"),
201+
pytest.param("int16[pyarrow]", np.float64, id="int16[pyarrow]"),
202+
pytest.param("int32[pyarrow]", np.float64, id="int32[pyarrow]"),
203+
pytest.param("int64[pyarrow]", np.float64, id="int64[pyarrow]"),
204+
pytest.param("uint8[pyarrow]", np.float64, id="uint8[pyarrow]"),
205+
pytest.param("uint16[pyarrow]", np.float64, id="uint16[pyarrow]"),
206+
pytest.param("uint32[pyarrow]", np.float64, id="uint32[pyarrow]"),
207+
pytest.param("uint64[pyarrow]", np.float64, id="uint64[pyarrow]"),
208+
# pytest.param("float16[pyarrow]", np.float64, id="float16[pyarrow]"),
209+
pytest.param("float32[pyarrow]", np.float64, id="float32[pyarrow]"),
210+
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"),
211+
],
212+
)
213+
def test_to_numpy_pandas_series_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype):
214+
"""
215+
Test the _to_numpy function with pandas.Series of pandas numeric dtypes and NA.
216+
"""
217+
series = pd.Series([1, 2, pd.NA, 4, 5, 6], dtype=dtype)[::2]
218+
assert series.isna().any()
219+
result = _to_numpy(series)
220+
_check_result(result, expected_dtype)
221+
npt.assert_array_equal(result, np.array([1.0, np.nan, 5.0], dtype=expected_dtype))
222+
223+
224+
########################################################################################
225+
# Test the _to_numpy function with PyArrow arrays.
226+
#
227+
# PyArrow provides the following dtypes:
228+
#
229+
# - Numeric dtypes:
230+
# - int8, int16, int32, int64
231+
# - uint8, uint16, uint32, uint64
232+
# - float16, float32, float64
233+
#
234+
# In PyArrow, array types can be specified in two ways:
235+
#
236+
# - Using string aliases (e.g., "int8")
237+
# - Using pyarrow.DataType (e.g., ``pa.int8()``)
238+
#
239+
# Reference: https://arrow.apache.org/docs/python/api/datatypes.html
240+
########################################################################################
241+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
242+
@pytest.mark.parametrize(
243+
("dtype", "expected_dtype"),
244+
[
245+
pytest.param("int8", np.int8, id="int8"),
246+
pytest.param("int16", np.int16, id="int16"),
247+
pytest.param("int32", np.int32, id="int32"),
248+
pytest.param("int64", np.int64, id="int64"),
249+
pytest.param("uint8", np.uint8, id="uint8"),
250+
pytest.param("uint16", np.uint16, id="uint16"),
251+
pytest.param("uint32", np.uint32, id="uint32"),
252+
pytest.param("uint64", np.uint64, id="uint64"),
253+
pytest.param("float16", np.float16, id="float16"),
254+
pytest.param("float32", np.float32, id="float32"),
255+
pytest.param("float64", np.float64, id="float64"),
256+
],
257+
)
258+
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype):
259+
"""
260+
Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes.
261+
"""
262+
data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
263+
if dtype == "float16": # float16 needs special handling
264+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
265+
data = np.array(data, dtype=np.float16)
266+
array = pa.array(data, type=dtype)[::2]
267+
result = _to_numpy(array)
268+
_check_result(result, expected_dtype)
269+
npt.assert_array_equal(result, array)
270+
271+
272+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
273+
@pytest.mark.parametrize(
274+
("dtype", "expected_dtype"),
275+
[
276+
pytest.param("int8", np.float64, id="int8"),
277+
pytest.param("int16", np.float64, id="int16"),
278+
pytest.param("int32", np.float64, id="int32"),
279+
pytest.param("int64", np.float64, id="int64"),
280+
pytest.param("uint8", np.float64, id="uint8"),
281+
pytest.param("uint16", np.float64, id="uint16"),
282+
pytest.param("uint32", np.float64, id="uint32"),
283+
pytest.param("uint64", np.float64, id="uint64"),
284+
pytest.param("float16", np.float16, id="float16"),
285+
pytest.param("float32", np.float32, id="float32"),
286+
pytest.param("float64", np.float64, id="float64"),
287+
],
288+
)
289+
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype):
290+
"""
291+
Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes and NA.
292+
"""
293+
data = [1.0, 2.0, None, 4.0, 5.0, 6.0]
294+
if dtype == "float16": # float16 needs special handling
295+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
296+
data = np.array(data, dtype=np.float16)
297+
array = pa.array(data, type=dtype)[::2]
298+
result = _to_numpy(array)
299+
_check_result(result, expected_dtype)
300+
npt.assert_array_equal(result, array)

0 commit comments

Comments
 (0)