|
8 | 8 | import pytest
|
9 | 9 | from pygmt.clib.conversion import _to_numpy
|
10 | 10 |
|
| 11 | +try: |
| 12 | + import pyarrow as pa |
| 13 | + |
| 14 | + _HAS_PYARROW = True |
| 15 | +except ImportError: |
| 16 | + _HAS_PYARROW = False |
| 17 | + |
11 | 18 |
|
12 | 19 | def _check_result(result, expected_dtype):
|
13 | 20 | """
|
@@ -122,6 +129,11 @@ def test_to_numpy_ndarray_numpy_dtypes_numeric(dtype, expected_dtype):
|
122 | 129 | # - BooleanDtype
|
123 | 130 | # - ArrowDtype: a special dtype used to store data in the PyArrow format.
|
124 | 131 | #
|
| 132 | +# PyArrow dtypes can be specified using the following formats: |
| 133 | +# |
| 134 | +# - Prefixed with the name of the dtype and "[pyarrow]" (e.g., "int8[pyarrow]") |
| 135 | +# - Specified using ``ArrowDType`` (e.g., "pd.ArrowDtype(pa.int8())") |
| 136 | +# |
125 | 137 | # References:
|
126 | 138 | # 1. https://pandas.pydata.org/docs/reference/arrays.html
|
127 | 139 | # 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes
|
@@ -207,3 +219,137 @@ def test_to_numpy_pandas_series_pandas_dtypes_numeric_with_na(dtype, expected_dt
|
207 | 219 | result = _to_numpy(series)
|
208 | 220 | _check_result(result, expected_dtype)
|
209 | 221 | npt.assert_array_equal(result, np.array([1.0, np.nan, 3.0], dtype=expected_dtype))
|
| 222 | + |
| 223 | + |
| 224 | +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") |
| 225 | +@pytest.mark.parametrize( |
| 226 | + ("dtype", "expected_dtype"), |
| 227 | + [ |
| 228 | + pytest.param("int8[pyarrow]", np.int8, id="int8[pyarrow]"), |
| 229 | + pytest.param("int16[pyarrow]", np.int16, id="int16[pyarrow]"), |
| 230 | + pytest.param("int32[pyarrow]", np.int32, id="int32[pyarrow]"), |
| 231 | + pytest.param("int64[pyarrow]", np.int64, id="int64[pyarrow]"), |
| 232 | + pytest.param("uint8[pyarrow]", np.uint8, id="uint8[pyarrow]"), |
| 233 | + pytest.param("uint16[pyarrow]", np.uint16, id="uint16[pyarrow]"), |
| 234 | + pytest.param("uint32[pyarrow]", np.uint32, id="uint32[pyarrow]"), |
| 235 | + pytest.param("uint64[pyarrow]", np.uint64, id="uint64[pyarrow]"), |
| 236 | + pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"), |
| 237 | + pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"), |
| 238 | + pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"), |
| 239 | + ], |
| 240 | +) |
| 241 | +def test_to_numpy_pandas_series_pyarrow_dtypes_numeric(dtype, expected_dtype): |
| 242 | + """ |
| 243 | + Test the _to_numpy function with pandas.Series of pandas numeric dtypes. |
| 244 | + """ |
| 245 | + series = pd.Series([1, 2, 3], dtype=dtype) |
| 246 | + result = _to_numpy(series) |
| 247 | + _check_result(result, expected_dtype) |
| 248 | + npt.assert_array_equal(result, series) |
| 249 | + |
| 250 | + |
| 251 | +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") |
| 252 | +@pytest.mark.parametrize( |
| 253 | + ("dtype", "expected_dtype"), |
| 254 | + [ |
| 255 | + pytest.param("int8[pyarrow]", np.float64, id="int8[pyarrow]"), |
| 256 | + pytest.param("int16[pyarrow]", np.float64, id="int16[pyarrow]"), |
| 257 | + pytest.param("int32[pyarrow]", np.float64, id="int32[pyarrow]"), |
| 258 | + pytest.param("int64[pyarrow]", np.float64, id="int64[pyarrow]"), |
| 259 | + pytest.param("uint8[pyarrow]", np.float64, id="uint8[pyarrow]"), |
| 260 | + pytest.param("uint16[pyarrow]", np.float64, id="uint16[pyarrow]"), |
| 261 | + pytest.param("uint32[pyarrow]", np.float64, id="uint32[pyarrow]"), |
| 262 | + pytest.param("uint64[pyarrow]", np.float64, id="uint64[pyarrow]"), |
| 263 | + pytest.param("float16[pyarrow]", np.float16, id="float16[pyarrow]"), |
| 264 | + pytest.param("float32[pyarrow]", np.float32, id="float32[pyarrow]"), |
| 265 | + pytest.param("float64[pyarrow]", np.float64, id="float64[pyarrow]"), |
| 266 | + ], |
| 267 | +) |
| 268 | +def test_to_numpy_pandas_series_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype): |
| 269 | + """ |
| 270 | + Test the _to_numpy function with pandas.Series of pandas numeric dtypes and NA. |
| 271 | + """ |
| 272 | + series = pd.Series([1, pd.NA, 3], dtype=dtype) |
| 273 | + result = _to_numpy(series) |
| 274 | + _check_result(result, expected_dtype) |
| 275 | + npt.assert_array_equal(result, np.array([1.0, np.nan, 3.0], dtype=expected_dtype)) |
| 276 | + |
| 277 | + |
| 278 | +######################################################################################## |
| 279 | +# Test the _to_numpy function with PyArrow arrays. |
| 280 | +# |
| 281 | +# PyArrow provides the following dtypes: |
| 282 | +# |
| 283 | +# - Numeric dtypes: |
| 284 | +# - int8, int16, int32, int64 |
| 285 | +# - uint8, uint16, uint32, uint64 |
| 286 | +# - float16, float32, float64 |
| 287 | +# |
| 288 | +# Reference: https://arrow.apache.org/docs/python/api/datatypes.html |
| 289 | +######################################################################################## |
| 290 | +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") |
| 291 | +@pytest.mark.parametrize( |
| 292 | + ("dtype", "expected_dtype"), |
| 293 | + [ |
| 294 | + pytest.param("int8", np.int8, id="int8"), |
| 295 | + pytest.param("int16", np.int16, id="int16"), |
| 296 | + pytest.param("int32", np.int32, id="int32"), |
| 297 | + pytest.param("int64", np.int64, id="int64"), |
| 298 | + pytest.param("uint8", np.uint8, id="uint8"), |
| 299 | + pytest.param("uint16", np.uint16, id="uint16"), |
| 300 | + pytest.param("uint32", np.uint32, id="uint32"), |
| 301 | + pytest.param("uint64", np.uint64, id="uint64"), |
| 302 | + pytest.param("float16", np.float16, id="float16"), |
| 303 | + pytest.param("float32", np.float32, id="float32"), |
| 304 | + pytest.param("float64", np.float64, id="float64"), |
| 305 | + ], |
| 306 | +) |
| 307 | +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric(dtype, expected_dtype): |
| 308 | + """ |
| 309 | + Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes. |
| 310 | + """ |
| 311 | + if dtype == "float16": |
| 312 | + # float16 needs special handling |
| 313 | + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html |
| 314 | + array = pa.array(np.array([1.0, 2.0, 3.0], dtype=np.float16), type=pa.float16()) |
| 315 | + else: |
| 316 | + array = pa.array([1, 2, 3], type=dtype) |
| 317 | + assert array.type == dtype |
| 318 | + result = _to_numpy(array) |
| 319 | + _check_result(result, expected_dtype) |
| 320 | + npt.assert_array_equal(result, array) |
| 321 | + |
| 322 | + |
| 323 | +@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed") |
| 324 | +@pytest.mark.parametrize( |
| 325 | + ("dtype", "expected_dtype"), |
| 326 | + [ |
| 327 | + pytest.param("int8", np.float64, id="int8"), |
| 328 | + pytest.param("int16", np.float64, id="int16"), |
| 329 | + pytest.param("int32", np.float64, id="int32"), |
| 330 | + pytest.param("int64", np.float64, id="int64"), |
| 331 | + pytest.param("uint8", np.float64, id="uint8"), |
| 332 | + pytest.param("uint16", np.float64, id="uint16"), |
| 333 | + pytest.param("uint32", np.float64, id="uint32"), |
| 334 | + pytest.param("uint64", np.float64, id="uint64"), |
| 335 | + pytest.param("float16", np.float16, id="float16"), |
| 336 | + pytest.param("float32", np.float32, id="float32"), |
| 337 | + pytest.param("float64", np.float64, id="float64"), |
| 338 | + ], |
| 339 | +) |
| 340 | +def test_to_numpy_pyarrow_array_pyarrow_dtypes_numeric_with_na(dtype, expected_dtype): |
| 341 | + """ |
| 342 | + Test the _to_numpy function with PyArrow arrays of PyArrow numeric dtypes and NA. |
| 343 | + """ |
| 344 | + if dtype == "float16": |
| 345 | + # float16 needs special handling |
| 346 | + # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html |
| 347 | + array = pa.array( |
| 348 | + np.array([1.0, None, 3.0], dtype=np.float16), type=pa.float16() |
| 349 | + ) |
| 350 | + else: |
| 351 | + array = pa.array([1, None, 3], type=dtype) |
| 352 | + assert array.type == dtype |
| 353 | + result = _to_numpy(array) |
| 354 | + _check_result(result, expected_dtype) |
| 355 | + npt.assert_array_equal(result, array) |
0 commit comments