diff --git a/pygmt/tests/test_clib_to_numpy.py b/pygmt/tests/test_clib_to_numpy.py index 12b8c1d782d..3624ed2be8d 100644 --- a/pygmt/tests/test_clib_to_numpy.py +++ b/pygmt/tests/test_clib_to_numpy.py @@ -11,6 +11,13 @@ from packaging.version import Version from pygmt.clib.conversion import _to_numpy +try: + import pyarrow as pa + + _HAS_PYARROW = True +except ImportError: + _HAS_PYARROW = False + def _check_result(result, expected_dtype): """ @@ -121,7 +128,7 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype): # # 1. NumPy dtypes (see above) # 2. pandas dtypes -# 3. PyArrow dtypes +# 3. PyArrow types (see below) # # pandas provides following dtypes: # @@ -152,3 +159,82 @@ def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype): result = _to_numpy(series) _check_result(result, expected_dtype) npt.assert_array_equal(result, series) + + +######################################################################################## +# Test the _to_numpy function with PyArrow arrays. +# +# PyArrow provides the following types: +# +# - Numeric types: +# - int8, int16, int32, int64 +# - uint8, uint16, uint32, uint64 +# - float16, float32, float64 +# +# In PyArrow, array types can be specified in two ways: +# +# - Using string aliases (e.g., "int8") +# - Using pyarrow.DataType (e.g., ``pa.int8()``) +# +# Reference: https://arrow.apache.org/docs/python/api/datatypes.html +######################################################################################## +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + pytest.param("int8", np.int8, id="int8"), + pytest.param("int16", np.int16, id="int16"), + pytest.param("int32", np.int32, id="int32"), + pytest.param("int64", np.int64, id="int64"), + pytest.param("uint8", np.uint8, id="uint8"), + pytest.param("uint16", np.uint16, id="uint16"), + pytest.param("uint32", np.uint32, id="uint32"), + pytest.param("uint64", np.uint64, id="uint64"), + pytest.param("float16", np.float16, id="float16"), + pytest.param("float32", np.float32, id="float32"), + pytest.param("float64", np.float64, id="float64"), + ], +) +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype): + """ + Test the _to_numpy function with PyArrow arrays of PyArrow numeric types. + """ + data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + if dtype == "float16": # float16 needs special handling + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html + data = np.array(data, dtype=np.float16) + array = pa.array(data, type=dtype)[::2] + result = _to_numpy(array) + _check_result(result, expected_dtype) + npt.assert_array_equal(result, array) + + +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") +@pytest.mark.parametrize( + ("dtype", "expected_dtype"), + [ + pytest.param("int8", np.float64, id="int8"), + pytest.param("int16", np.float64, id="int16"), + pytest.param("int32", np.float64, id="int32"), + pytest.param("int64", np.float64, id="int64"), + pytest.param("uint8", np.float64, id="uint8"), + pytest.param("uint16", np.float64, id="uint16"), + pytest.param("uint32", np.float64, id="uint32"), + pytest.param("uint64", np.float64, id="uint64"), + pytest.param("float16", np.float16, id="float16"), + pytest.param("float32", np.float32, id="float32"), + pytest.param("float64", np.float64, id="float64"), + ], +) +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype): + """ + Test the _to_numpy function with PyArrow arrays of PyArrow numeric types and NA. + """ + data = [1.0, 2.0, None, 4.0, 5.0, 6.0] + if dtype == "float16": # float16 needs special handling + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html + data = np.array(data, dtype=np.float16) + array = pa.array(data, type=dtype)[::2] + result = _to_numpy(array) + _check_result(result, expected_dtype) + npt.assert_array_equal(result, array)