Skip to content

Commit a65f6ae

Browse files
authored
clib.conversion._to_numpy: Add tests for pandas.Series with pyarrow numeric dtypes (#3585)
1 parent f3aa7b9 commit a65f6ae

File tree

1 file changed

+47
-4
lines changed

1 file changed

+47
-4
lines changed

pygmt/tests/test_clib_to_numpy.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
from packaging.version import Version
1212
from pygmt.clib.conversion import _to_numpy
13+
from pygmt.helpers.testing import skip_if_no
1314

1415
try:
1516
import pyarrow as pa
@@ -18,6 +19,9 @@
1819
except ImportError:
1920
_HAS_PYARROW = False
2021

22+
# Mark tests that require pyarrow
23+
pa_marks = {"marks": skip_if_no(package="pyarrow")}
24+
2125

2226
def _check_result(result, expected_dtype):
2327
"""
@@ -145,6 +149,11 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
145149
# - BooleanDtype
146150
# - ArrowDtype: a special dtype used to store data in the PyArrow format.
147151
#
152+
# In pandas, PyArrow types can be specified using the following formats:
153+
#
154+
# - Prefixed with the name of the dtype and "[pyarrow]" (e.g., "int8[pyarrow]")
155+
# - Specified using ``ArrowDType`` (e.g., "pd.ArrowDtype(pa.int8())")
156+
#
148157
# References:
149158
# 1. https://pandas.pydata.org/docs/reference/arrays.html
150159
# 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes
@@ -174,13 +183,30 @@ def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
174183
pytest.param(pd.UInt64Dtype(), np.uint64, id="UInt64"),
175184
pytest.param(pd.Float32Dtype(), np.float32, id="Float32"),
176185
pytest.param(pd.Float64Dtype(), np.float64, id="Float64"),
186+
pytest.param("int8[pyarrow]", np.int8, id="int8[pyarrow]", **pa_marks),
187+
pytest.param("int16[pyarrow]", np.int16, id="int16[pyarrow]", **pa_marks),
188+
pytest.param("int32[pyarrow]", np.int32, id="int32[pyarrow]", **pa_marks),
189+
pytest.param("int64[pyarrow]", np.int64, id="int64[pyarrow]", **pa_marks),
190+
pytest.param("uint8[pyarrow]", np.uint8, id="uint8[pyarrow]", **pa_marks),
191+
pytest.param("uint16[pyarrow]", np.uint16, id="uint16[pyarrow]", **pa_marks),
192+
pytest.param("uint32[pyarrow]", np.uint32, id="uint32[pyarrow]", **pa_marks),
193+
pytest.param("uint64[pyarrow]", np.uint64, id="uint64[pyarrow]", **pa_marks),
194+
pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]", **pa_marks),
195+
pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]", **pa_marks),
196+
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]", **pa_marks),
177197
],
178198
)
179199
def test_to_numpy_pandas_series_pandas_dtypes_numeric(dtype, expected_dtype):
180200
"""
181-
Test the _to_numpy function with pandas.Series of pandas numeric dtypes.
201+
Test the _to_numpy function with pandas.Series of pandas/PyArrow numeric dtypes.
182202
"""
183-
series = pd.Series([1, 2, 3, 4, 5, 6], dtype=dtype)[::2] # Not C-contiguous
203+
data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
204+
if dtype == "float16[pyarrow]" and Version(pd.__version__) < Version("2.2"):
205+
# float16 needs special handling for pandas < 2.2.
206+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
207+
data = np.array(data, dtype=np.float16)
208+
209+
series = pd.Series(data, dtype=dtype)[::2] # Not C-contiguous
184210
result = _to_numpy(series)
185211
_check_result(result, expected_dtype)
186212
npt.assert_array_equal(result, series)
@@ -199,13 +225,30 @@ def test_to_numpy_pandas_series_pandas_dtypes_numeric(dtype, expected_dtype):
199225
pytest.param(pd.UInt64Dtype(), np.float64, id="UInt64"),
200226
pytest.param(pd.Float32Dtype(), np.float32, id="Float32"),
201227
pytest.param(pd.Float64Dtype(), np.float64, id="Float64"),
228+
pytest.param("int8[pyarrow]", np.float64, id="int8[pyarrow]", **pa_marks),
229+
pytest.param("int16[pyarrow]", np.float64, id="int16[pyarrow]", **pa_marks),
230+
pytest.param("int32[pyarrow]", np.float64, id="int32[pyarrow]", **pa_marks),
231+
pytest.param("int64[pyarrow]", np.float64, id="int64[pyarrow]", **pa_marks),
232+
pytest.param("uint8[pyarrow]", np.float64, id="uint8[pyarrow]", **pa_marks),
233+
pytest.param("uint16[pyarrow]", np.float64, id="uint16[pyarrow]", **pa_marks),
234+
pytest.param("uint32[pyarrow]", np.float64, id="uint32[pyarrow]", **pa_marks),
235+
pytest.param("uint64[pyarrow]", np.float64, id="uint64[pyarrow]", **pa_marks),
236+
pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]", **pa_marks),
237+
pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]", **pa_marks),
238+
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]", **pa_marks),
202239
],
203240
)
204241
def test_to_numpy_pandas_series_pandas_dtypes_numeric_with_na(dtype, expected_dtype):
205242
"""
206-
Test the _to_numpy function with pandas.Series of pandas numeric dtypes and NA.
243+
Test the _to_numpy function with pandas.Series of pandas/PyArrow numeric dtypes and
244+
missing values (NA).
207245
"""
208-
series = pd.Series([1, 2, pd.NA, 4, 5, 6], dtype=dtype)[::2] # Not C-contiguous
246+
data = [1.0, 2.0, None, 4.0, 5.0, 6.0]
247+
if dtype == "float16[pyarrow]" and Version(pd.__version__) < Version("2.2"):
248+
# float16 needs special handling for pandas < 2.2.
249+
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
250+
data = np.array(data, dtype=np.float16)
251+
series = pd.Series(data, dtype=dtype)[::2] # Not C-contiguous
209252
assert series.isna().any()
210253
result = _to_numpy(series)
211254
_check_result(result, expected_dtype)

0 commit comments

Comments
 (0)