3
3
4
4
import numpy as np
5
5
import pandas as pd
6
+ import pytest
6
7
from typing_extensions import assert_type
7
8
8
9
from tests import (
@@ -44,6 +45,7 @@ def test_string_accessors_boolean_series():
44
45
_check (assert_type (s .str .endswith ("e" ), "pd.Series[bool]" ))
45
46
_check (assert_type (s .str .endswith (("e" , "f" )), "pd.Series[bool]" ))
46
47
_check (assert_type (s .str .fullmatch ("apple" ), "pd.Series[bool]" ))
48
+ _check (assert_type (s .str .fullmatch (re .compile (r"apple" )), "pd.Series[bool]" ))
47
49
_check (assert_type (s .str .isalnum (), "pd.Series[bool]" ))
48
50
_check (assert_type (s .str .isalpha (), "pd.Series[bool]" ))
49
51
_check (assert_type (s .str .isdecimal (), "pd.Series[bool]" ))
@@ -54,6 +56,7 @@ def test_string_accessors_boolean_series():
54
56
_check (assert_type (s .str .istitle (), "pd.Series[bool]" ))
55
57
_check (assert_type (s .str .isupper (), "pd.Series[bool]" ))
56
58
_check (assert_type (s .str .match ("pp" ), "pd.Series[bool]" ))
59
+ _check (assert_type (s .str .match (re .compile (r"pp" )), "pd.Series[bool]" ))
57
60
58
61
59
62
def test_string_accessors_boolean_index ():
@@ -72,6 +75,7 @@ def test_string_accessors_boolean_index():
72
75
_check (assert_type (idx .str .endswith ("e" ), np_ndarray_bool ))
73
76
_check (assert_type (idx .str .endswith (("e" , "f" )), np_ndarray_bool ))
74
77
_check (assert_type (idx .str .fullmatch ("apple" ), np_ndarray_bool ))
78
+ _check (assert_type (idx .str .fullmatch (re .compile (r"apple" )), np_ndarray_bool ))
75
79
_check (assert_type (idx .str .isalnum (), np_ndarray_bool ))
76
80
_check (assert_type (idx .str .isalpha (), np_ndarray_bool ))
77
81
_check (assert_type (idx .str .isdecimal (), np_ndarray_bool ))
@@ -82,6 +86,7 @@ def test_string_accessors_boolean_index():
82
86
_check (assert_type (idx .str .istitle (), np_ndarray_bool ))
83
87
_check (assert_type (idx .str .isupper (), np_ndarray_bool ))
84
88
_check (assert_type (idx .str .match ("pp" ), np_ndarray_bool ))
89
+ _check (assert_type (idx .str .match (re .compile (r"pp" )), np_ndarray_bool ))
85
90
86
91
87
92
def test_string_accessors_integer_series ():
@@ -94,6 +99,10 @@ def test_string_accessors_integer_series():
94
99
_check (assert_type (s .str .count ("pp" ), "pd.Series[int]" ))
95
100
_check (assert_type (s .str .len (), "pd.Series[int]" ))
96
101
102
+ # unlike findall, find doesn't accept a compiled pattern
103
+ with pytest .raises (TypeError ):
104
+ s .str .find (re .compile (r"p" )) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
105
+
97
106
98
107
def test_string_accessors_integer_index ():
99
108
idx = pd .Index (DATA )
@@ -105,6 +114,10 @@ def test_string_accessors_integer_index():
105
114
_check (assert_type (idx .str .count ("pp" ), "pd.Index[int]" ))
106
115
_check (assert_type (idx .str .len (), "pd.Index[int]" ))
107
116
117
+ # unlike findall, find doesn't accept a compiled pattern
118
+ with pytest .raises (TypeError ):
119
+ idx .str .find (re .compile (r"p" )) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
120
+
108
121
109
122
def test_string_accessors_string_series ():
110
123
s = pd .Series (DATA )
@@ -123,6 +136,9 @@ def test_string_accessors_string_series():
123
136
_check (assert_type (s .str .removesuffix ("e" ), "pd.Series[str]" ))
124
137
_check (assert_type (s .str .repeat (2 ), "pd.Series[str]" ))
125
138
_check (assert_type (s .str .replace ("a" , "X" ), "pd.Series[str]" ))
139
+ _check (
140
+ assert_type (s .str .replace (re .compile (r"a" ), "X" , regex = True ), "pd.Series[str]" )
141
+ )
126
142
_check (assert_type (s .str .rjust (80 ), "pd.Series[str]" ))
127
143
_check (assert_type (s .str .rstrip (), "pd.Series[str]" ))
128
144
_check (assert_type (s .str .slice_replace (0 , 2 , "XX" ), "pd.Series[str]" ))
@@ -158,6 +174,9 @@ def test_string_accessors_string_index():
158
174
_check (assert_type (idx .str .removesuffix ("e" ), "pd.Index[str]" ))
159
175
_check (assert_type (idx .str .repeat (2 ), "pd.Index[str]" ))
160
176
_check (assert_type (idx .str .replace ("a" , "X" ), "pd.Index[str]" ))
177
+ _check (
178
+ assert_type (idx .str .replace (re .compile (r"a" ), "X" , regex = True ), "pd.Index[str]" )
179
+ )
161
180
_check (assert_type (idx .str .rjust (80 ), "pd.Index[str]" ))
162
181
_check (assert_type (idx .str .rstrip (), "pd.Index[str]" ))
163
182
_check (assert_type (idx .str .slice_replace (0 , 2 , "XX" ), "pd.Index[str]" ))
@@ -190,29 +209,49 @@ def test_string_accessors_list_series():
190
209
s = pd .Series (DATA )
191
210
_check = functools .partial (check , klass = pd .Series , dtype = list )
192
211
_check (assert_type (s .str .findall ("pp" ), "pd.Series[list[str]]" ))
212
+ _check (assert_type (s .str .findall (re .compile (r"pp" )), "pd.Series[list[str]]" ))
193
213
_check (assert_type (s .str .split ("a" ), "pd.Series[list[str]]" ))
214
+ _check (assert_type (s .str .split (re .compile (r"a" )), "pd.Series[list[str]]" ))
194
215
# GH 194
195
216
_check (assert_type (s .str .split ("a" , expand = False ), "pd.Series[list[str]]" ))
196
217
_check (assert_type (s .str .rsplit ("a" ), "pd.Series[list[str]]" ))
197
218
_check (assert_type (s .str .rsplit ("a" , expand = False ), "pd.Series[list[str]]" ))
198
219
220
+ # rsplit doesn't accept compiled pattern
221
+ # it doesn't raise at runtime but produces a nan
222
+ bad_rsplit_result = s .str .rsplit (
223
+ re .compile (r"a" ) # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
224
+ )
225
+ assert bad_rsplit_result .isna ().all ()
226
+
199
227
200
228
def test_string_accessors_list_index ():
201
229
idx = pd .Index (DATA )
202
230
_check = functools .partial (check , klass = pd .Index , dtype = list )
203
231
_check (assert_type (idx .str .findall ("pp" ), "pd.Index[list[str]]" ))
232
+ _check (assert_type (idx .str .findall (re .compile (r"pp" )), "pd.Index[list[str]]" ))
204
233
_check (assert_type (idx .str .split ("a" ), "pd.Index[list[str]]" ))
234
+ _check (assert_type (idx .str .split (re .compile (r"a" )), "pd.Index[list[str]]" ))
205
235
# GH 194
206
236
_check (assert_type (idx .str .split ("a" , expand = False ), "pd.Index[list[str]]" ))
207
237
_check (assert_type (idx .str .rsplit ("a" ), "pd.Index[list[str]]" ))
208
238
_check (assert_type (idx .str .rsplit ("a" , expand = False ), "pd.Index[list[str]]" ))
209
239
240
+ # rsplit doesn't accept compiled pattern
241
+ # it doesn't raise at runtime but produces a nan
242
+ bad_rsplit_result = idx .str .rsplit (
243
+ re .compile (r"a" ) # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
244
+ )
245
+ assert bad_rsplit_result .isna ().all ()
246
+
210
247
211
248
def test_string_accessors_expanding_series ():
212
249
s = pd .Series (["a1" , "b2" , "c3" ])
213
250
_check = functools .partial (check , klass = pd .DataFrame )
214
251
_check (assert_type (s .str .extract (r"([ab])?(\d)" ), pd .DataFrame ))
252
+ _check (assert_type (s .str .extract (re .compile (r"([ab])?(\d)" )), pd .DataFrame ))
215
253
_check (assert_type (s .str .extractall (r"([ab])?(\d)" ), pd .DataFrame ))
254
+ _check (assert_type (s .str .extractall (re .compile (r"([ab])?(\d)" )), pd .DataFrame ))
216
255
_check (assert_type (s .str .get_dummies (), pd .DataFrame ))
217
256
_check (assert_type (s .str .partition ("p" ), pd .DataFrame ))
218
257
_check (assert_type (s .str .rpartition ("p" ), pd .DataFrame ))
@@ -231,7 +270,15 @@ def test_string_accessors_expanding_index():
231
270
232
271
# These ones are the odd ones out?
233
272
check (assert_type (idx .str .extractall (r"([ab])?(\d)" ), pd .DataFrame ), pd .DataFrame )
273
+ check (
274
+ assert_type (idx .str .extractall (re .compile (r"([ab])?(\d)" )), pd .DataFrame ),
275
+ pd .DataFrame ,
276
+ )
234
277
check (assert_type (idx .str .extract (r"([ab])?(\d)" ), pd .DataFrame ), pd .DataFrame )
278
+ check (
279
+ assert_type (idx .str .extract (re .compile (r"([ab])?(\d)" )), pd .DataFrame ),
280
+ pd .DataFrame ,
281
+ )
235
282
236
283
237
284
def test_series_overloads_partition ():
0 commit comments