Skip to content

Commit 3cda2d4

Browse files
committed
fixes in accumulate.__reduce__
1 parent 9ebb247 commit 3cda2d4

File tree

1 file changed

+39
-15
lines changed

1 file changed

+39
-15
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/itertools/AccumulateBuiltins.java

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ Object next(VirtualFrame frame, PAccumulate self,
9898
@Cached BranchProfile hasInitialProfile,
9999
@Cached BranchProfile markerProfile,
100100
@Cached ConditionProfile hasFuncProfile) {
101-
102101
if (self.getInitial() != null) {
103102
hasInitialProfile.enter();
104103
self.setTotal(self.getInitial());
@@ -123,48 +122,73 @@ Object next(VirtualFrame frame, PAccumulate self,
123122
@Builtin(name = __REDUCE__, minNumOfPositionalArgs = 1)
124123
@GenerateNodeFactory
125124
public abstract static class ReduceNode extends PythonUnaryBuiltinNode {
126-
@Specialization
125+
@Specialization(guards = "hasFunc(self)")
127126
Object reduce(VirtualFrame frame, PAccumulate self,
128127
@Cached GetClassNode getClassNode,
129128
@Cached BranchProfile hasInitialProfile,
130-
@Cached BranchProfile hasTotalProfile,
129+
@Cached BranchProfile totalNoneProfile,
130+
@Cached BranchProfile totalMarkerProfile,
131131
@Cached BranchProfile elseProfile,
132132
@Cached PyObjectGetIter getIter) {
133+
return reduce(self, self.getFunc(), hasInitialProfile, getClassNode, totalNoneProfile, getIter, frame, totalMarkerProfile, elseProfile);
134+
}
135+
136+
@Specialization(guards = "!hasFunc(self)")
137+
Object reduceNoFunc(VirtualFrame frame, PAccumulate self,
138+
@Cached GetClassNode getClassNode,
139+
@Cached BranchProfile hasInitialProfile,
140+
@Cached BranchProfile totalNoneProfile,
141+
@Cached BranchProfile totalMarkerProfile,
142+
@Cached BranchProfile elseProfile,
143+
@Cached PyObjectGetIter getIter) {
144+
return reduce(self, PNone.NONE, hasInitialProfile, getClassNode, totalNoneProfile, getIter, frame, totalMarkerProfile, elseProfile);
145+
}
146+
147+
private Object reduce(PAccumulate self, Object func, BranchProfile hasInitialProfile, GetClassNode getClassNode, BranchProfile totalNoneProfile, PyObjectGetIter getIter, VirtualFrame frame,
148+
BranchProfile totalMarkerProfile, BranchProfile elseProfile) {
133149
if (self.getInitial() != null) {
134150
hasInitialProfile.enter();
151+
135152
Object type = getClassNode.execute(self);
136-
PTuple inititalTuple = factory().createTuple(new Object[]{self.getInitial()});
137153
PChain chain = factory().createChain(PythonBuiltinClassType.PChain);
138-
chain.setSource(inititalTuple);
139-
chain.setActive(self.getIterable());
154+
chain.setSource(getIter.execute(frame, factory().createList(new Object[]{self.getIterable()})));
155+
PTuple initialTuple = factory().createTuple(new Object[]{self.getInitial()});
156+
chain.setActive(getIter.execute(frame, initialTuple));
140157

141-
PTuple tuple = factory().createTuple(new Object[]{chain, self.getFunc()});
158+
PTuple tuple = factory().createTuple(new Object[]{chain, func});
142159
return factory().createTuple(new Object[]{type, tuple, PNone.NONE});
143160
} else if (self.getTotal() == PNone.NONE) {
144-
hasTotalProfile.enter();
161+
totalNoneProfile.enter();
162+
145163
PChain chain = factory().createChain(PythonBuiltinClassType.PChain);
146164
PList noneList = factory().createList(new Object[]{PNone.NONE});
147165
Object noneIter = getIter.execute(frame, noneList);
148166
chain.setSource(getIter.execute(frame, factory().createList(new Object[]{noneIter, self.getIterable()})));
149167
chain.setActive(PNone.NONE);
150-
151168
PAccumulate accumulate = factory().createAccumulate(PythonBuiltinClassType.PAccumulate);
152169
accumulate.setIterable(chain);
153-
accumulate.setFunc(self.getFunc());
170+
accumulate.setFunc(func);
154171

155172
PTuple tuple = factory().createTuple(new Object[]{accumulate, 1, PNone.NONE});
156173
return factory().createTuple(new Object[]{PythonBuiltinClassType.PIslice, tuple});
174+
} else if (self.getTotal() != null) {
175+
totalMarkerProfile.enter();
176+
177+
Object type = getClassNode.execute(self);
178+
PTuple tuple = factory().createTuple(new Object[]{self.getIterable(), func});
179+
return factory().createTuple(new Object[]{type, tuple, self.getTotal()});
157180
} else {
158181
elseProfile.enter();
159-
Object type = getClassNode.execute(self);
160182

161-
Object func = self.getFunc() != null ? self.getFunc() : PNone.NONE;
183+
Object type = getClassNode.execute(self);
162184
PTuple tuple = factory().createTuple(new Object[]{self.getIterable(), func});
163-
164-
Object total = self.getTotal() != null ? self.getTotal() : PNone.NONE;
165-
return factory().createTuple(new Object[]{type, tuple, total});
185+
return factory().createTuple(new Object[]{type, tuple});
166186
}
167187
}
188+
189+
protected static boolean hasFunc(PAccumulate self) {
190+
return self.getFunc() != null;
191+
}
168192
}
169193

170194
@Builtin(name = __SETSTATE__, minNumOfPositionalArgs = 2)

0 commit comments

Comments
 (0)