Skip to content

Commit 2a2ab7a

Browse files
committed
clib.conversion._to_numpy: Add tests for pandas.Series and pyarrow.array with pyarrow numeric dtypes
1 parent eceff7f commit 2a2ab7a

File tree

1 file changed

+146
-0
lines changed

1 file changed

+146
-0
lines changed

pygmt/tests/test_clib_to_numpy.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
import pytest
99
from pygmt.clib.conversion import _to_numpy
1010

11+
try:
12+
import pyarrow as pa
13+
14+
_HAS_PYARROW = True
15+
except ImportError:
16+
_HAS_PYARROW = False
17+
1118

1219
def _check_result(result, expected_dtype):
1320
"""
@@ -122,6 +129,11 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
122129
# - BooleanDtype
123130
# - ArrowDtype: a special dtype used to store data in the PyArrow format.
124131
#
132+
# PyArrow dtypes can be specified using the following formats:
133+
#
134+
# - Prefixed with the name of the dtype and "[pyarrow]" (e.g., "int8[pyarrow]")
135+
# - Specified using ``ArrowDType`` (e.g., "pd.ArrowDtype(pa.int8())")
136+
#
125137
# References:
126138
# 1. https://pandas.pydata.org/docs/reference/arrays.html
127139
# 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes
@@ -207,3 +219,137 @@ def test_to_numpy_pandas_series_pandas_dtypes_numeric_with_na(dtype, expected_dt
207219
result = _to_numpy(series)
208220
_check_result(result, expected_dtype)
209221
npt.assert_array_equal(result, np.array([1.0, np.nan, 3.0], dtype=expected_dtype))
222+
223+
224+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
225+
@pytest.mark.parametrize(
226+
("dtype", "expected_dtype"),
227+
[
228+
pytest.param("int8[pyarrow]", np.int8, id="int8[pyarrow]"),
229+
pytest.param("int16[pyarrow]", np.int16, id="int16[pyarrow]"),
230+
pytest.param("int32[pyarrow]", np.int32, id="int32[pyarrow]"),
231+
pytest.param("int64[pyarrow]", np.int64, id="int64[pyarrow]"),
232+
pytest.param("uint8[pyarrow]", np.uint8, id="uint8[pyarrow]"),
233+
pytest.param("uint16[pyarrow]", np.uint16, id="uint16[pyarrow]"),
234+
pytest.param("uint32[pyarrow]", np.uint32, id="uint32[pyarrow]"),
235+
pytest.param("uint64[pyarrow]", np.uint64, id="uint64[pyarrow]"),
236+
pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"),
237+
pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"),
238+
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"),
239+
],
240+
)
241+
def test_to_numpy_pandas_series_pyarrow_dtypes_numeric(dtype, expected_dtype):
242+
"""
243+
Test the _to_numpy function with pandas.Series of pandas numeric dtypes.
244+
"""
245+
series = pd.Series([1, 2, 3], dtype=dtype)
246+
result = _to_numpy(series)
247+
_check_result(result, expected_dtype)
248+
npt.assert_array_equal(result, series)
249+
250+
251+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
252+
@pytest.mark.parametrize(
253+
("dtype", "expected_dtype"),
254+
[
255+
pytest.param("int8[pyarrow]", np.float64, id="int8[pyarrow]"),
256+
pytest.param("int16[pyarrow]", np.float64, id="int16[pyarrow]"),
257+
pytest.param("int32[pyarrow]", np.float64, id="int32[pyarrow]"),
258+
pytest.param("int64[pyarrow]", np.float64, id="int64[pyarrow]"),
259+
pytest.param("uint8[pyarrow]", np.float64, id="uint8[pyarrow]"),
260+
pytest.param("uint16[pyarrow]", np.float64, id="uint16[pyarrow]"),
261+
pytest.param("uint32[pyarrow]", np.float64, id="uint32[pyarrow]"),
262+
pytest.param("uint64[pyarrow]", np.float64, id="uint64[pyarrow]"),
263+
pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"),
264+
pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"),
265+
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"),
266+
],
267+
)
268+
def test_to_numpy_pandas_series_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype):
269+
"""
270+
Test the _to_numpy function with pandas.Series of pandas numeric dtypes and NA.
271+
"""
272+
series = pd.Series([1, pd.NA, 3], dtype=dtype)
273+
result = _to_numpy(series)
274+
_check_result(result, expected_dtype)
275+
npt.assert_array_equal(result, np.array([1.0, np.nan, 3.0], dtype=expected_dtype))
276+
277+
278+
########################################################################################
279+
# Test the _to_numpy function with PyArrow arrays.
280+
#
281+
# PyArrow provides the following dtypes:
282+
#
283+
# - Numeric dtypes:
284+
# - int8, int16, int32, int64
285+
# - uint8, uint16, uint32, uint64
286+
# - float16, float32, float64
287+
#
288+
# Reference: https://arrow.apache.org/docs/python/api/datatypes.html
289+
########################################################################################
290+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
291+
@pytest.mark.parametrize(
292+
("dtype", "expected_dtype"),
293+
[
294+
pytest.param("int8", np.int8, id="int8"),
295+
pytest.param("int16", np.int16, id="int16"),
296+
pytest.param("int32", np.int32, id="int32"),
297+
pytest.param("int64", np.int64, id="int64"),
298+
pytest.param("uint8", np.uint8, id="uint8"),
299+
pytest.param("uint16", np.uint16, id="uint16"),
300+
pytest.param("uint32", np.uint32, id="uint32"),
301+
pytest.param("uint64", np.uint64, id="uint64"),
302+
pytest.param("float16", np.float16, id="float16"),
303+
pytest.param("float32", np.float32, id="float32"),
304+
pytest.param("float64", np.float64, id="float64"),
305+
],
306+
)
307+
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype):
308+
"""
309+
Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes.
310+
"""
311+
if dtype == "float16":
312+
# float16 needs special handling
313+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
314+
array = pa.array(np.array([1.0, 2.0, 3.0], dtype=np.float16), type=pa.float16())
315+
else:
316+
array = pa.array([1, 2, 3], type=dtype)
317+
assert array.type == dtype
318+
result = _to_numpy(array)
319+
_check_result(result, expected_dtype)
320+
npt.assert_array_equal(result, array)
321+
322+
323+
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
324+
@pytest.mark.parametrize(
325+
("dtype", "expected_dtype"),
326+
[
327+
pytest.param("int8", np.float64, id="int8"),
328+
pytest.param("int16", np.float64, id="int16"),
329+
pytest.param("int32", np.float64, id="int32"),
330+
pytest.param("int64", np.float64, id="int64"),
331+
pytest.param("uint8", np.float64, id="uint8"),
332+
pytest.param("uint16", np.float64, id="uint16"),
333+
pytest.param("uint32", np.float64, id="uint32"),
334+
pytest.param("uint64", np.float64, id="uint64"),
335+
pytest.param("float16", np.float16, id="float16"),
336+
pytest.param("float32", np.float32, id="float32"),
337+
pytest.param("float64", np.float64, id="float64"),
338+
],
339+
)
340+
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype):
341+
"""
342+
Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes and NA.
343+
"""
344+
if dtype == "float16":
345+
# float16 needs special handling
346+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
347+
array = pa.array(
348+
np.array([1.0, None, 3.0], dtype=np.float16), type=pa.float16()
349+
)
350+
else:
351+
array = pa.array([1, None, 3], type=dtype)
352+
assert array.type == dtype
353+
result = _to_numpy(array)
354+
_check_result(result, expected_dtype)
355+
npt.assert_array_equal(result, array)

0 commit comments

Comments
 (0)