Skip to content

Commit a8fb176

Browse files
committed
Fix parsing nested tuples in PyArg_ParseTuple
1 parent acc4d43 commit a8fb176

File tree

3 files changed

+102
-25
lines changed

3 files changed

+102
-25
lines changed

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_modsupport.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ def _reference_parse_O(args):
7070
raise TypeError
7171

7272

73+
def _reference_parse_tuple(args):
74+
try:
75+
t = args[0][0]
76+
if len(t) != 2:
77+
raise TypeError
78+
return t[0], t[1]
79+
except Exception:
80+
raise TypeError
81+
82+
7383
class Indexable:
7484
def __int__(self):
7585
return 456
@@ -78,6 +88,26 @@ def __index__(self):
7888
return 123
7989

8090

91+
class MySeq:
92+
def __len__(self):
93+
return 2
94+
95+
def __getitem__(self, item):
96+
if item == 0:
97+
return 'x'
98+
elif item == 1:
99+
return 'y'
100+
else:
101+
raise IndexError
102+
103+
104+
class BadSeq:
105+
def __len__(self):
106+
return 2
107+
108+
def __getitem__(self, item):
109+
raise IndexError
110+
81111
class TestModsupport(CPyExtTestCase):
82112
def compile_module(self, name):
83113
type(self).mro()[1].__dict__["test_%s" % name].create_module(name)
@@ -151,6 +181,34 @@ def compile_module(self, name):
151181
cmpfunc=unhandled_error_compare
152182
)
153183

