42
42
43
43
import static com .oracle .graal .python .builtins .PythonBuiltinClassType .TypeError ;
44
44
import static com .oracle .graal .python .nodes .ErrorMessages .IS_NOT_A ;
45
- import static com .oracle .graal .python .nodes .ErrorMessages .STATE_ARGUMENT_D_MUST_BE_A_S ;
46
45
import static com .oracle .graal .python .nodes .SpecialMethodNames .__ITER__ ;
47
46
import static com .oracle .graal .python .nodes .SpecialMethodNames .__NEXT__ ;
48
47
import static com .oracle .graal .python .nodes .SpecialMethodNames .__REDUCE__ ;
@@ -97,53 +96,79 @@ Object next(VirtualFrame frame, PGroupBy self,
97
96
@ Cached BuiltinFunctions .NextNode nextNode ,
98
97
@ Cached CallNode callNode ,
99
98
@ Cached PyObjectRichCompareBool .EqNode eqNode ,
99
+ @ Cached BranchProfile eqProfile ,
100
100
@ Cached ConditionProfile hasFuncProfile ,
101
101
@ Cached LoopConditionProfile loopConditionProfile ) {
102
102
self .setCurrGrouper (null );
103
- Object marker = self .getMarker ();
104
- while (loopConditionProfile .profile (!(self .getCurrKey () != marker && (self .getTgtKey () == marker || !eqNode .execute (frame , self .getTgtKey (), self .getCurrKey ()))))) {
103
+ while (loopConditionProfile .profile (doGroupByStep (frame , self , eqProfile , eqNode ))) {
105
104
self .groupByStep (frame , nextNode , callNode , hasFuncProfile );
106
105
}
107
106
self .setTgtKey (self .getCurrKey ());
108
107
PGrouper grouper = factory ().createGrouper (self , self .getTgtKey ());
109
108
return factory ().createTuple (new Object []{self .getCurrKey (), grouper });
110
109
}
110
+
111
+ protected boolean doGroupByStep (VirtualFrame frame , PGroupBy self , BranchProfile eqProfile , PyObjectRichCompareBool .EqNode eqNode ) {
112
+ if (self .getCurrKey () == null ) {
113
+ return true ;
114
+ } else if (self .getTgtKey () == null ) {
115
+ return false ;
116
+ } else {
117
+ eqProfile .enter ();
118
+ if (!eqNode .execute (frame , self .getTgtKey (), self .getCurrKey ())) {
119
+ return false ;
120
+ }
121
+ }
122
+ return true ;
123
+ }
111
124
}
112
125
113
126
@ Builtin (name = __REDUCE__ , minNumOfPositionalArgs = 1 )
114
127
@ GenerateNodeFactory
115
128
public abstract static class ReduceNode extends PythonUnaryBuiltinNode {
116
- @ Specialization (guards = "isMarkerStillSet (self)" )
129
+ @ Specialization (guards = { "!valuesSet (self)", "isNull(self.getKeyFunc())" } )
117
130
Object reduce (PGroupBy self ,
118
131
@ Cached GetClassNode getClassNode ) {
132
+ return reduce (self , PNone .NONE , getClassNode );
133
+ }
119
134
120
- Object type = getClassNode .execute (self );
121
- Object keyFunc = self .getKeyFunc () == null ? PNone .NONE : self .getKeyFunc ();
135
+ @ Specialization (guards = {"!valuesSet(self)" , "!isNull(self.getKeyFunc())" })
136
+ Object reduceNoFunc (PGroupBy self ,
137
+ @ Cached GetClassNode getClassNode ) {
138
+ return reduce (self , self .getKeyFunc (), getClassNode );
139
+ }
122
140
141
+ private Object reduce (PGroupBy self , Object keyFunc , GetClassNode getClassNode ) {
142
+ Object type = getClassNode .execute (self );
123
143
PTuple tuple = factory ().createTuple (new Object []{self .getIt (), keyFunc });
124
144
return factory ().createTuple (new Object []{type , tuple });
125
145
}
126
146
127
- @ Specialization (guards = "!isMarkerStillSet (self)" )
147
+ @ Specialization (guards = { "valuesSet (self)", "isNull(self.getKeyFunc())" } )
128
148
Object reduceMarkerNotSet (PGroupBy self ,
129
149
@ Cached GetClassNode getClassNode ) {
150
+ return reduceOther (self , PNone .NONE , getClassNode );
151
+ }
152
+
153
+ @ Specialization (guards = {"valuesSet(self)" , "!isNull(self.getKeyFunc())" })
154
+ Object reduceMarkerNotSetNoFunc (PGroupBy self ,
155
+ @ Cached GetClassNode getClassNode ) {
156
+ return reduceOther (self , self .getKeyFunc (), getClassNode );
157
+ }
130
158
159
+ private Object reduceOther (PGroupBy self , Object keyFunc , GetClassNode getClassNode ) {
131
160
Object type = getClassNode .execute (self );
132
- Object keyFunc = self .getKeyFunc () == null ? PNone .NONE : self .getKeyFunc ();
133
-
134
- Object currValue = self .getCurrValue ();
135
- Object tgtKey = self .getTgtKey ();
136
- Object currKey = self .getCurrKey ();
137
- Object currGrouper = self .getCurrGrouper () == null ? PNone .NONE : self .getCurrGrouper ();
138
- Object marker = self .getMarker ();
161
+ PTuple tuple1 = factory ().createTuple (new Object []{self .getIt (), keyFunc });
162
+ PTuple tuple2 = factory ().createTuple (new Object []{self .getCurrValue (), self .getTgtKey (), self .getCurrKey ()});
163
+ return factory ().createTuple (new Object []{type , tuple1 , tuple2 });
164
+ }
139
165
140
- PTuple tuple = factory ().createTuple (new Object []{self .getIt (), keyFunc , currValue , tgtKey , currKey , currGrouper , marker });
141
- PTuple emptyTuple = factory ().createTuple (new Object []{factory ().createEmptyTuple ()});
142
- return factory ().createTuple (new Object []{type , emptyTuple , tuple });
166
+ protected boolean valuesSet (PGroupBy self ) {
167
+ return self .getTgtKey () != null && self .getCurrKey () != null && self .getCurrValue () != null ;
143
168
}
144
169
145
- protected boolean isMarkerStillSet ( PGroupBy self ) {
146
- return self . getCurrValue () == self . getMarker () || self . getTgtKey () == self . getMarker () || self . getCurrKey () == self . getMarker () ;
170
+ protected boolean isNull ( Object obj ) {
171
+ return obj == null ;
147
172
}
148
173
149
174
}
@@ -155,41 +180,21 @@ public abstract static class SetStateNode extends PythonBinaryBuiltinNode {
155
180
Object setState (VirtualFrame frame , PGroupBy self , Object state ,
156
181
@ Cached TupleBuiltins .LenNode lenNode ,
157
182
@ Cached TupleBuiltins .GetItemNode getItemNode ,
158
- @ Cached BranchProfile isNotTupleProfile ,
159
- @ Cached BranchProfile isNotGroupByProfile ) {
160
- if (!(state instanceof PTuple ) || (int ) lenNode .execute (frame , state ) != 7 ) {
183
+ @ Cached BranchProfile isNotTupleProfile ) {
184
+ if (!(state instanceof PTuple ) || (int ) lenNode .execute (frame , state ) != 3 ) {
161
185
isNotTupleProfile .enter ();
162
- throw raise (TypeError , IS_NOT_A , "state" , "7 -tuple" );
186
+ throw raise (TypeError , IS_NOT_A , "state" , "3 -tuple" );
163
187
}
164
- Object iterable = getItemNode .execute (frame , state , 0 );
165
- self .setIt (iterable );
166
-
167
- Object keyFunc = getItemNode .execute (frame , state , 1 );
168
- self .setKeyFunc (keyFunc instanceof PNone ? null : keyFunc );
169
188
170
- Object currValue = getItemNode .execute (frame , state , 2 );
189
+ Object currValue = getItemNode .execute (frame , state , 0 );
171
190
self .setCurrValue (currValue );
172
191
173
- Object tgtKey = getItemNode .execute (frame , state , 3 );
192
+ Object tgtKey = getItemNode .execute (frame , state , 1 );
174
193
self .setTgtKey (tgtKey );
175
194
176
- Object currKey = getItemNode .execute (frame , state , 4 );
195
+ Object currKey = getItemNode .execute (frame , state , 2 );
177
196
self .setCurrKey (currKey );
178
197
179
- Object currGrouper = getItemNode .execute (frame , state , 5 );
180
- if (currGrouper instanceof PNone ) {
181
- self .setCurrGrouper (null );
182
- } else {
183
- if (!(currGrouper instanceof PGrouper )) {
184
- isNotGroupByProfile .enter ();
185
- throw raise (TypeError , STATE_ARGUMENT_D_MUST_BE_A_S , 6 , "PGrouper" );
186
- }
187
- self .setCurrGrouper ((PGrouper ) currGrouper );
188
- }
189
-
190
- Object marker = getItemNode .execute (frame , state , 6 );
191
- self .setMarker (marker );
192
-
193
198
return PNone .NONE ;
194
199
}
195
200
}
0 commit comments