Skip to content

Commit b363e6d

Browse files
committed
Check input type in TRegex-compiled regular expressions
Also support memoryviews as arguments for bytes patterns.
1 parent de646ae commit b363e6d

File tree

2 files changed

+53
-10
lines changed

2 files changed

+53
-10
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/SREModuleBuiltins.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,11 @@
5151
import com.oracle.graal.python.builtins.Builtin;
5252
import com.oracle.graal.python.builtins.CoreFunctions;
5353
import com.oracle.graal.python.builtins.PythonBuiltins;
54+
import com.oracle.graal.python.builtins.objects.bytes.BytesNodes;
5455
import com.oracle.graal.python.builtins.objects.bytes.BytesUtils;
5556
import com.oracle.graal.python.builtins.objects.bytes.PIBytesLike;
5657
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
58+
import com.oracle.graal.python.builtins.objects.memoryview.PMemoryView;
5759
import com.oracle.graal.python.builtins.objects.str.PString;
5860
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5961
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
@@ -96,6 +98,7 @@ protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFa
9698
abstract static class ProcessEscapeSequences extends PythonUnaryBuiltinNode {
9799

98100
@Child private SequenceStorageNodes.ToByteArrayNode toByteArrayNode;
101+
@Child private BytesNodes.ToBytesNode toBytesNode;
99102

100103
@CompilationFinal private Pattern namedCaptGroupPattern;
101104

@@ -123,6 +126,15 @@ Object run(PIBytesLike str) {
123126
return str;
124127
}
125128

129+
@Specialization
130+
Object run(PMemoryView memoryView) {
131+
byte[] bytes = doBytes(getToBytesNode().execute(memoryView));
132+
if (bytes != null) {
133+
return factory().createByteArray(bytes);
134+
}
135+
return memoryView;
136+
}
137+
126138
@TruffleBoundary(transferToInterpreterOnException = false, allowInlining = true)
127139
private byte[] doBytes(byte[] str) {
128140
try {
@@ -156,6 +168,13 @@ private SequenceStorageNodes.ToByteArrayNode getToByteArrayNode() {
156168
return toByteArrayNode;
157169
}
158170

171+
private BytesNodes.ToBytesNode getToBytesNode() {
172+
if (toBytesNode == null) {
173+
CompilerDirectives.transferToInterpreterAndInvalidate();
174+
toBytesNode = insert(BytesNodes.ToBytesNode.create());
175+
}
176+
return toBytesNode;
177+
}
159178
}
160179

161180
@Builtin(name = "tregex_call_compile", fixedNumOfPositionalArgs = 3)

graalpython/lib-graalpython/_sre.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,12 @@ def _append_end_assert(pattern):
203203
else:
204204
return pattern if pattern.endswith(rb"\Z") else pattern + rb"\Z"
205205

206+
def _is_bytes_like(object):
207+
return isinstance(object, (bytes, bytearray, memoryview))
208+
206209
class SRE_Pattern():
207210
def __init__(self, pattern, flags):
211+
self.__binary = isinstance(pattern, bytes)
208212
self.pattern = pattern
209213
self.flags = flags
210214
flags_str = []
@@ -220,9 +224,18 @@ def __init__(self, pattern, flags):
220224
self.groupindex[group_name] = self.__compiled_regexes[self.pattern].groups[group_name]
221225

222226

227+
def __check_input_type(self, input):
228+
if not isinstance(input, str) and not _is_bytes_like(input):
229+
raise TypeError("expected string or bytes-like object")
230+
if not self.__binary and _is_bytes_like(input):
231+
raise TypeError("cannot use a string pattern on a bytes-like object")
232+
if self.__binary and isinstance(input, str):
233+
raise TypeError("cannot use a bytes pattern on a string-like object")
234+
235+
223236
def __tregex_compile(self, pattern):
224237
if pattern not in self.__compiled_regexes:
225-
tregex_engine = TREGEX_ENGINE_STR if isinstance(pattern, str) else TREGEX_ENGINE_BYTES
238+
tregex_engine = TREGEX_ENGINE_BYTES if self.__binary else TREGEX_ENGINE_STR
226239
try:
227240
self.__compiled_regexes[pattern] = tregex_call_compile(tregex_engine, pattern, self.flags_str)
228241
except ValueError as e:
@@ -266,12 +279,15 @@ def _search(self, pattern, string, pos, endpos):
266279
return None
267280

268281
def search(self, string, pos=0, endpos=None):
282+
self.__check_input_type(string)
269283
return self._search(self.pattern, string, pos, default(endpos, -1))
270284

271285
def match(self, string, pos=0, endpos=None):
286+
self.__check_input_type(string)
272287
return self._search(_prepend_begin_assert(self.pattern), string, pos, default(endpos, -1))
273288

274289
def fullmatch(self, string, pos=0, endpos=None):
290+
self.__check_input_type(string)
275291
return self._search(_append_end_assert(_prepend_begin_assert(self.pattern)), string, pos, default(endpos, -1))
276292

277293
def __sanitize_out_type(self, elem):
@@ -283,6 +299,7 @@ def __sanitize_out_type(self, elem):
283299
return str(elem)
284300

285301
def findall(self, string, pos=0, endpos=-1):
302+
self.__check_input_type(string)
286303
if endpos > len(string):
287304
endpos = len(string)
288305
elif endpos < 0:
@@ -312,20 +329,20 @@ def group(match_result, group_nr, string):
312329
return string[group_start:group_end]
313330

314331
n = len(repl)
315-
result = ""
332+
result = b"" if self.__binary else ""
316333
start = 0
317-
backslash = '\\'
334+
backslash = b'\\' if self.__binary else '\\'
318335
pos = repl.find(backslash, start)
319336
while pos != -1 and start < n:
320337
if pos+1 < n:
321338
if repl[pos + 1].isdigit() and match_result.groupCount > 0:
322-
group_nr = int(repl[pos+1])
339+
group_nr = int(repl[pos+1].decode('ascii')) if self.__binary else int(repl[pos+1])
323340
group_str = group(match_result, group_nr, string)
324341
if group_str is None:
325342
raise ValueError("invalid group reference %s at position %s" % (group_nr, pos))
326343
result += repl[start:pos] + group_str
327344
start = pos + 2
328-
elif repl[pos + 1] == 'g':
345+
elif repl[pos + 1] == (b'g' if self.__binary else 'g'):
329346
group_ref, group_ref_end, digits_only = self.__extract_groupname(repl, pos + 2)
330347
if group_ref:
331348
group_str = group(match_result, int(group_ref) if digits_only else pattern.groups[group_ref], string)
@@ -345,26 +362,30 @@ def group(match_result, group_nr, string):
345362

346363

347364
def __extract_groupname(self, repl, pos):
348-
if repl[pos] == '<':
365+
if repl[pos] == (b'<' if self.__binary else '<'):
349366
digits_only = True
350367
n = len(repl)
351368
i = pos + 1
352-
while i < n and repl[i] != '>':
369+
while i < n and repl[i] != (b'>' if self.__binary else '>'):
353370
digits_only = digits_only and repl[i].isdigit()
354371
i += 1
355372
if i < n:
356373
# found '>'
357-
return repl[pos + 1 : i], i, digits_only
374+
group_ref = repl[pos + 1 : i]
375+
group_ref_str = group_ref.decode('ascii') if self.__binary else group_ref
376+
return group_ref_str, i, digits_only
358377
return None, pos, False
359378

360379

361380
def sub(self, repl, string, count=0):
381+
self.__check_input_type(string)
362382
n = 0
363383
pattern = self.__tregex_compile(self.pattern)
364384
result = []
365385
pos = 0
366-
is_string_rep = isinstance(repl, str) or isinstance(repl, bytes) or isinstance(repl, bytearray)
386+
is_string_rep = isinstance(repl, str) or _is_bytes_like(repl)
367387
if is_string_rep:
388+
self.__check_input_type(repl)
368389
repl = _process_escape_sequences(repl)
369390
while (count == 0 or n < count) and pos <= len(string):
370391
match_result = tregex_call_exec(pattern.exec, string, pos)
@@ -386,7 +407,10 @@ def sub(self, repl, string, count=0):
386407
result.append(string[pos])
387408
pos = pos + 1
388409
result.append(string[pos:])
389-
return "".join(result)
410+
if self.__binary:
411+
return b"".join(result)
412+
else:
413+
return "".join(result)
390414

391415
def split(self, string, maxsplit=0):
392416
n = 0

0 commit comments

Comments
 (0)