|
11 | 11 | from packaging.version import Version
|
12 | 12 | from pygmt.clib.conversion import _to_numpy
|
13 | 13 |
|
| 14 | +try: |
| 15 | + import pyarrow as pa |
| 16 | + |
| 17 | + _HAS_PYARROW = True |
| 18 | +except ImportError: |
| 19 | + _HAS_PYARROW = False |
| 20 | + |
14 | 21 |
|
15 | 22 | def _check_result(result, expected_dtype):
|
16 | 23 | """
|
@@ -121,7 +128,7 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
|
121 | 128 | #
|
122 | 129 | # 1. NumPy dtypes (see above)
|
123 | 130 | # 2. pandas dtypes
|
124 |
| -# 3. PyArrow dtypes |
| 131 | +# 3. PyArrow types (see below) |
125 | 132 | #
|
126 | 133 | # pandas provides following dtypes:
|
127 | 134 | #
|
@@ -203,3 +210,82 @@ def test_to_numpy_pandas_series_pandas_dtypes_numeric_with_na(dtype, expected_dt
|
203 | 210 | result = _to_numpy(series)
|
204 | 211 | _check_result(result, expected_dtype)
|
205 | 212 | npt.assert_array_equal(result, np.array([1.0, np.nan, 5.0], dtype=expected_dtype))
|
| 213 | + |
| 214 | + |
| 215 | +######################################################################################## |
| 216 | +# Test the _to_numpy function with PyArrow arrays. |
| 217 | +# |
| 218 | +# PyArrow provides the following types: |
| 219 | +# |
| 220 | +# - Numeric types: |
| 221 | +# - int8, int16, int32, int64 |
| 222 | +# - uint8, uint16, uint32, uint64 |
| 223 | +# - float16, float32, float64 |
| 224 | +# |
| 225 | +# In PyArrow, array types can be specified in two ways: |
| 226 | +# |
| 227 | +# - Using string aliases (e.g., "int8") |
| 228 | +# - Using pyarrow.DataType (e.g., ``pa.int8()``) |
| 229 | +# |
| 230 | +# Reference: https://arrow.apache.org/docs/python/api/datatypes.html |
| 231 | +######################################################################################## |
| 232 | +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") |
| 233 | +@pytest.mark.parametrize( |
| 234 | + ("dtype", "expected_dtype"), |
| 235 | + [ |
| 236 | + pytest.param("int8", np.int8, id="int8"), |
| 237 | + pytest.param("int16", np.int16, id="int16"), |
| 238 | + pytest.param("int32", np.int32, id="int32"), |
| 239 | + pytest.param("int64", np.int64, id="int64"), |
| 240 | + pytest.param("uint8", np.uint8, id="uint8"), |
| 241 | + pytest.param("uint16", np.uint16, id="uint16"), |
| 242 | + pytest.param("uint32", np.uint32, id="uint32"), |
| 243 | + pytest.param("uint64", np.uint64, id="uint64"), |
| 244 | + pytest.param("float16", np.float16, id="float16"), |
| 245 | + pytest.param("float32", np.float32, id="float32"), |
| 246 | + pytest.param("float64", np.float64, id="float64"), |
| 247 | + ], |
| 248 | +) |
| 249 | +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype): |
| 250 | + """ |
| 251 | + Test the _to_numpy function with PyArrow arrays of PyArrow numeric types. |
| 252 | + """ |
| 253 | + data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] |
| 254 | + if dtype == "float16": # float16 needs special handling |
| 255 | + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html |
| 256 | + data = np.array(data, dtype=np.float16) |
| 257 | + array = pa.array(data, type=dtype)[::2] |
| 258 | + result = _to_numpy(array) |
| 259 | + _check_result(result, expected_dtype) |
| 260 | + npt.assert_array_equal(result, array) |
| 261 | + |
| 262 | + |
| 263 | +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") |
| 264 | +@pytest.mark.parametrize( |
| 265 | + ("dtype", "expected_dtype"), |
| 266 | + [ |
| 267 | + pytest.param("int8", np.float64, id="int8"), |
| 268 | + pytest.param("int16", np.float64, id="int16"), |
| 269 | + pytest.param("int32", np.float64, id="int32"), |
| 270 | + pytest.param("int64", np.float64, id="int64"), |
| 271 | + pytest.param("uint8", np.float64, id="uint8"), |
| 272 | + pytest.param("uint16", np.float64, id="uint16"), |
| 273 | + pytest.param("uint32", np.float64, id="uint32"), |
| 274 | + pytest.param("uint64", np.float64, id="uint64"), |
| 275 | + pytest.param("float16", np.float16, id="float16"), |
| 276 | + pytest.param("float32", np.float32, id="float32"), |
| 277 | + pytest.param("float64", np.float64, id="float64"), |
| 278 | + ], |
| 279 | +) |
| 280 | +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype): |
| 281 | + """ |
| 282 | + Test the _to_numpy function with PyArrow arrays of PyArrow numeric types and NA. |
| 283 | + """ |
| 284 | + data = [1.0, 2.0, None, 4.0, 5.0, 6.0] |
| 285 | + if dtype == "float16": # float16 needs special handling |
| 286 | + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html |
| 287 | + data = np.array(data, dtype=np.float16) |
| 288 | + array = pa.array(data, type=dtype)[::2] |
| 289 | + result = _to_numpy(array) |
| 290 | + _check_result(result, expected_dtype) |
| 291 | + npt.assert_array_equal(result, array) |
0 commit comments