1818import numpy as np
1919import pytest
2020
21- from pandas .core .dtypes .cast import can_hold_element
2221from pandas .core .dtypes .dtypes import NumpyEADtype
2322
2423import pandas as pd
2524import pandas ._testing as tm
2625from pandas .api .types import is_object_dtype
2726from pandas .core .arrays .numpy_ import NumpyExtensionArray
28- from pandas .core .internals import blocks
2927from pandas .tests .extension import base
3028
31-
32- def _can_hold_element_patched (obj , element ) -> bool :
33- if isinstance (element , NumpyExtensionArray ):
34- element = element .to_numpy ()
35- return can_hold_element (obj , element )
36-
37-
3829orig_assert_attr_equal = tm .assert_attr_equal
3930
4031
@@ -78,7 +69,6 @@ def allow_in_pandas(monkeypatch):
7869 """
7970 with monkeypatch .context () as m :
8071 m .setattr (NumpyExtensionArray , "_typ" , "extension" )
81- m .setattr (blocks , "can_hold_element" , _can_hold_element_patched )
8272 m .setattr (tm .asserters , "assert_attr_equal" , _assert_attr_equal )
8373 yield
8474
@@ -175,15 +165,7 @@ def skip_numpy_object(dtype, request):
175165skip_nested = pytest .mark .usefixtures ("skip_numpy_object" )
176166
177167
178- class BaseNumPyTests :
179- pass
180-
181-
182- class TestCasting (BaseNumPyTests , base .BaseCastingTests ):
183- pass
184-
185-
186- class TestConstructors (BaseNumPyTests , base .BaseConstructorsTests ):
168+ class TestNumpyExtensionArray (base .ExtensionTests ):
187169 @pytest .mark .skip (reason = "We don't register our dtype" )
188170 # We don't want to register. This test should probably be split in two.
189171 def test_from_dtype (self , data ):
@@ -194,8 +176,6 @@ def test_series_constructor_scalar_with_index(self, data, dtype):
194176 # ValueError: Length of passed values is 1, index implies 3.
195177 super ().test_series_constructor_scalar_with_index (data , dtype )
196178
197-
198- class TestDtype (BaseNumPyTests , base .BaseDtypeTests ):
199179 def test_check_dtype (self , data , request , using_infer_string ):
200180 if data .dtype .numpy_dtype == "object" :
201181 request .applymarker (
@@ -214,26 +194,11 @@ def test_is_not_object_type(self, dtype, request):
214194 else :
215195 super ().test_is_not_object_type (dtype )
216196
217-
218- class TestGetitem (BaseNumPyTests , base .BaseGetitemTests ):
219197 @skip_nested
220198 def test_getitem_scalar (self , data ):
221199 # AssertionError
222200 super ().test_getitem_scalar (data )
223201
224-
225- class TestGroupby (BaseNumPyTests , base .BaseGroupbyTests ):
226- pass
227-
228-
229- class TestInterface (BaseNumPyTests , base .BaseInterfaceTests ):
230- @skip_nested
231- def test_array_interface (self , data ):
232- # NumPy array shape inference
233- super ().test_array_interface (data )
234-
235-
236- class TestMethods (BaseNumPyTests , base .BaseMethodsTests ):
237202 @skip_nested
238203 def test_shift_fill_value (self , data ):
239204 # np.array shape inference. Shift implementation fails.
@@ -251,7 +216,9 @@ def test_fillna_copy_series(self, data_missing):
251216
252217 @skip_nested
253218 def test_searchsorted (self , data_for_sorting , as_series ):
254- # Test setup fails.
219+ # TODO: NumpyExtensionArray.searchsorted calls ndarray.searchsorted which
220+ # isn't quite what we want in nested data cases. Instead we need to
221+ # adapt something like libindex._bin_search.
255222 super ().test_searchsorted (data_for_sorting , as_series )
256223
257224 @pytest .mark .xfail (reason = "NumpyExtensionArray.diff may fail on dtype" )
@@ -270,38 +237,60 @@ def test_insert_invalid(self, data, invalid_scalar):
270237 # NumpyExtensionArray[object] can hold anything, so skip
271238 super ().test_insert_invalid (data , invalid_scalar )
272239
273-
274- class TestArithmetics (BaseNumPyTests , base .BaseArithmeticOpsTests ):
275240 divmod_exc = None
276241 series_scalar_exc = None
277242 frame_scalar_exc = None
278243 series_array_exc = None
279244
280- @skip_nested
281245 def test_divmod (self , data ):
246+ divmod_exc = None
247+ if data .dtype .kind == "O" :
248+ divmod_exc = TypeError
249+ self .divmod_exc = divmod_exc
282250 super ().test_divmod (data )
283251
284- @skip_nested
285- def test_arith_series_with_scalar (self , data , all_arithmetic_operators ):
252+ def test_divmod_series_array (self , data ):
253+ ser = pd .Series (data )
254+ exc = None
255+ if data .dtype .kind == "O" :
256+ exc = TypeError
257+ self .divmod_exc = exc
258+ self ._check_divmod_op (ser , divmod , data )
259+
260+ def test_arith_series_with_scalar (self , data , all_arithmetic_operators , request ):
261+ opname = all_arithmetic_operators
262+ series_scalar_exc = None
263+ if data .dtype .numpy_dtype == object :
264+ if opname in ["__mul__" , "__rmul__" ]:
265+ mark = pytest .mark .xfail (
266+ reason = "the Series.combine step raises but not the Series method."
267+ )
268+ request .node .add_marker (mark )
269+ series_scalar_exc = TypeError
270+ self .series_scalar_exc = series_scalar_exc
286271 super ().test_arith_series_with_scalar (data , all_arithmetic_operators )
287272
288- def test_arith_series_with_array (self , data , all_arithmetic_operators , request ):
273+ def test_arith_series_with_array (self , data , all_arithmetic_operators ):
289274 opname = all_arithmetic_operators
275+ series_array_exc = None
290276 if data .dtype .numpy_dtype == object and opname not in ["__add__" , "__radd__" ]:
291- mark = pytest . mark . xfail ( reason = "Fails for object dtype" )
292- request . applymarker ( mark )
277+ series_array_exc = TypeError
278+ self . series_array_exc = series_array_exc
293279 super ().test_arith_series_with_array (data , all_arithmetic_operators )
294280
295- @skip_nested
296- def test_arith_frame_with_scalar (self , data , all_arithmetic_operators ):
281+ def test_arith_frame_with_scalar (self , data , all_arithmetic_operators , request ):
282+ opname = all_arithmetic_operators
283+ frame_scalar_exc = None
284+ if data .dtype .numpy_dtype == object :
285+ if opname in ["__mul__" , "__rmul__" ]:
286+ mark = pytest .mark .xfail (
287+ reason = "the Series.combine step raises but not the Series method."
288+ )
289+ request .node .add_marker (mark )
290+ frame_scalar_exc = TypeError
291+ self .frame_scalar_exc = frame_scalar_exc
297292 super ().test_arith_frame_with_scalar (data , all_arithmetic_operators )
298293
299-
300- class TestPrinting (BaseNumPyTests , base .BasePrintingTests ):
301- pass
302-
303-
304- class TestReduce (BaseNumPyTests , base .BaseReduceTests ):
305294 def _supports_reduction (self , ser : pd .Series , op_name : str ) -> bool :
306295 if ser .dtype .kind == "O" :
307296 return op_name in ["sum" , "min" , "max" , "any" , "all" ]
@@ -328,8 +317,6 @@ def check_reduce(self, ser: pd.Series, op_name: str, skipna: bool):
328317 def test_reduce_frame (self , data , all_numeric_reductions , skipna ):
329318 pass
330319
331-
332- class TestMissing (BaseNumPyTests , base .BaseMissingTests ):
333320 @skip_nested
334321 def test_fillna_series (self , data_missing ):
335322 # Non-scalar "scalar" values.
@@ -340,12 +327,6 @@ def test_fillna_frame(self, data_missing):
340327 # Non-scalar "scalar" values.
341328 super ().test_fillna_frame (data_missing )
342329
343-
344- class TestReshaping (BaseNumPyTests , base .BaseReshapingTests ):
345- pass
346-
347-
348- class TestSetitem (BaseNumPyTests , base .BaseSetitemTests ):
349330 @skip_nested
350331 def test_setitem_invalid (self , data , invalid_scalar ):
351332 # object dtype can hold anything, so doesn't raise
@@ -431,11 +412,25 @@ def test_setitem_with_expansion_dataframe_column(self, data, full_indexer):
431412 expected = pd .DataFrame ({"data" : data .to_numpy ()})
432413 tm .assert_frame_equal (result , expected , check_column_type = False )
433414
415+ @pytest .mark .xfail (reason = "NumpyEADtype is unpacked" )
416+ def test_index_from_listlike_with_dtype (self , data ):
417+ super ().test_index_from_listlike_with_dtype (data )
434418
435- @skip_nested
436- class TestParsing (BaseNumPyTests , base .BaseParsingTests ):
437- pass
419+ @skip_nested
420+ @pytest .mark .parametrize ("engine" , ["c" , "python" ])
421+ def test_EA_types (self , engine , data , request ):
422+ super ().test_EA_types (engine , data , request )
423+
424+ @pytest .mark .xfail (reason = "Expect NumpyEA, get np.ndarray" )
425+ def test_compare_array (self , data , comparison_op ):
426+ super ().test_compare_array (data , comparison_op )
427+
428+ def test_compare_scalar (self , data , comparison_op , request ):
429+ if data .dtype .kind == "f" or comparison_op .__name__ in ["eq" , "ne" ]:
430+ mark = pytest .mark .xfail (reason = "Expect NumpyEA, get np.ndarray" )
431+ request .applymarker (mark )
432+ super ().test_compare_scalar (data , comparison_op )
438433
439434
440- class Test2DCompat (BaseNumPyTests , base .NDArrayBacked2DTests ):
435+ class Test2DCompat (base .NDArrayBacked2DTests ):
441436 pass
0 commit comments