Skip to content

Commit 5bc9a99

Browse files
committed
Merge branch 'tim/sklearn-linreg' of ssh://ol-bitbucket.us.oracle.com:7999/g/graalpython into tim/sklearn-linreg
2 parents b3187c9 + 5e23c12 commit 5bc9a99

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/SREModuleBuiltins.java

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,67 @@
3838
*/
3939
package com.oracle.graal.python.builtins.modules;
4040

41+
import java.util.ArrayList;
42+
import java.util.List;
43+
import java.util.regex.Matcher;
44+
import java.util.regex.Pattern;
45+
46+
import com.oracle.graal.python.builtins.Builtin;
4147
import com.oracle.graal.python.builtins.CoreFunctions;
4248
import com.oracle.graal.python.builtins.PythonBuiltins;
49+
import com.oracle.graal.python.builtins.objects.str.PString;
4350
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
51+
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
52+
import com.oracle.graal.python.runtime.exception.PythonErrorType;
53+
import com.oracle.truffle.api.CompilerDirectives;
54+
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
55+
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
56+
import com.oracle.truffle.api.dsl.Fallback;
57+
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
4458
import com.oracle.truffle.api.dsl.NodeFactory;
45-
46-
import java.util.ArrayList;
47-
import java.util.List;
59+
import com.oracle.truffle.api.dsl.Specialization;
4860

4961
@CoreFunctions(defineModule = "_sre")
5062
public class SREModuleBuiltins extends PythonBuiltins {
5163
@Override
5264
protected List<? extends NodeFactory<? extends PythonBuiltinNode>> getNodeFactories() {
5365
return new ArrayList<>();
5466
}
67+
68+
/**
69+
* Called from C when they actually want a {@code const char*} for a Python string
70+
*/
71+
@Builtin(name = "tregex_preprocess", fixedNumOfArguments = 1)
72+
@GenerateNodeFactory
73+
abstract static class TregexPreprocessNode extends PythonUnaryBuiltinNode {
74+
@CompilationFinal private Pattern pattern;
75+
76+
@Specialization
77+
Object run(PString str) {
78+
return run(str.getValue());
79+
}
80+
81+
@Specialization
82+
Object run(String str) {
83+
str.replaceAll("[^\\[]?#[^\\]]*\n", "");
84+
if (pattern == null) {
85+
CompilerDirectives.transferToInterpreterAndInvalidate();
86+
pattern = Pattern.compile("(?<CMT>#[^\\]]*\n)");
87+
}
88+
return replaceAll(str);
89+
}
90+
91+
@TruffleBoundary
92+
private String replaceAll(String r) {
93+
Matcher matcher = pattern.matcher(r);
94+
return matcher.replaceAll("");
95+
}
96+
97+
@Fallback
98+
Object run(Object o) {
99+
throw raise(PythonErrorType.TypeError, "expected string, not %p", o);
100+
}
101+
102+
}
103+
55104
}

graalpython/lib-graalpython/_sre.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def lastindex(self):
103103

104104
class SRE_Pattern():
105105
def __init__(self, pattern, flags, code, groups=0, groupindex=None, indexgroup=None):
106-
self.pattern = self._decode_string(pattern)
106+
self.pattern = self._decode_string(pattern, flags)
107107
self.flags = flags
108108
self.code = code
109109
self.num_groups = groups
@@ -115,15 +115,22 @@ def __init__(self, pattern, flags, code, groups=0, groupindex=None, indexgroup=N
115115
jsflags.append(jsflag)
116116
self.jsflags = "".join(jsflags)
117117

118-
def _decode_string(self, string):
118+
def _decode_string(self, string, flags):
119119
if isinstance(string, str):
120120
pattern = string
121121
elif isinstance(string, bytes):
122122
pattern = string.decode()
123123
else:
124124
raise TypeError("invalid search pattern {!r}".format(string))
125125
# TODO: fix this in the regex engine
126-
return pattern.replace(r'\"', '"').replace(r"\'", "'")
126+
pattern = pattern.replace(r'\"', '"').replace(r"\'", "'")
127+
128+
# TODO: that's not nearly complete but should be sufficient for now
129+
from sre_compile import SRE_FLAG_VERBOSE
130+
if flags & SRE_FLAG_VERBOSE:
131+
pattern = _sre.tregex_preprocess(pattern)
132+
return pattern
133+
127134

128135
def __repr__(self):
129136
flags = self.flags

0 commit comments

Comments
 (0)