12
12
13
13
14
14
class BaseOpsUtil (BaseExtensionTests ):
15
+ series_scalar_exc : type [Exception ] | None = TypeError
16
+ frame_scalar_exc : type [Exception ] | None = TypeError
17
+ series_array_exc : type [Exception ] | None = TypeError
18
+ divmod_exc : type [Exception ] | None = TypeError
19
+
20
+ def _get_expected_exception (
21
+ self , op_name : str , obj , other
22
+ ) -> type [Exception ] | None :
23
+ # Find the Exception, if any we expect to raise calling
24
+ # obj.__op_name__(other)
25
+
26
+ # The self.obj_bar_exc pattern isn't great in part because it can depend
27
+ # on op_name or dtypes, but we use it here for backward-compatibility.
28
+ if op_name in ["__divmod__" , "__rdivmod__" ]:
29
+ return self .divmod_exc
30
+ if isinstance (obj , pd .Series ) and isinstance (other , pd .Series ):
31
+ return self .series_array_exc
32
+ elif isinstance (obj , pd .Series ):
33
+ return self .series_scalar_exc
34
+ else :
35
+ return self .frame_scalar_exc
36
+
15
37
def _cast_pointwise_result (self , op_name : str , obj , other , pointwise_result ):
16
38
# In _check_op we check that the result of a pointwise operation
17
39
# (found via _combine) matches the result of the vectorized
@@ -24,17 +46,21 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
24
46
def get_op_from_name (self , op_name : str ):
25
47
return tm .get_op_from_name (op_name )
26
48
27
- def check_opname (self , ser : pd .Series , op_name : str , other , exc = Exception ):
28
- op = self .get_op_from_name (op_name )
29
-
30
- self ._check_op (ser , op , other , op_name , exc )
31
-
32
- # Subclasses are not expected to need to override _check_op or _combine.
49
+ # Subclasses are not expected to need to override check_opname, _check_op,
50
+ # _check_divmod_op, or _combine.
33
51
# Ideally any relevant overriding can be done in _cast_pointwise_result,
34
52
# get_op_from_name, and the specification of `exc`. If you find a use
35
53
# case that still requires overriding _check_op or _combine, please let
36
54
# us know at github.com/pandas-dev/pandas/issues
37
55
@final
56
+ def check_opname (self , ser : pd .Series , op_name : str , other ):
57
+ exc = self ._get_expected_exception (op_name , ser , other )
58
+ op = self .get_op_from_name (op_name )
59
+
60
+ self ._check_op (ser , op , other , op_name , exc )
61
+
62
+ # see comment on check_opname
63
+ @final
38
64
def _combine (self , obj , other , op ):
39
65
if isinstance (obj , pd .DataFrame ):
40
66
if len (obj .columns ) != 1 :
@@ -44,11 +70,14 @@ def _combine(self, obj, other, op):
44
70
expected = obj .combine (other , op )
45
71
return expected
46
72
47
- # see comment on _combine
73
+ # see comment on check_opname
48
74
@final
49
75
def _check_op (
50
76
self , ser : pd .Series , op , other , op_name : str , exc = NotImplementedError
51
77
):
78
+ # Check that the Series/DataFrame arithmetic/comparison method matches
79
+ # the pointwise result from _combine.
80
+
52
81
if exc is None :
53
82
result = op (ser , other )
54
83
expected = self ._combine (ser , other , op )
@@ -59,8 +88,14 @@ def _check_op(
59
88
with pytest .raises (exc ):
60
89
op (ser , other )
61
90
62
- def _check_divmod_op (self , ser : pd .Series , op , other , exc = Exception ):
63
- # divmod has multiple return values, so check separately
91
+ # see comment on check_opname
92
+ @final
93
+ def _check_divmod_op (self , ser : pd .Series , op , other ):
94
+ # check that divmod behavior matches behavior of floordiv+mod
95
+ if op is divmod :
96
+ exc = self ._get_expected_exception ("__divmod__" , ser , other )
97
+ else :
98
+ exc = self ._get_expected_exception ("__rdivmod__" , ser , other )
64
99
if exc is None :
65
100
result_div , result_mod = op (ser , other )
66
101
if op is divmod :
@@ -96,26 +131,24 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
96
131
# series & scalar
97
132
op_name = all_arithmetic_operators
98
133
ser = pd .Series (data )
99
- self .check_opname (ser , op_name , ser .iloc [0 ], exc = self . series_scalar_exc )
134
+ self .check_opname (ser , op_name , ser .iloc [0 ])
100
135
101
136
def test_arith_frame_with_scalar (self , data , all_arithmetic_operators ):
102
137
# frame & scalar
103
138
op_name = all_arithmetic_operators
104
139
df = pd .DataFrame ({"A" : data })
105
- self .check_opname (df , op_name , data [0 ], exc = self . frame_scalar_exc )
140
+ self .check_opname (df , op_name , data [0 ])
106
141
107
142
def test_arith_series_with_array (self , data , all_arithmetic_operators ):
108
143
# ndarray & other series
109
144
op_name = all_arithmetic_operators
110
145
ser = pd .Series (data )
111
- self .check_opname (
112
- ser , op_name , pd .Series ([ser .iloc [0 ]] * len (ser )), exc = self .series_array_exc
113
- )
146
+ self .check_opname (ser , op_name , pd .Series ([ser .iloc [0 ]] * len (ser )))
114
147
115
148
def test_divmod (self , data ):
116
149
ser = pd .Series (data )
117
- self ._check_divmod_op (ser , divmod , 1 , exc = self . divmod_exc )
118
- self ._check_divmod_op (1 , ops .rdivmod , ser , exc = self . divmod_exc )
150
+ self ._check_divmod_op (ser , divmod , 1 )
151
+ self ._check_divmod_op (1 , ops .rdivmod , ser )
119
152
120
153
def test_divmod_series_array (self , data , data_for_twos ):
121
154
ser = pd .Series (data )
0 commit comments