@@ -34,6 +34,9 @@ def timestamp(unit: str, tz: str | None = None):
34
34
35
35
_HAS_PYARROW = False
36
36
37
+ # Mark tests that require pyarrow
38
+ pa_marks = {"marks" : skip_if_no (package = "pyarrow" )}
39
+
37
40
38
41
def _check_result (result , expected_dtype ):
39
42
"""
@@ -173,22 +176,105 @@ def test_to_numpy_numpy_string(dtype):
173
176
# - BooleanDtype
174
177
# - ArrowDtype: a special dtype used to store data in the PyArrow format.
175
178
#
179
+ # In pandas, PyArrow types can be specified using the following formats:
180
+ #
181
+ # - Prefixed with the name of the dtype and "[pyarrow]" (e.g., "int8[pyarrow]")
182
+ # - Specified using ``ArrowDType`` (e.g., "pd.ArrowDtype(pa.int8())")
183
+ #
176
184
# References:
177
185
# 1. https://pandas.pydata.org/docs/reference/arrays.html
178
186
# 2. https://pandas.pydata.org/docs/user_guide/basics.html#basics-dtypes
179
187
# 3. https://pandas.pydata.org/docs/user_guide/pyarrow.html
180
188
########################################################################################
181
- @pytest .mark .parametrize (("dtype" , "expected_dtype" ), np_dtype_params )
189
+ @pytest .mark .parametrize (
190
+ ("dtype" , "expected_dtype" ),
191
+ [
192
+ * np_dtype_params ,
193
+ pytest .param (pd .Int8Dtype (), np .int8 , id = "Int8" ),
194
+ pytest .param (pd .Int16Dtype (), np .int16 , id = "Int16" ),
195
+ pytest .param (pd .Int32Dtype (), np .int32 , id = "Int32" ),
196
+ pytest .param (pd .Int64Dtype (), np .int64 , id = "Int64" ),
197
+ pytest .param (pd .UInt8Dtype (), np .uint8 , id = "UInt8" ),
198
+ pytest .param (pd .UInt16Dtype (), np .uint16 , id = "UInt16" ),
199
+ pytest .param (pd .UInt32Dtype (), np .uint32 , id = "UInt32" ),
200
+ pytest .param (pd .UInt64Dtype (), np .uint64 , id = "UInt64" ),
201
+ pytest .param (pd .Float32Dtype (), np .float32 , id = "Float32" ),
202
+ pytest .param (pd .Float64Dtype (), np .float64 , id = "Float64" ),
203
+ pytest .param ("int8[pyarrow]" , np .int8 , id = "int8[pyarrow]" , ** pa_marks ),
204
+ pytest .param ("int16[pyarrow]" , np .int16 , id = "int16[pyarrow]" , ** pa_marks ),
205
+ pytest .param ("int32[pyarrow]" , np .int32 , id = "int32[pyarrow]" , ** pa_marks ),
206
+ pytest .param ("int64[pyarrow]" , np .int64 , id = "int64[pyarrow]" , ** pa_marks ),
207
+ pytest .param ("uint8[pyarrow]" , np .uint8 , id = "uint8[pyarrow]" , ** pa_marks ),
208
+ pytest .param ("uint16[pyarrow]" , np .uint16 , id = "uint16[pyarrow]" , ** pa_marks ),
209
+ pytest .param ("uint32[pyarrow]" , np .uint32 , id = "uint32[pyarrow]" , ** pa_marks ),
210
+ pytest .param ("uint64[pyarrow]" , np .uint64 , id = "uint64[pyarrow]" , ** pa_marks ),
211
+ pytest .param ("float16[pyarrow]" , np .float16 , id = "float16[pyarrow]" , ** pa_marks ),
212
+ pytest .param ("float32[pyarrow]" , np .float32 , id = "float32[pyarrow]" , ** pa_marks ),
213
+ pytest .param ("float64[pyarrow]" , np .float64 , id = "float64[pyarrow]" , ** pa_marks ),
214
+ ],
215
+ )
182
216
def test_to_numpy_pandas_numeric (dtype , expected_dtype ):
183
217
"""
184
218
Test the _to_numpy function with pandas.Series of numeric dtypes.
185
219
"""
186
- series = pd .Series ([1 , 2 , 3 , 4 , 5 , 6 ], dtype = dtype )[::2 ] # Not C-contiguous
220
+ data = [1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 ]
221
+ if dtype == "float16[pyarrow]" and Version (pd .__version__ ) < Version ("2.2" ):
222
+ # float16 needs special handling for pandas < 2.2.
223
+ # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
224
+ data = np .array (data , dtype = np .float16 )
225
+ series = pd .Series (data , dtype = dtype )[::2 ] # Not C-contiguous
187
226
result = _to_numpy (series )
188
227
_check_result (result , expected_dtype )
189
228
npt .assert_array_equal (result , series )
190
229
191
230
231
+ @pytest .mark .parametrize (
232
+ ("dtype" , "expected_dtype" ),
233
+ [
234
+ pytest .param (np .float16 , np .float16 , id = "float16" ),
235
+ pytest .param (np .float32 , np .float32 , id = "float32" ),
236
+ pytest .param (np .float64 , np .float64 , id = "float64" ),
237
+ pytest .param (np .longdouble , np .longdouble , id = "longdouble" ),
238
+ pytest .param (pd .Int8Dtype (), np .float64 , id = "Int8" ),
239
+ pytest .param (pd .Int16Dtype (), np .float64 , id = "Int16" ),
240
+ pytest .param (pd .Int32Dtype (), np .float64 , id = "Int32" ),
241
+ pytest .param (pd .Int64Dtype (), np .float64 , id = "Int64" ),
242
+ pytest .param (pd .UInt8Dtype (), np .float64 , id = "UInt8" ),
243
+ pytest .param (pd .UInt16Dtype (), np .float64 , id = "UInt16" ),
244
+ pytest .param (pd .UInt32Dtype (), np .float64 , id = "UInt32" ),
245
+ pytest .param (pd .UInt64Dtype (), np .float64 , id = "UInt64" ),
246
+ pytest .param (pd .Float32Dtype (), np .float32 , id = "Float32" ),
247
+ pytest .param (pd .Float64Dtype (), np .float64 , id = "Float64" ),
248
+ pytest .param ("int8[pyarrow]" , np .float64 , id = "int8[pyarrow]" , ** pa_marks ),
249
+ pytest .param ("int16[pyarrow]" , np .float64 , id = "int16[pyarrow]" , ** pa_marks ),
250
+ pytest .param ("int32[pyarrow]" , np .float64 , id = "int32[pyarrow]" , ** pa_marks ),
251
+ pytest .param ("int64[pyarrow]" , np .float64 , id = "int64[pyarrow]" , ** pa_marks ),
252
+ pytest .param ("uint8[pyarrow]" , np .float64 , id = "uint8[pyarrow]" , ** pa_marks ),
253
+ pytest .param ("uint16[pyarrow]" , np .float64 , id = "uint16[pyarrow]" , ** pa_marks ),
254
+ pytest .param ("uint32[pyarrow]" , np .float64 , id = "uint32[pyarrow]" , ** pa_marks ),
255
+ pytest .param ("uint64[pyarrow]" , np .float64 , id = "uint64[pyarrow]" , ** pa_marks ),
256
+ pytest .param ("float16[pyarrow]" , np .float16 , id = "float16[pyarrow]" , ** pa_marks ),
257
+ pytest .param ("float32[pyarrow]" , np .float32 , id = "float32[pyarrow]" , ** pa_marks ),
258
+ pytest .param ("float64[pyarrow]" , np .float64 , id = "float64[pyarrow]" , ** pa_marks ),
259
+ ],
260
+ )
261
+ def test_to_numpy_pandas_numeric_with_na (dtype , expected_dtype ):
262
+ """
263
+ Test the _to_numpy function with pandas.Series of NumPy/pandas/PyArrow numeric
264
+ dtypes and missing values (NA).
265
+ """
266
+ data = [1.0 , 2.0 , None , 4.0 , 5.0 , 6.0 ]
267
+ if dtype == "float16[pyarrow]" and Version (pd .__version__ ) < Version ("2.2" ):
268
+ # float16 needs special handling for pandas < 2.2.
269
+ # Example from https://arrow.apache.org/docs/python/generated/pyarrow.float16.html
270
+ data = np .array (data , dtype = np .float16 )
271
+ series = pd .Series (data , dtype = dtype )[::2 ] # Not C-contiguous
272
+ assert series .isna ().any ()
273
+ result = _to_numpy (series )
274
+ _check_result (result , expected_dtype )
275
+ npt .assert_array_equal (result , np .array ([1.0 , np .nan , 5.0 ], dtype = expected_dtype ))
276
+
277
+
192
278
@pytest .mark .parametrize (
193
279
"dtype" ,
194
280
[
0 commit comments