184+
test_parseargs_tuple = CPyExtFunction(
185+
_reference_parse_tuple,
186+
lambda: (
187+
((("a", "b"),),),
188+
((["a", "b"],),),
189+
((MySeq(),),),
190+
((["a"],),),
191+
((["a", "b", "c"],),),
192+
((1,),),
193+
((BadSeq(),),),
194+
),
195+
code='''
196+
static PyObject* wrap_PyArg_ParseTuple(PyObject* argTuple) {
197+
PyObject* a = NULL;
198+
PyObject* b = NULL;
199+
if (PyArg_ParseTuple(argTuple, "(OO)", &a, &b) == 0) {
200+
return NULL;
201+
}
202+
return Py_BuildValue("(OO)", a, b);
203+
}
204+
''',
205+
resultspec="O",
206+
argspec="O",
207+
arguments=["PyObject* argTuple"],
208+
callfunction="wrap_PyArg_ParseTuple",
209+
cmpfunc=unhandled_error_compare
210+
)
211+
154212
test_parseargs_O_conv = CPyExtFunction(
155213
lambda args: True if args[0][0] else False,
156214
lambda: (

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/common/CExtParseArgumentsNode.java

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,10 @@
7575
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
7676
import com.oracle.graal.python.lib.PyObjectIsTrueNode;
7777
import com.oracle.graal.python.lib.PyObjectSizeNode;
78+
import com.oracle.graal.python.lib.PySequenceCheckNode;
7879
import com.oracle.graal.python.nodes.ErrorMessages;
7980
import com.oracle.graal.python.nodes.PGuards;
81+
import com.oracle.graal.python.nodes.builtins.TupleNodes;
8082
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
8183
import com.oracle.graal.python.nodes.object.GetClassNode;
8284
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
@@ -139,6 +141,7 @@ public abstract class CExtParseArgumentsNode {
139141
static final char FORMAT_LOWER_W = 'w';
140142
static final char FORMAT_LOWER_P = 'p';
141143
static final char FORMAT_PAR_OPEN = '(';
144+
static final char FORMAT_PAR_CLOSE = ')';
142145

143146
@GenerateUncached
144147
@ImportStatic(PGuards.class)
@@ -241,15 +244,8 @@ private static ParserState convertArg(ParserState state, Object kwds, char[] for
241244
case FORMAT_LOWER_W:
242245
case FORMAT_LOWER_P:
243246
case FORMAT_PAR_OPEN:
247+
case FORMAT_PAR_CLOSE:
244248
return convertArgNode.execute(state, kwds, c, format, format_idx, kwdnames, varargs);
245-
case ')':
246-
if (state.v.prev == null) {
247-
CompilerDirectives.transferToInterpreter();
248-
raiseNode.raiseIntWithoutFrame(0, PythonBuiltinClassType.SystemError, ErrorMessages.LEFT_BRACKET_WO_RIGHT_BRACKET_IN_ARG);
249-
throw ParseArgumentsException.raise();
250-
} else {
251-
return state.close();
252-
}
253249
case '|':
254250
if (state.restOptional) {
255251
raiseNode.raiseIntWithoutFrame(0, SystemError, "Invalid format string (| specified twice)", c);
@@ -1039,8 +1035,10 @@ static ParserState doPredicate(ParserState state, Object kwds, @SuppressWarnings
10391035
}
10401036

10411037
@Specialization(guards = "c == FORMAT_PAR_OPEN")
1042-
static ParserState doPredicate(ParserState state, Object kwds, @SuppressWarnings("unused") char c, @SuppressWarnings("unused") char[] format, @SuppressWarnings("unused") int format_idx,
1038+
static ParserState doParOpen(ParserState state, Object kwds, @SuppressWarnings("unused") char c, @SuppressWarnings("unused") char[] format, @SuppressWarnings("unused") int format_idx,
10431039
Object kwdnames, @SuppressWarnings("unused") Object varargs,
1040+
@Cached PySequenceCheckNode sequenceCheckNode,
1041+
@Cached TupleNodes.ConstructTupleNode constructTupleNode,
10441042
@Cached PythonObjectFactory factory,
10451043
@Shared("getArgNode") @Cached GetArgNode getArgNode,
10461044
@Shared("raiseNode") @Cached PRaiseNativeNode raiseNode) throws InteropException, ParseArgumentsException {
@@ -1049,14 +1047,33 @@ static ParserState doPredicate(ParserState state, Object kwds, @SuppressWarnings
10491047
if (skipOptionalArg(arg, state.restOptional)) {
10501048
return state.open(new PositionalArgStack(factory.createEmptyTuple(), state.v));
10511049
} else {
1052-
// n.b.: there is a small gap in this check: In theory, there could be
1053-
// native subclass of tuple. But since we do not support this anyway, the
1054-
// instanceof test is just the most efficient way to do it.
1055-
if (!(arg instanceof PTuple)) {
1050+
if (!sequenceCheckNode.execute(arg)) {
10561051
throw raise(raiseNode, TypeError, ErrorMessages.EXPECTED_S_GOT_P, "tuple", arg);
10571052
}
1058-
return state.open(new PositionalArgStack((PTuple) arg, state.v));
1053+
try {
1054+
return state.open(new PositionalArgStack(constructTupleNode.execute(null, arg), state.v));
1055+
} catch (PException e) {
1056+
throw raise(raiseNode, TypeError, "failed to convert sequence");
1057+
}
1058+
}
1059+
}
1060+
1061+
@Specialization(guards = "c == FORMAT_PAR_CLOSE")
1062+
static ParserState doParClose(ParserState state, @SuppressWarnings("unused") Object kwds, @SuppressWarnings("unused") char c, @SuppressWarnings("unused") char[] format,
1063+
@SuppressWarnings("unused") int format_idx,
1064+
@SuppressWarnings("unused") Object kwdnames, @SuppressWarnings("unused") Object varargs,
1065+
@Cached SequenceStorageNodes.LenNode lenNode,
1066+
@Shared("raiseNode") @Cached PRaiseNativeNode raiseNode) throws ParseArgumentsException {
1067+
if (state.v.prev == null) {
1068+
CompilerDirectives.transferToInterpreter();
1069+
raiseNode.raiseIntWithoutFrame(0, PythonBuiltinClassType.SystemError, ErrorMessages.LEFT_BRACKET_WO_RIGHT_BRACKET_IN_ARG);
1070+
throw ParseArgumentsException.raise();
1071+
}
1072+
int len = lenNode.execute(state.v.argv.getSequenceStorage());
1073+
if (len != state.v.argnum) {
1074+
throw raise(raiseNode, TypeError, "must be sequence of length %d, not %d", state.v.argnum, len);
10591075
}
1076+
return state.close();
10601077
}
10611078

10621079
private static ParseArgumentsException raise(PRaiseNativeNode raiseNode, PythonBuiltinClassType errType, String format, Object... arguments) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/builtins/TupleNodes.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,55 +48,57 @@
4848
import com.oracle.graal.python.lib.PyObjectGetIter;
4949
import com.oracle.graal.python.nodes.PGuards;
5050
import com.oracle.graal.python.nodes.PNodeWithContext;
51-
import com.oracle.graal.python.nodes.SpecialMethodNames;
5251
import com.oracle.graal.python.nodes.object.GetClassNode;
5352
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
5453
import com.oracle.truffle.api.CompilerDirectives;
5554
import com.oracle.truffle.api.dsl.Cached;
55+
import com.oracle.truffle.api.dsl.Cached.Shared;
5656
import com.oracle.truffle.api.dsl.Fallback;
57-
import com.oracle.truffle.api.dsl.ImportStatic;
57+
import com.oracle.truffle.api.dsl.GenerateUncached;
5858
import com.oracle.truffle.api.dsl.Specialization;
59+
import com.oracle.truffle.api.frame.Frame;
5960
import com.oracle.truffle.api.frame.VirtualFrame;
6061

6162
public abstract class TupleNodes {
6263

63-
@ImportStatic({PGuards.class, SpecialMethodNames.class})
64+
@GenerateUncached
6465
public abstract static class ConstructTupleNode extends PNodeWithContext {
65-
@Child private PythonObjectFactory factory = PythonObjectFactory.create();
66-
6766
public final PTuple execute(VirtualFrame frame, Object value) {
6867
return execute(frame, PythonBuiltinClassType.PTuple, value);
6968
}
7069

71-
public abstract PTuple execute(VirtualFrame frame, Object cls, Object value);
70+
public abstract PTuple execute(Frame frame, Object cls, Object value);
7271

7372
@Specialization(guards = "isNoValue(none)")
74-
PTuple tuple(Object cls, @SuppressWarnings("unused") PNone none) {
73+
static PTuple tuple(Object cls, @SuppressWarnings("unused") PNone none,
74+
@Shared("factory") @Cached PythonObjectFactory factory) {
7575
return factory.createEmptyTuple(cls);
7676
}
7777

7878
@Specialization
79-
PTuple tuple(Object cls, String arg) {
79+
static PTuple tuple(Object cls, String arg,
80+
@Shared("factory") @Cached PythonObjectFactory factory) {
8081
return factory.createTuple(cls, StringUtils.toCharacterArray(arg));
8182
}
8283

8384
@Specialization(guards = {"cannotBeOverridden(cls)", "cannotBeOverridden(iterable, getClassNode)"}, limit = "1")
84-
PTuple tuple(@SuppressWarnings("unused") Object cls, PTuple iterable,
85+
static PTuple tuple(@SuppressWarnings("unused") Object cls, PTuple iterable,
8586
@SuppressWarnings("unused") @Cached GetClassNode getClassNode) {
8687
return iterable;
8788
}
8889

8990
@Specialization(guards = {"!isNoValue(iterable)", "createNewTuple(cls, iterable, getClassNode)"}, limit = "1")
90-
PTuple tuple(VirtualFrame frame, Object cls, Object iterable,
91+
static PTuple tuple(VirtualFrame frame, Object cls, Object iterable,
9192
@SuppressWarnings("unused") @Cached GetClassNode getClassNode,
93+
@Shared("factory") @Cached PythonObjectFactory factory,
9294
@Cached CreateStorageFromIteratorNode storageNode,
9395
@Cached PyObjectGetIter getIter) {
9496
Object iterObj = getIter.execute(frame, iterable);
9597
return factory.createTuple(cls, storageNode.execute(frame, iterObj));
9698
}
9799

98100
@Fallback
99-
public PTuple tuple(@SuppressWarnings("unused") Object cls, Object value) {
101+
static PTuple tuple(@SuppressWarnings("unused") Object cls, Object value) {
100102
CompilerDirectives.transferToInterpreter();
101103
throw new RuntimeException("tuple does not support iterable object " + value);
102104
}

0 commit comments

Comments
 (0)