Skip to content

Commit 050a8ed

Browse files
committed
Handle Pattern.match using the sticky flag
1 parent 509938f commit 050a8ed

File tree

1 file changed

+23
-22
lines changed

1 file changed

+23
-22
lines changed

graalpython/lib-graalpython/_sre.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,20 @@ def setup(sre_compiler, error_class, flags_table):
7070

7171
def configure_fallback_compiler(mode):
7272
def fallback_compiler(pattern, flags):
73+
sticky = False
7374
bit_flags = 0
7475
for flag in flags:
75-
bit_flags = bit_flags | FLAGS[flag]
76+
# Handle internal stick(y) flag used to signal matching only at the start of input.
77+
if flag == "y":
78+
sticky = True
79+
else:
80+
bit_flags = bit_flags | FLAGS[flag]
7681

7782
compiled_pattern = sre_compiler(pattern if mode == "str" else _str_to_bytes(pattern), bit_flags)
7883

7984
def executable_pattern(regex_object, input, from_index):
80-
result = compiled_pattern.search(input, from_index)
85+
search_method = compiled_pattern.match if sticky else compiled_pattern.search
86+
result = search_method(input, from_index)
8187
is_match = result is not None
8288
group_count = 1 + compiled_pattern.groups
8389
return _RegexResult(
@@ -191,12 +197,6 @@ def lastindex(self):
191197
def __repr__(self):
192198
return "<re.Match object; span=%r, match=%r>" % (self.span(), self.group())
193199

194-
def _prepend_begin_assert(pattern):
195-
if isinstance(pattern, str):
196-
return pattern if pattern.startswith(r"\A") else r"\A" + pattern
197-
else:
198-
return pattern if pattern.startswith(rb"\A") else rb"\A" + pattern
199-
200200
def _append_end_assert(pattern):
201201
if isinstance(pattern, str):
202202
return pattern if pattern.endswith(r"\Z") else pattern + r"\Z"
@@ -217,11 +217,10 @@ def __init__(self, pattern, flags):
217217
flags_str.append(char)
218218
self.flags_str = "".join(flags_str)
219219
self.__compiled_regexes = dict()
220-
self.__tregex_compile(pattern)
221220
self.groupindex = dict()
222-
if self.__compiled_regexes[self.pattern].groups is not None:
223-
for group_name in dir(self.__compiled_regexes[self.pattern].groups):
224-
self.groupindex[group_name] = self.__compiled_regexes[self.pattern].groups[group_name]
221+
if self.__tregex_compile(self.pattern).groups is not None:
222+
for group_name in dir(self.__tregex_compile(self.pattern).groups):
223+
self.groupindex[group_name] = self.__tregex_compile(self.pattern).groups[group_name]
225224

226225

227226
def __check_input_type(self, input):
@@ -233,11 +232,13 @@ def __check_input_type(self, input):
233232
raise TypeError("cannot use a bytes pattern on a string-like object")
234233

235234

236-
def __tregex_compile(self, pattern):
237-
if pattern not in self.__compiled_regexes:
235+
def __tregex_compile(self, pattern, flags=None):
236+
if flags is None:
237+
flags = self.flags_str
238+
if (pattern, flags) not in self.__compiled_regexes:
238239
tregex_engine = TREGEX_ENGINE_BYTES if self.__binary else TREGEX_ENGINE_STR
239240
try:
240-
self.__compiled_regexes[pattern] = tregex_call_compile(tregex_engine, pattern, self.flags_str)
241+
self.__compiled_regexes[(pattern, flags)] = tregex_call_compile(tregex_engine, pattern, flags)
241242
except ValueError as e:
242243
message = str(e)
243244
boundary = message.rfind(" at position ")
@@ -247,7 +248,7 @@ def __tregex_compile(self, pattern):
247248
position = int(message[boundary + len(" at position "):])
248249
message = message[:boundary]
249250
raise error(message, pattern, position)
250-
return self.__compiled_regexes[pattern]
251+
return self.__compiled_regexes[(pattern, flags)]
251252

252253

253254
def __repr__(self):
@@ -267,12 +268,12 @@ def __repr__(self):
267268
sflags = "|".join(flag_items)
268269
return "re.compile(%s%s%s)" % (self.pattern, sep, sflags)
269270

270-
def _search(self, pattern, string, pos, endpos):
271-
pattern = self.__tregex_compile(pattern)
271+
def _search(self, pattern, string, pos, endpos, sticky=False):
272+
pattern = self.__tregex_compile(pattern, self.flags_str + ("y" if sticky else ""))
272273
if endpos == -1 or endpos >= len(string):
273-
result = tregex_call_exec(pattern.exec, string, pos)
274+
result = tregex_call_exec(pattern.exec, string, min(pos, len(string) + 1))
274275
else:
275-
result = tregex_call_exec(pattern.exec, string[:endpos], pos)
276+
result = tregex_call_exec(pattern.exec, string[:endpos], min(pos, endpos % len(string) + 1))
276277
if result.isMatch:
277278
return SRE_Match(self, pos, endpos, result)
278279
else:
@@ -284,11 +285,11 @@ def search(self, string, pos=0, endpos=None):
284285

285286
def match(self, string, pos=0, endpos=None):
286287
self.__check_input_type(string)
287-
return self._search(_prepend_begin_assert(self.pattern), string, pos, default(endpos, -1))
288+
return self._search(self.pattern, string, pos, default(endpos, -1), sticky=True)
288289

289290
def fullmatch(self, string, pos=0, endpos=None):
290291
self.__check_input_type(string)
291-
return self._search(_append_end_assert(_prepend_begin_assert(self.pattern)), string, pos, default(endpos, -1))
292+
return self._search(_append_end_assert(self.pattern), string, pos, default(endpos, -1), sticky=True)
292293

293294
def __sanitize_out_type(self, elem):
294295
"""Helper function for findall and split. Ensures that the type of the elements of the

0 commit comments

Comments
 (0)