-
Notifications
You must be signed in to change notification settings - Fork 230
clib.conversion._to_numpy: Add tests for pandas.Series with pyarrow numeric dtypes #3585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
a9b10d6
54d1a37
50e6872
1317e11
e09aa75
06e6958
94f335b
bba7296
f2e504b
4150993
b947b83
0a82e3f
98dfb29
898bcca
10526da
f3bfef1
c504bf2
cf397b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,13 @@ | |
from packaging.version import Version | ||
from pygmt.clib.conversion import _to_numpy | ||
|
||
try: | ||
import pyarrow as pa | ||
|
||
_HAS_PYARROW = True | ||
except ImportError: | ||
_HAS_PYARROW = False | ||
|
||
|
||
def _check_result(result, expected_dtype): | ||
""" | ||
|
@@ -138,6 +145,11 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype): | |
# - BooleanDtype | ||
# - ArrowDtype: a special dtype used to store data in the PyArrow format. | ||
# | ||
# PyArrow dtypes can be specified using the following formats: | ||
# | ||
# - Prefixed with the name of the dtype and "[pyarrow]" (e.g., "int8[pyarrow]") | ||
# - Specified using ``ArrowDType`` (e.g., "pd.ArrowDtype(pa.int8())") | ||
# | ||
# References: | ||
# 1. https://pandas.pydata.org/docs/reference/arrays.html | ||
# 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes | ||
|
@@ -152,3 +164,137 @@ def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype): | |
result = _to_numpy(series) | ||
_check_result(result, expected_dtype) | ||
npt.assert_array_equal(result, series) | ||
|
||
|
||
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") | ||
@pytest.mark.parametrize( | ||
("dtype", "expected_dtype"), | ||
[ | ||
pytest.param("int8[pyarrow]", np.int8, id="int8[pyarrow]"), | ||
pytest.param("int16[pyarrow]", np.int16, id="int16[pyarrow]"), | ||
pytest.param("int32[pyarrow]", np.int32, id="int32[pyarrow]"), | ||
pytest.param("int64[pyarrow]", np.int64, id="int64[pyarrow]"), | ||
pytest.param("uint8[pyarrow]", np.uint8, id="uint8[pyarrow]"), | ||
pytest.param("uint16[pyarrow]", np.uint16, id="uint16[pyarrow]"), | ||
pytest.param("uint32[pyarrow]", np.uint32, id="uint32[pyarrow]"), | ||
pytest.param("uint64[pyarrow]", np.uint64, id="uint64[pyarrow]"), | ||
pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"), | ||
pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"), | ||
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"), | ||
], | ||
) | ||
def test_to_numpy_pandas_series_pyarrow_dtypes_numeric(dtype, expected_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is exactly the same as the test To merge them into a single test, we have to change the pytest.param to
which is too long to fit in one line and will make the pytest.params too long to read. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Managed to fit it into a single test like so: from pygmt.helpers.testing import skip_if_no
pa_marks = {"marks": skip_if_no(package="pyarrow")}
@pytest.mark.parametrize(
("dtype", "expected_dtype"),
[
...,
pytest.param("int8[pyarrow]", np.int8, id="int8[pyarrow]", **pa_marks),
pytest.param("int16[pyarrow]", np.int16, id="int16[pyarrow]", **pa_marks),
pytest.param("int32[pyarrow]", np.int32, id="int32[pyarrow]", **pa_marks),
pytest.param("int64[pyarrow]", np.int64, id="int64[pyarrow]", **pa_marks),
pytest.param("uint8[pyarrow]", np.uint8, id="uint8[pyarrow]", **pa_marks),
pytest.param("uint16[pyarrow]", np.uint16, id="uint16[pyarrow]", **pa_marks),
pytest.param("uint32[pyarrow]", np.uint32, id="uint32[pyarrow]", **pa_marks),
pytest.param("uint64[pyarrow]", np.uint64, id="uint64[pyarrow]", **pa_marks),
pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]", **pa_marks),
pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]", **pa_marks),
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]", **pa_marks),
],
) The longest line is just under 88 characters. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice changes. We probably should merge this PR into #3584, so that we can test all pandas dtypes (including pyarrow-backed) in a single PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I see that this branch is targetting |
||
""" | ||
Test the _to_numpy function with pandas.Series of PyArrow numeric dtypes. | ||
""" | ||
series = pd.Series([1, 2, 3, 4, 5, 6], dtype=dtype)[::2] # Not C-contiguous | ||
result = _to_numpy(series) | ||
_check_result(result, expected_dtype) | ||
npt.assert_array_equal(result, series) | ||
|
||
|
||
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") | ||
@pytest.mark.parametrize( | ||
("dtype", "expected_dtype"), | ||
[ | ||
pytest.param("int8[pyarrow]", np.float64, id="int8[pyarrow]"), | ||
pytest.param("int16[pyarrow]", np.float64, id="int16[pyarrow]"), | ||
pytest.param("int32[pyarrow]", np.float64, id="int32[pyarrow]"), | ||
pytest.param("int64[pyarrow]", np.float64, id="int64[pyarrow]"), | ||
pytest.param("uint8[pyarrow]", np.float64, id="uint8[pyarrow]"), | ||
pytest.param("uint16[pyarrow]", np.float64, id="uint16[pyarrow]"), | ||
pytest.param("uint32[pyarrow]", np.float64, id="uint32[pyarrow]"), | ||
pytest.param("uint64[pyarrow]", np.float64, id="uint64[pyarrow]"), | ||
# pytest.param("float16[pyarrow]", np.float64, id="float16[pyarrow]"), | ||
seisman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pytest.param("float32[pyarrow]", np.float64, id="float32[pyarrow]"), | ||
pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"), | ||
], | ||
seisman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
def test_to_numpy_pandas_series_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype): | ||
""" | ||
Test the _to_numpy function with pandas.Series of PyArrow numeric dtypes and NA. | ||
""" | ||
series = pd.Series([1, 2, pd.NA, 4, 5, 6], dtype=dtype)[::2] | ||
assert series.isna().any() | ||
result = _to_numpy(series) | ||
_check_result(result, expected_dtype) | ||
npt.assert_array_equal(result, np.array([1.0, np.nan, 5.0], dtype=expected_dtype)) | ||
|
||
|
||
######################################################################################## | ||
# Test the _to_numpy function with PyArrow arrays. | ||
# | ||
# PyArrow provides the following dtypes: | ||
# | ||
# - Numeric dtypes: | ||
# - int8, int16, int32, int64 | ||
# - uint8, uint16, uint32, uint64 | ||
# - float16, float32, float64 | ||
# | ||
# In PyArrow, array types can be specified in two ways: | ||
# | ||
# - Using string aliases (e.g., "int8") | ||
# - Using pyarrow.DataType (e.g., ``pa.int8()``) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, we can't use |
||
# | ||
# Reference: https://arrow.apache.org/docs/python/api/datatypes.html | ||
######################################################################################## | ||
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") | ||
@pytest.mark.parametrize( | ||
("dtype", "expected_dtype"), | ||
[ | ||
pytest.param("int8", np.int8, id="int8"), | ||
pytest.param("int16", np.int16, id="int16"), | ||
pytest.param("int32", np.int32, id="int32"), | ||
pytest.param("int64", np.int64, id="int64"), | ||
pytest.param("uint8", np.uint8, id="uint8"), | ||
pytest.param("uint16", np.uint16, id="uint16"), | ||
pytest.param("uint32", np.uint32, id="uint32"), | ||
pytest.param("uint64", np.uint64, id="uint64"), | ||
pytest.param("float16", np.float16, id="float16"), | ||
pytest.param("float32", np.float32, id="float32"), | ||
pytest.param("float64", np.float64, id="float64"), | ||
], | ||
) | ||
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype): | ||
""" | ||
Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes. | ||
""" | ||
data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] | ||
if dtype == "float16": # float16 needs special handling | ||
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html | ||
data = np.array(data, dtype=np.float16) | ||
array = pa.array(data, type=dtype)[::2] | ||
result = _to_numpy(array) | ||
_check_result(result, expected_dtype) | ||
npt.assert_array_equal(result, array) | ||
|
||
|
||
@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") | ||
@pytest.mark.parametrize( | ||
("dtype", "expected_dtype"), | ||
[ | ||
pytest.param("int8", np.float64, id="int8"), | ||
pytest.param("int16", np.float64, id="int16"), | ||
pytest.param("int32", np.float64, id="int32"), | ||
pytest.param("int64", np.float64, id="int64"), | ||
pytest.param("uint8", np.float64, id="uint8"), | ||
pytest.param("uint16", np.float64, id="uint16"), | ||
pytest.param("uint32", np.float64, id="uint32"), | ||
pytest.param("uint64", np.float64, id="uint64"), | ||
pytest.param("float16", np.float16, id="float16"), | ||
pytest.param("float32", np.float32, id="float32"), | ||
pytest.param("float64", np.float64, id="float64"), | ||
], | ||
) | ||
def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype): | ||
""" | ||
Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes and NA. | ||
""" | ||
data = [1.0, 2.0, None, 4.0, 5.0, 6.0] | ||
if dtype == "float16": # float16 needs special handling | ||
# Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html | ||
data = np.array(data, dtype=np.float16) | ||
array = pa.array(data, type=dtype)[::2] | ||
result = _to_numpy(array) | ||
_check_result(result, expected_dtype) | ||
npt.assert_array_equal(result, array) |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't use
pd.ArrowDtype(pa.int8())
here becausepa
is not defined whenpyarrow
is not installed. So we have to use the string aliases.