@@ -203,8 +203,12 @@ def _append_end_assert(pattern):
203
203
else :
204
204
return pattern if pattern .endswith (rb"\Z" ) else pattern + rb"\Z"
205
205
206
+ def _is_bytes_like (object ):
207
+ return isinstance (object , (bytes , bytearray , memoryview ))
208
+
206
209
class SRE_Pattern ():
207
210
def __init__ (self , pattern , flags ):
211
+ self .__binary = isinstance (pattern , bytes )
208
212
self .pattern = pattern
209
213
self .flags = flags
210
214
flags_str = []
@@ -220,9 +224,18 @@ def __init__(self, pattern, flags):
220
224
self .groupindex [group_name ] = self .__compiled_regexes [self .pattern ].groups [group_name ]
221
225
222
226
227
+ def __check_input_type (self , input ):
228
+ if not isinstance (input , str ) and not _is_bytes_like (input ):
229
+ raise TypeError ("expected string or bytes-like object" )
230
+ if not self .__binary and _is_bytes_like (input ):
231
+ raise TypeError ("cannot use a string pattern on a bytes-like object" )
232
+ if self .__binary and isinstance (input , str ):
233
+ raise TypeError ("cannot use a bytes pattern on a string-like object" )
234
+
235
+
223
236
def __tregex_compile (self , pattern ):
224
237
if pattern not in self .__compiled_regexes :
225
- tregex_engine = TREGEX_ENGINE_STR if isinstance ( pattern , str ) else TREGEX_ENGINE_BYTES
238
+ tregex_engine = TREGEX_ENGINE_BYTES if self . __binary else TREGEX_ENGINE_STR
226
239
try :
227
240
self .__compiled_regexes [pattern ] = tregex_call_compile (tregex_engine , pattern , self .flags_str )
228
241
except ValueError as e :
@@ -266,12 +279,15 @@ def _search(self, pattern, string, pos, endpos):
266
279
return None
267
280
268
281
def search (self , string , pos = 0 , endpos = None ):
282
+ self .__check_input_type (string )
269
283
return self ._search (self .pattern , string , pos , default (endpos , - 1 ))
270
284
271
285
def match (self , string , pos = 0 , endpos = None ):
286
+ self .__check_input_type (string )
272
287
return self ._search (_prepend_begin_assert (self .pattern ), string , pos , default (endpos , - 1 ))
273
288
274
289
def fullmatch (self , string , pos = 0 , endpos = None ):
290
+ self .__check_input_type (string )
275
291
return self ._search (_append_end_assert (_prepend_begin_assert (self .pattern )), string , pos , default (endpos , - 1 ))
276
292
277
293
def __sanitize_out_type (self , elem ):
@@ -283,6 +299,7 @@ def __sanitize_out_type(self, elem):
283
299
return str (elem )
284
300
285
301
def findall (self , string , pos = 0 , endpos = - 1 ):
302
+ self .__check_input_type (string )
286
303
if endpos > len (string ):
287
304
endpos = len (string )
288
305
elif endpos < 0 :
@@ -312,20 +329,20 @@ def group(match_result, group_nr, string):
312
329
return string [group_start :group_end ]
313
330
314
331
n = len (repl )
315
- result = ""
332
+ result = b"" if self . __binary else ""
316
333
start = 0
317
- backslash = '\\ '
334
+ backslash = b' \\ ' if self . __binary else '\\ '
318
335
pos = repl .find (backslash , start )
319
336
while pos != - 1 and start < n :
320
337
if pos + 1 < n :
321
338
if repl [pos + 1 ].isdigit () and match_result .groupCount > 0 :
322
- group_nr = int (repl [pos + 1 ])
339
+ group_nr = int (repl [pos + 1 ]. decode ( 'ascii' )) if self . __binary else int ( repl [ pos + 1 ] )
323
340
group_str = group (match_result , group_nr , string )
324
341
if group_str is None :
325
342
raise ValueError ("invalid group reference %s at position %s" % (group_nr , pos ))
326
343
result += repl [start :pos ] + group_str
327
344
start = pos + 2
328
- elif repl [pos + 1 ] == 'g' :
345
+ elif repl [pos + 1 ] == ( b 'g' if self . __binary else 'g' ) :
329
346
group_ref , group_ref_end , digits_only = self .__extract_groupname (repl , pos + 2 )
330
347
if group_ref :
331
348
group_str = group (match_result , int (group_ref ) if digits_only else pattern .groups [group_ref ], string )
@@ -345,26 +362,30 @@ def group(match_result, group_nr, string):
345
362
346
363
347
364
def __extract_groupname (self , repl , pos ):
348
- if repl [pos ] == '<' :
365
+ if repl [pos ] == ( b '<' if self . __binary else '<' ) :
349
366
digits_only = True
350
367
n = len (repl )
351
368
i = pos + 1
352
- while i < n and repl [i ] != '>' :
369
+ while i < n and repl [i ] != ( b '>' if self . __binary else '>' ) :
353
370
digits_only = digits_only and repl [i ].isdigit ()
354
371
i += 1
355
372
if i < n :
356
373
# found '>'
357
- return repl [pos + 1 : i ], i , digits_only
374
+ group_ref = repl [pos + 1 : i ]
375
+ group_ref_str = group_ref .decode ('ascii' ) if self .__binary else group_ref
376
+ return group_ref_str , i , digits_only
358
377
return None , pos , False
359
378
360
379
361
380
def sub (self , repl , string , count = 0 ):
381
+ self .__check_input_type (string )
362
382
n = 0
363
383
pattern = self .__tregex_compile (self .pattern )
364
384
result = []
365
385
pos = 0
366
- is_string_rep = isinstance (repl , str ) or isinstance (repl , bytes ) or isinstance ( repl , bytearray )
386
+ is_string_rep = isinstance (repl , str ) or _is_bytes_like (repl )
367
387
if is_string_rep :
388
+ self .__check_input_type (repl )
368
389
repl = _process_escape_sequences (repl )
369
390
while (count == 0 or n < count ) and pos <= len (string ):
370
391
match_result = tregex_call_exec (pattern .exec , string , pos )
@@ -386,7 +407,10 @@ def sub(self, repl, string, count=0):
386
407
result .append (string [pos ])
387
408
pos = pos + 1
388
409
result .append (string [pos :])
389
- return "" .join (result )
410
+ if self .__binary :
411
+ return b"" .join (result )
412
+ else :
413
+ return "" .join (result )
390
414
391
415
def split (self , string , maxsplit = 0 ):
392
416
n = 0
0 commit comments