Skip to content

Commit dc2eb11

Browse files
committed
Fix group replacement handling
1 parent 00c3c14 commit dc2eb11

File tree

3 files changed

+20
-152
lines changed

3 files changed

+20
-152
lines changed

graalpython/com.oracle.graal.python.test/src/tests/unittest_tags/test_re.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
*graalpython.lib-python.3.test.test_re.ReTests.test_re_escape_non_ascii_bytes
7171
*graalpython.lib-python.3.test.test_re.ReTests.test_re_findall
7272
*graalpython.lib-python.3.test.test_re.ReTests.test_re_groupref_exists
73+
*graalpython.lib-python.3.test.test_re.ReTests.test_re_groupref_overflow
7374
*graalpython.lib-python.3.test.test_re.ReTests.test_re_match
7475
*graalpython.lib-python.3.test.test_re.ReTests.test_re_split
7576
*graalpython.lib-python.3.test.test_re.ReTests.test_re_subn
@@ -80,6 +81,8 @@
8081
*graalpython.lib-python.3.test.test_re.ReTests.test_search_star_plus
8182
*graalpython.lib-python.3.test.test_re.ReTests.test_special_escapes
8283
*graalpython.lib-python.3.test.test_re.ReTests.test_stack_overflow
84+
*graalpython.lib-python.3.test.test_re.ReTests.test_sub_template_numeric_escape
8385
*graalpython.lib-python.3.test.test_re.ReTests.test_symbolic_groups
86+
*graalpython.lib-python.3.test.test_re.ReTests.test_symbolic_refs
8487
*graalpython.lib-python.3.test.test_re.ReTests.test_unlimited_zero_width_repeat
8588
*graalpython.lib-python.3.test.test_re.ReTests.test_weakref

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/SREModuleBuiltins.java

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -44,38 +44,27 @@
4444
import static com.oracle.graal.python.runtime.exception.PythonErrorType.ValueError;
4545

4646
import java.io.UnsupportedEncodingException;
47-
import java.nio.charset.StandardCharsets;
4847
import java.util.List;
4948

5049
import com.oracle.graal.python.PythonLanguage;
5150
import com.oracle.graal.python.builtins.Builtin;
5251
import com.oracle.graal.python.builtins.CoreFunctions;
5352
import com.oracle.graal.python.builtins.PythonBuiltins;
54-
import com.oracle.graal.python.builtins.objects.bytes.BytesUtils;
55-
import com.oracle.graal.python.builtins.objects.bytes.PBytesLike;
56-
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes;
57-
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodesFactory.ToByteArrayNodeGen;
5853
import com.oracle.graal.python.builtins.objects.function.PFunction;
5954
import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
60-
import com.oracle.graal.python.builtins.objects.str.PString;
61-
import com.oracle.graal.python.nodes.ErrorMessages;
6255
import com.oracle.graal.python.nodes.call.CallNode;
6356
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
6457
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
65-
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
6658
import com.oracle.graal.python.nodes.truffle.PythonArithmeticTypes;
6759
import com.oracle.graal.python.nodes.util.CastToJavaStringNode;
6860
import com.oracle.graal.python.runtime.ExecutionContext.IndirectCallContext;
6961
import com.oracle.graal.python.runtime.PythonContext;
7062
import com.oracle.graal.python.runtime.PythonCore;
7163
import com.oracle.graal.python.runtime.PythonOptions;
72-
import com.oracle.graal.python.runtime.exception.PythonErrorType;
73-
import com.oracle.truffle.api.CompilerAsserts;
7464
import com.oracle.truffle.api.CompilerDirectives;
7565
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
7666
import com.oracle.truffle.api.dsl.Cached;
7767
import com.oracle.truffle.api.dsl.CachedContext;
78-
import com.oracle.truffle.api.dsl.Fallback;
7968
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
8069
import com.oracle.truffle.api.dsl.NodeFactory;
8170
import com.oracle.truffle.api.dsl.Specialization;
@@ -105,79 +94,6 @@ public void initialize(PythonCore core) {
10594
super.initialize(core);
10695
}
10796

