Skip to content

Commit 1e32f02

Browse files
committed
intrinsified itertools.cycle()
1 parent e5fa2f0 commit 1e32f02

File tree

7 files changed

+390
-68
lines changed

7 files changed

+390
-68
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/Python3Core.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@
200200
import com.oracle.graal.python.builtins.objects.itertools.CombinationsBuiltins;
201201
import com.oracle.graal.python.builtins.objects.itertools.CompressBuiltins;
202202
import com.oracle.graal.python.builtins.objects.itertools.CountBuiltins;
203+
import com.oracle.graal.python.builtins.objects.itertools.CycleBuiltins;
203204
import com.oracle.graal.python.builtins.objects.itertools.DropwhileBuiltins;
204205
import com.oracle.graal.python.builtins.objects.itertools.FilterfalseBuiltins;
205206
import com.oracle.graal.python.builtins.objects.itertools.GroupByBuiltins;
@@ -519,6 +520,7 @@ private static PythonBuiltins[] initializeBuiltins(boolean nativeAccessAllowed)
519520
new DropwhileBuiltins(),
520521
new ChainBuiltins(),
521522
new CountBuiltins(),
523+
new CycleBuiltins(),
522524
new FilterfalseBuiltins(),
523525
new GroupByBuiltins(),
524526
new GrouperBuiltins(),

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/PythonBuiltinClassType.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ public enum PythonBuiltinClassType implements TruffleObject {
215215
PCombinations("combinations", "itertools"),
216216
PCombinationsWithReplacement("combinations_with_replacement", "itertools"),
217217
PCompress("compress", "itertools"),
218+
PCycle("cycle", "itertools"),
218219
PDropwhile("dropwhile", "itertools"),
219220
PFilterfalse("filterfalse", "itertools"),
220221
PGroupBy("groupby", "itertools"),

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ItertoolsModuleBuiltins.java

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import com.oracle.graal.python.builtins.objects.itertools.PCombinationsWithReplacement;
5050
import com.oracle.graal.python.builtins.objects.itertools.PCompress;
5151
import com.oracle.graal.python.builtins.objects.itertools.PCount;
52+
import com.oracle.graal.python.builtins.objects.itertools.PCycle;
5253
import com.oracle.graal.python.builtins.objects.itertools.PDropwhile;
5354
import com.oracle.graal.python.builtins.objects.itertools.PFilterfalse;
5455
import com.oracle.graal.python.builtins.objects.itertools.PGroupBy;
@@ -83,6 +84,7 @@
8384
import com.oracle.graal.python.nodes.call.special.CallVarargsMethodNode;
8485
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
8586
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
87+
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
8688
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryClinicBuiltinNode;
8789
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
8890
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryClinicBuiltinNode;
@@ -102,6 +104,7 @@
102104
import com.oracle.truffle.api.profiles.BranchProfile;
103105
import com.oracle.truffle.api.profiles.ConditionProfile;
104106
import com.oracle.truffle.api.profiles.LoopConditionProfile;
107+
import java.util.ArrayList;
105108

106109
@CoreFunctions(defineModule = "itertools")
107110
public final class ItertoolsModuleBuiltins extends PythonBuiltins {
@@ -284,6 +287,42 @@ protected Object notype(Object cls, Object[] arguments, PKeyword[] keywords,
284287
}
285288
}
286289

290+
@Builtin(name = "cycle", minNumOfPositionalArgs = 2, constructsClass = PythonBuiltinClassType.PCycle, doc = "Make an iterator returning elements from the iterable and\n" +
291+
" saving a copy of each. When the iterable is exhausted, return\n" +
292+
" elements from the saved copy. Repeats indefinitely.\n\n" +
293+
" Equivalent to :\n\n" +
294+
" def cycle(iterable):\n" +
295+
" \tsaved = []\n" +
296+
" \tfor element in iterable:\n" +
297+
" \t\tyield element\n" +
298+
" \t\tsaved.append(element)\n" +
299+
" \twhile saved:\n" +
300+
" \t\tfor element in saved:\n" +
301+
" \t\t\tyield element")
302+
@GenerateNodeFactory
303+
public abstract static class CycleNode extends PythonBinaryBuiltinNode {
304+
305+
@SuppressWarnings("unused")
306+
@Specialization(guards = "isTypeNode.execute(cls)")
307+
protected PCycle construct(VirtualFrame frame, Object cls, Object iterable,
308+
@Cached PyObjectGetIter getIter,
309+
@Cached IsTypeNode isTypeNode) {
310+
PCycle self = factory().createCycle(cls);
311+
self.setSaved(new ArrayList<>());
312+
self.setIterable(getIter.execute(frame, iterable));
313+
self.setIndex(0);
314+
self.setFirstpass(false);
315+
return self;
316+
}
317+
318+
@Specialization(guards = "!isTypeNode.execute(cls)")
319+
@SuppressWarnings("unused")
320+
protected Object notype(Object cls, Object iterable,
321+
@SuppressWarnings("unused") @Cached IsTypeNode isTypeNode) {
322+
throw raise(TypeError, ErrorMessages.IS_NOT_TYPE_OBJ, "'cls'", cls);
323+
}
324+
}
325+
287326
@Builtin(name = "dropwhile", minNumOfPositionalArgs = 3, constructsClass = PythonBuiltinClassType.PDropwhile, doc = "dropwhile(predicate, iterable) --> dropwhile object\n\n" +
288327
"Drop items from the iterable while predicate(item) is true.\n" +
289328
"Afterwards, return every element until the iterable is exhausted.")
@@ -565,6 +604,7 @@ Object construct(VirtualFrame frame, Object cls, Object iterable, Object rArg,
565604
negRprofile.enter();
566605
throw raise(ValueError, MUST_BE_NON_NEGATIVE, "r");
567606
}
607+
// XXX could be generator
568608
int len = sizeNode.execute(frame, iterable);
569609
return construct(cls, iterable, r, len, nrProfile, indicesLoopProfile, cyclesLoopProfile);
570610
}
@@ -684,6 +724,7 @@ private static void constructOneRepeat(VirtualFrame frame, PProduct self, Object
684724
private static void construct(VirtualFrame frame, PProduct self, PSequence[] gears, PyObjectSizeNode sizeNode) {
685725
self.setGears(gears);
686726
for (int i = 0; i < gears.length; i++) {
727+
// XXX could be generator
687728
if (sizeNode.execute(frame, gears[i]) == 0) {
688729
self.setIndices(null);
689730
self.setLst(null);
@@ -1005,7 +1046,7 @@ protected Object construct(Object cls, Object iterable, Object[] args,
10051046
@GenerateNodeFactory
10061047
public abstract static class ZipLongestNode extends PythonBuiltinNode {
10071048
@Specialization(guards = "isTypeNode.execute(cls)")
1008-
Object constructNoFillValue(VirtualFrame frame, Object cls, Object[] args, PNone fillValue,
1049+
Object constructNoFillValue(VirtualFrame frame, Object cls, Object[] args, @SuppressWarnings("unused") PNone fillValue,
10091050
@Cached PyObjectGetIter getIterNode,
10101051
@Cached LoopConditionProfile loopProfile,
10111052
@SuppressWarnings("unused") @Cached IsTypeNode isTypeNode) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
/*
2+
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* The Universal Permissive License (UPL), Version 1.0
6+
*
7+
* Subject to the condition set forth below, permission is hereby granted to any
8+
* person obtaining a copy of this software, associated documentation and/or
9+
* data (collectively the "Software"), free of charge and under any and all
10+
* copyright rights in the Software, and any and all patent rights owned or
11+
* freely licensable by each licensor hereunder covering either (i) the
12+
* unmodified Software as contributed to or provided by such licensor, or (ii)
13+
* the Larger Works (as defined below), to deal in both
14+
*
15+
* (a) the Software, and
16+
*
17+
* (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if
18+
* one is included with the Software each a "Larger Work" to which the Software
19+
* is contributed by such licensors),
20+
*
21+
* without restriction, including without limitation the rights to copy, create
22+
* derivative works of, display, perform, and distribute the Software and make,
23+
* use, sell, offer for sale, import, export, have made, and have sold the
24+
* Software and the Larger Work(s), and to sublicense the foregoing rights on
25+
* either these or other terms.
26+
*
27+
* This license is subject to the following condition:
28+
*
29+
* The above copyright notice and either this complete permission notice or at a
30+
* minimum a reference to the UPL must be included in all copies or substantial
31+
* portions of the Software.
32+
*
33+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
34+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
35+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
36+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
37+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
38+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
39+
* SOFTWARE.
40+
*/
41+
package com.oracle.graal.python.builtins.objects.itertools;
42+
43+
import static com.oracle.graal.python.builtins.PythonBuiltinClassType.StopIteration;
44+
import static com.oracle.graal.python.builtins.PythonBuiltinClassType.TypeError;
45+
import static com.oracle.graal.python.nodes.ErrorMessages.IS_NOT_A;
46+
import static com.oracle.graal.python.nodes.ErrorMessages.STATE_ARGUMENT_D_MUST_BE_A_S;
47+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__ITER__;
48+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__NEXT__;
49+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__REDUCE__;
50+
import static com.oracle.graal.python.nodes.SpecialMethodNames.__SETSTATE__;
51+
52+
import com.oracle.graal.python.builtins.Builtin;
53+
import java.util.List;
54+
55+
import com.oracle.graal.python.builtins.CoreFunctions;
56+
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
57+
import com.oracle.graal.python.builtins.PythonBuiltins;
58+
import com.oracle.graal.python.builtins.modules.BuiltinFunctions;
59+
import com.oracle.graal.python.builtins.objects.PNone;
60+
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.ToArrayNode;
61+
import com.oracle.graal.python.builtins.objects.list.PList;
62+
import com.oracle.graal.python.builtins.objects.object.PythonObject;
63+
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
64+
import com.oracle.graal.python.builtins.objects.tuple.TupleBuiltins.GetItemNode;
65+
import com.oracle.graal.python.builtins.objects.tuple.TupleBuiltins.LenNode;
66+
import com.oracle.graal.python.lib.PyNumberAsSizeNode;
67+
import com.oracle.graal.python.lib.PyObjectGetIter;
68+
import com.oracle.graal.python.lib.PyObjectLookupAttr;
69+
import com.oracle.graal.python.nodes.call.special.CallUnaryMethodNode;
70+
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
71+
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
72+
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
73+
import com.oracle.graal.python.nodes.object.GetClassNode;
74+
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
75+
import com.oracle.graal.python.runtime.exception.PException;
76+
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
77+
import com.oracle.truffle.api.dsl.Cached;
78+
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
79+
import com.oracle.truffle.api.dsl.NodeFactory;
80+
import com.oracle.truffle.api.dsl.Specialization;
81+
import com.oracle.truffle.api.frame.VirtualFrame;
82+
import com.oracle.truffle.api.profiles.BranchProfile;
83+
import java.util.ArrayList;
84+
import java.util.Arrays;
85+
86+
@CoreFunctions(extendClasses = {PythonBuiltinClassType.PCycle})
87+
public final class CycleBuiltins extends PythonBuiltins {
88+
89+
@Override
90+
protected List<? extends NodeFactory<? extends PythonBuiltinBaseNode>> getNodeFactories() {
91+
return CycleBuiltinsFactory.getFactories();
92+
}
93+
94+
@Builtin(name = __ITER__, minNumOfPositionalArgs = 1)
95+
@GenerateNodeFactory
96+
public abstract static class IterNode extends PythonUnaryBuiltinNode {
97+
@Specialization
98+
static Object iter(PCycle self) {
99+
return self;
100+
}
101+
}
102+
103+
@Builtin(name = __NEXT__, minNumOfPositionalArgs = 1)
104+
@GenerateNodeFactory
105+
public abstract static class NextNode extends PythonUnaryBuiltinNode {
106+
@Specialization
107+
Object next(VirtualFrame frame, PCycle self,
108+
@Cached BuiltinFunctions.NextNode nextNode,
109+
@Cached IsBuiltinClassProfile isStopIterationProfile,
110+
@Cached BranchProfile iterableProfile,
111+
@Cached BranchProfile firstPassProfile,
112+
@Cached BranchProfile savedProfile) {
113+
if (self.getIterable() != null) {
114+
iterableProfile.enter();
115+
try {
116+
Object item = nextNode.execute(frame, self.getIterable(), PNone.NO_VALUE);
117+
if (!self.isFirstpass()) {
118+
firstPassProfile.enter();
119+
add(self.getSaved(), item);
120+
}
121+
return item;
122+
} catch (PException e) {
123+
e.expectStopIteration(isStopIterationProfile);
124+
self.setIterable(null);
125+
}
126+
}
127+
if (isEmpty(self.getSaved())) {
128+
savedProfile.enter();
129+
throw raise(StopIteration);
130+
}
131+
Object item = get(self.getSaved(), self.getIndex());
132+
self.setIndex(self.getIndex() + 1);
133+
if (self.getIndex() >= size(self.getSaved())) {
134+
self.setIndex(0);
135+
}
136+
return item;
137+
}
138+
139+
@TruffleBoundary
140+
private boolean isEmpty(List<Object> l) {
141+
return l.isEmpty();
142+
}
143+
144+
@TruffleBoundary
145+
private Object add(List<Object> l, Object item) {
146+
return l.add(item);
147+
}
148+
149+
@TruffleBoundary
150+
private Object get(List<Object> l, int idx) {
151+
return l.get(idx);
152+
}
153+
154+
@TruffleBoundary
155+
private int size(List<Object> l) {
156+
return l.size();
157+
}
158+
}
159+
160+
@Builtin(name = __REDUCE__, minNumOfPositionalArgs = 1)
161+
@GenerateNodeFactory
162+
public abstract static class ReduceNode extends PythonUnaryBuiltinNode {
163+
@Specialization(guards = "hasIterable(self)")
164+
Object reduce(PCycle self,
165+
@Cached GetClassNode getClass) {
166+
Object type = getClass.execute(self);
167+
PTuple iterableTuple = factory().createTuple(new Object[]{self.getIterable()});
168+
PTuple tuple = factory().createTuple(new Object[]{getSavedList(self), self.isFirstpass()});
169+
return factory().createTuple(new Object[]{type, iterableTuple, tuple});
170+
}
171+
172+
@Specialization(guards = "!hasIterable(self)")
173+
Object reduceNoIterable(VirtualFrame frame, PCycle self,
174+
@Cached GetClassNode getClass,
175+
@Cached PyObjectLookupAttr lookupAttrNode,
176+
@Cached CallUnaryMethodNode callNode,
177+
@Cached PyObjectGetIter getIterNode,
178+
@Cached BranchProfile indexProfile) {
179+
Object type = getClass.execute(self);
180+
PList savedList = getSavedList(self);
181+
Object it = getIterNode.execute(frame, savedList);
182+
if (self.getIndex() > 0) {
183+
indexProfile.enter();
184+
Object setStateCallable = lookupAttrNode.execute(frame, it, __SETSTATE__);
185+
callNode.executeObject(frame, setStateCallable, self.getIndex());
186+
}
187+
PTuple iteratorTuple = factory().createTuple(new Object[]{it});
188+
PTuple tuple = factory().createTuple(new Object[]{savedList, true});
189+
return factory().createTuple(new Object[]{type, iteratorTuple, tuple});
190+
}
191+
192+
PList getSavedList(PCycle self) {
193+
return factory().createList(toArray(self.getSaved()));
194+
}
195+
196+
@TruffleBoundary
197+
private static Object[] toArray(List<Object> l) {
198+
return l.toArray(new Object[l.size()]);
199+
}
200+
201+
protected boolean hasIterable(PCycle self) {
202+
return self.getIterable() != null;
203+
}
204+
}
205+
206+
@Builtin(name = __SETSTATE__, minNumOfPositionalArgs = 2)
207+
@GenerateNodeFactory
208+
public abstract static class SetStateNode extends PythonBinaryBuiltinNode {
209+
abstract Object execute(VirtualFrame frame, PythonObject self, Object state);
210+
211+
@Specialization
212+
Object setState(VirtualFrame frame, PCycle self, Object state,
213+
@Cached LenNode lenNode,
214+
@Cached GetItemNode getItemNode,
215+
@Cached IsBuiltinClassProfile isTypeErrorProfile,
216+
@Cached ToArrayNode toArrayNode,
217+
@Cached PyNumberAsSizeNode asSizeNode,
218+
@Cached BranchProfile isNotTupleProfile) {
219+
if (!((state instanceof PTuple) && ((int) lenNode.execute(frame, state) == 2))) {
220+
isNotTupleProfile.enter();
221+
throw raise(TypeError, IS_NOT_A, "state", "2-tuple");
222+
}
223+
Object obj = getItemNode.execute(frame, state, 0);
224+
if (!(obj instanceof PList)) {
225+
throw raise(TypeError, STATE_ARGUMENT_D_MUST_BE_A_S, 1, "Plist");
226+
}
227+
PList saved = (PList) obj;
228+
229+
boolean firstPass;
230+
try {
231+
firstPass = asSizeNode.executeLossy(frame, getItemNode.execute(frame, state, 1)) != 0;
232+
} catch (PException e) {
233+
e.expectTypeError(isTypeErrorProfile);
234+
throw raise(TypeError, STATE_ARGUMENT_D_MUST_BE_A_S, 2, "int");
235+
}
236+
237+
Object[] savedArray = toArrayNode.execute(saved.getSequenceStorage());
238+
self.setSaved(toList(savedArray));
239+
self.setFirstpass(firstPass);
240+
self.setIndex(0);
241+
return PNone.NONE;
242+
}
243+
244+
@TruffleBoundary
245+
private static ArrayList<Object> toList(Object[] savedArray) {
246+
return new ArrayList<>(Arrays.asList(savedArray));
247+
}
248+
}
249+
250+
}

0 commit comments

Comments
 (0)