108-
/**
109-
* Replaces any <it>quoted</it> escape sequence like {@code "\\n"} (two characters; backslash +
110-
* 'n') by its single character like {@code "\n"} (one character; newline).
111-
*/
112-
@Builtin(name = "_process_escape_sequences", minNumOfPositionalArgs = 1)
113-
@GenerateNodeFactory
114-
abstract static class ProcessEscapeSequences extends PythonUnaryBuiltinNode {
115-
116-
@Child private SequenceStorageNodes.ToByteArrayNode toByteArrayNode;
117-
118-
@Specialization
119-
Object run(PString str) {
120-
return run(str.getValue());
121-
}
122-
123-
@Specialization
124-
@TruffleBoundary(transferToInterpreterOnException = false, allowInlining = true)
125-
Object run(String str) {
126-
if (containsBackslash(str)) {
127-
StringBuilder sb = BytesUtils.decodeEscapes(getCore(), str, true);
128-
return sb.toString();
129-
}
130-
return str;
131-
}
132-
133-
@Specialization
134-
Object run(PBytesLike str) {
135-
byte[] bytes = doBytes(getToByteArrayNode().execute(str.getSequenceStorage()));
136-
return factory().createByteArray(bytes);
137-
}
138-
139-
@Specialization(guards = "bufferLib.isBuffer(buffer)", limit = "3")
140-
Object run(Object buffer,
141-
@CachedLibrary("buffer") PythonObjectLibrary bufferLib) {
142-
byte[] bytes;
143-
try {
144-
bytes = bufferLib.getBufferBytes(buffer);
145-
} catch (UnsupportedMessageException e) {
146-
throw CompilerDirectives.shouldNotReachHere();
147-
}
148-
return factory().createByteArray(bytes);
149-
}
150-
151-
@TruffleBoundary(transferToInterpreterOnException = false, allowInlining = true)
152-
private byte[] doBytes(byte[] str) {
153-
StringBuilder sb = BytesUtils.decodeEscapes(getCore(), new String(str, StandardCharsets.US_ASCII), true);
154-
return sb.toString().getBytes(StandardCharsets.US_ASCII);
155-
}
156-
157-
private static boolean containsBackslash(String str) {
158-
CompilerAsserts.neverPartOfCompilation();
159-
for (int i = 0; i < str.length(); i++) {
160-
if (str.charAt(i) == '\\') {
161-
return true;
162-
}
163-
}
164-
return false;
165-
}
166-
167-
@Fallback
168-
Object run(Object o) {
169-
throw raise(PythonErrorType.TypeError, ErrorMessages.EXPECTED_S_NOT_P, "string", o);
170-
}
171-
172-
private SequenceStorageNodes.ToByteArrayNode getToByteArrayNode() {
173-
if (toByteArrayNode == null) {
174-
CompilerDirectives.transferToInterpreterAndInvalidate();
175-
toByteArrayNode = insert(ToByteArrayNodeGen.create());
176-
}
177-
return toByteArrayNode;
178-
}
179-
}
180-
18197
abstract static class ToRegexSourceNode extends Node {
18298

18399
public abstract Source execute(Object pattern, String flags);

graalpython/lib-graalpython/_sre.py

Lines changed: 17 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -394,64 +394,6 @@ def findall(self, string, pos=0, endpos=-1):
394394
pos = result.getEnd(0) + no_progress
395395
return matchlist
396396

397-
def __replace_groups(self, repl, string, match_result, pattern):
398-
def group(pattern, match_result, group_nr, string):
399-
if group_nr >= pattern.groupCount:
400-
return None
401-
group_start = match_result.getStart(group_nr)
402-
group_end = match_result.getEnd(group_nr)
403-
return string[group_start:group_end]
404-
405-
n = len(repl)
406-
result = b"" if self.__binary else ""
407-
start = 0
408-
backslash = b'\\' if self.__binary else '\\'
409-
pos = repl.find(backslash, start)
410-
while pos != -1 and start < n:
411-
if pos+1 < n:
412-
c = repl[pos + 1:pos + 2].decode('ascii') if self.__binary else repl[pos + 1]
413-
if c.isdigit() and pattern.groupCount > 0:
414-
# TODO: Should handle backreferences longer than 1 digit and fall back to octal escapes.
415-
group_nr = int(c)
416-
group_str = group(pattern, match_result, group_nr, string)
417-
if group_str is None:
418-
raise error("invalid group reference %s" % group_nr)
419-
result += repl[start:pos] + group_str
420-
start = pos + 2
421-
elif c == 'g':
422-
group_ref, group_ref_end, digits_only = self.__extract_groupname(repl, pos + 2)
423-
if group_ref:
424-
group_str = group(pattern, match_result, int(group_ref) if digits_only else pattern.groups[group_ref], string)
425-
if group_str is None:
426-
raise error("invalid group reference %s" % group_ref)
427-
result += repl[start:pos] + group_str
428-
start = group_ref_end + 1
429-
elif c == '\\':
430-
result += repl[start:pos] + backslash
431-
start = pos + 2
432-
else:
433-
assert False, "unexpected escape in re.sub"
434-
pos = repl.find(backslash, start)
435-
result += repl[start:]
436-
return result
437-
438-
439-
def __extract_groupname(self, repl, pos):
440-
if repl[pos] == (b'<' if self.__binary else '<'):
441-
digits_only = True
442-
n = len(repl)
443-
i = pos + 1
444-
while i < n and repl[i] != (b'>' if self.__binary else '>'):
445-
digits_only = digits_only and repl[i].isdigit()
446-
i += 1
447-
if i < n:
448-
# found '>'
449-
group_ref = repl[pos + 1 : i]
450-
group_ref_str = group_ref.decode('ascii') if self.__binary else group_ref
451-
return group_ref_str, i, digits_only
452-
return None, pos, False
453-
454-
455397
def sub(self, repl, string, count=0):
456398
return self.subn(repl, string, count)[0]
457399

@@ -461,13 +403,20 @@ def subn(self, repl, string, count=0):
461403
pattern = self.__tregex_compile(self.pattern)
462404
result = []
463405
pos = 0
464-
is_string_rep = isinstance(repl, str) or _is_bytes_like(repl)
465-
if is_string_rep:
406+
literal = False
407+
if not callable(repl):
466408
self.__check_input_type(repl)
467-
try:
468-
repl = _process_escape_sequences(repl)
469-
except ValueError as e:
470-
raise error(str(e))
409+
if isinstance(repl, str):
410+
literal = '\\' not in repl
411+
else:
412+
literal = b'\\' not in repl
413+
if not literal:
414+
import sre_parse
415+
template = sre_parse.parse_template(repl, self)
416+
417+
def repl(match):
418+
return sre_parse.expand_template(template, match)
419+
471420
while (count == 0 or n < count) and pos <= len(string):
472421
match_result = tregex_call_exec(pattern.exec, string, pos)
473422
if not match_result.isMatch:
@@ -476,8 +425,8 @@ def subn(self, repl, string, count=0):
476425
start = match_result.getStart(0)
477426
end = match_result.getEnd(0)
478427
result.append(string[pos:start])
479-
if is_string_rep:
480-
result.append(self.__replace_groups(repl, string, match_result, pattern))
428+
if literal:
429+
result.append(repl)
481430
else:
482431
_srematch = SRE_Match(self, pos, -1, match_result, string, pattern)
483432
_repl = repl(_srematch)
@@ -489,9 +438,9 @@ def subn(self, repl, string, count=0):
489438
pos = pos + 1
490439
result.append(string[pos:])
491440
if self.__binary:
492-
return (b"".join(result), n)
441+
return b"".join(result), n
493442
else:
494-
return ("".join(result), n)
443+
return "".join(result), n
495444

496445
def split(self, string, maxsplit=0):
497446
n = 0

0 commit comments

Comments
 (0)