38
38
import com .oracle .graal .python .builtins .Builtin ;
39
39
import com .oracle .graal .python .builtins .CoreFunctions ;
40
40
import com .oracle .graal .python .builtins .PythonBuiltins ;
41
+ import com .oracle .graal .python .builtins .objects .PNone ;
41
42
import com .oracle .graal .python .builtins .objects .PNotImplemented ;
43
+ import com .oracle .graal .python .builtins .objects .common .EconomicMapStorage ;
42
44
import com .oracle .graal .python .builtins .objects .common .HashingStorage ;
45
+ import com .oracle .graal .python .builtins .objects .common .HashingStorage .Equivalence ;
43
46
import com .oracle .graal .python .builtins .objects .common .HashingStorageNodes ;
47
+ import com .oracle .graal .python .builtins .objects .common .HashingStorageNodes .PythonEquivalence ;
48
+ import com .oracle .graal .python .builtins .objects .common .PHashingCollection ;
49
+ import com .oracle .graal .python .builtins .objects .set .FrozenSetBuiltinsFactory .BinaryUnionNodeGen ;
50
+ import com .oracle .graal .python .nodes .PBaseNode ;
51
+ import com .oracle .graal .python .nodes .control .GetIteratorNode ;
52
+ import com .oracle .graal .python .nodes .control .GetNextNode ;
44
53
import com .oracle .graal .python .nodes .function .PythonBuiltinBaseNode ;
54
+ import com .oracle .graal .python .nodes .function .PythonBuiltinNode ;
45
55
import com .oracle .graal .python .nodes .function .builtins .PythonBinaryBuiltinNode ;
46
56
import com .oracle .graal .python .nodes .function .builtins .PythonUnaryBuiltinNode ;
57
+ import com .oracle .graal .python .runtime .exception .PException ;
58
+ import com .oracle .truffle .api .CompilerDirectives ;
59
+ import com .oracle .truffle .api .CompilerDirectives .CompilationFinal ;
47
60
import com .oracle .truffle .api .dsl .Cached ;
48
61
import com .oracle .truffle .api .dsl .Fallback ;
49
62
import com .oracle .truffle .api .dsl .GenerateNodeFactory ;
50
63
import com .oracle .truffle .api .dsl .NodeFactory ;
51
64
import com .oracle .truffle .api .dsl .Specialization ;
65
+ import com .oracle .truffle .api .profiles .ConditionProfile ;
66
+ import com .oracle .truffle .api .profiles .ValueProfile ;
52
67
53
68
@ CoreFunctions (extendClasses = {PFrozenSet .class , PSet .class })
54
69
public final class FrozenSetBuiltins extends PythonBuiltins {
@@ -116,35 +131,51 @@ Object run(PBaseSet self, PBaseSet other) {
116
131
@ Builtin (name = __AND__ , fixedNumOfArguments = 2 )
117
132
@ GenerateNodeFactory
118
133
abstract static class AndNode extends PythonBinaryBuiltinNode {
134
+ @ Child private HashingStorageNodes .IntersectNode intersectNode ;
135
+
119
136
@ Specialization
120
- PBaseSet doPBaseSet (PSet left , PBaseSet right ,
121
- @ Cached ("create()" ) HashingStorageNodes .IntersectNode intersectNode ) {
122
- HashingStorage intersectedStorage = intersectNode .execute (left .getDictStorage (), right .getDictStorage ());
137
+ PBaseSet doPBaseSet (PSet left , PBaseSet right ) {
138
+ HashingStorage intersectedStorage = getIntersectNode ().execute (left .getDictStorage (), right .getDictStorage ());
123
139
return factory ().createSet (intersectedStorage );
124
140
}
125
141
126
142
@ Specialization
127
- PBaseSet doPBaseSet (PFrozenSet left , PBaseSet right ,
128
- @ Cached ("create()" ) HashingStorageNodes .IntersectNode intersectNode ) {
129
- HashingStorage intersectedStorage = intersectNode .execute (left .getDictStorage (), right .getDictStorage ());
143
+ PBaseSet doPBaseSet (PFrozenSet left , PBaseSet right ) {
144
+ HashingStorage intersectedStorage = getIntersectNode ().execute (left .getDictStorage (), right .getDictStorage ());
130
145
return factory ().createFrozenSet (intersectedStorage );
131
146
}
147
+
148
+ private HashingStorageNodes .IntersectNode getIntersectNode () {
149
+ if (intersectNode == null ) {
150
+ CompilerDirectives .transferToInterpreterAndInvalidate ();
151
+ intersectNode = insert (HashingStorageNodes .IntersectNode .create ());
152
+ }
153
+ return intersectNode ;
154
+ }
132
155
}
133
156
134
157
@ Builtin (name = __SUB__ , fixedNumOfArguments = 2 )
135
158
@ GenerateNodeFactory
136
159
abstract static class SubNode extends PythonBinaryBuiltinNode {
160
+ @ Child private HashingStorageNodes .DiffNode diffNode ;
161
+
162
+ private HashingStorageNodes .DiffNode getDiffNode () {
163
+ if (diffNode == null ) {
164
+ CompilerDirectives .transferToInterpreterAndInvalidate ();
165
+ diffNode = HashingStorageNodes .DiffNode .create ();
166
+ }
167
+ return diffNode ;
168
+ }
169
+
137
170
@ Specialization
138
- PBaseSet doPBaseSet (PSet left , PBaseSet right ,
139
- @ Cached ("create()" ) HashingStorageNodes .DiffNode diffNode ) {
140
- HashingStorage storage = diffNode .execute (left .getDictStorage (), right .getDictStorage ());
171
+ PBaseSet doPBaseSet (PSet left , PBaseSet right ) {
172
+ HashingStorage storage = getDiffNode ().execute (left .getDictStorage (), right .getDictStorage ());
141
173
return factory ().createSet (storage );
142
174
}
143
175
144
176
@ Specialization
145
- PBaseSet doPBaseSet (PFrozenSet left , PBaseSet right ,
146
- @ Cached ("create()" ) HashingStorageNodes .DiffNode diffNode ) {
147
- HashingStorage storage = diffNode .execute (left .getDictStorage (), right .getDictStorage ());
177
+ PBaseSet doPBaseSet (PFrozenSet left , PBaseSet right ) {
178
+ HashingStorage storage = getDiffNode ().execute (left .getDictStorage (), right .getDictStorage ());
148
179
return factory ().createSet (storage );
149
180
}
150
181
}
@@ -158,4 +189,105 @@ boolean contains(PBaseSet self, Object key,
158
189
return containsKeyNode .execute (self .getDictStorage (), key );
159
190
}
160
191
}
192
+
193
+ @ Builtin (name = "union" , minNumOfArguments = 1 , takesVariableArguments = true )
194
+ @ GenerateNodeFactory
195
+ abstract static class UnionNode extends PythonBuiltinNode {
196
+
197
+ @ Child private BinaryUnionNode binaryUnionNode ;
198
+
199
+ @ CompilationFinal private ValueProfile setTypeProfile ;
200
+
201
+ @ Specialization (guards = {"args.length == len" , "args.length < 32" }, limit = "3" )
202
+ PBaseSet doCached (PBaseSet self , Object [] args ,
203
+ @ Cached ("args.length" ) int len ,
204
+ @ Cached ("create()" ) HashingStorageNodes .CopyNode copyNode ) {
205
+ PBaseSet result = create (self , copyNode .execute (self .getDictStorage ()));
206
+ for (int i = 0 ; i < len ; i ++) {
207
+ getBinaryUnionNode ().execute (result , result .getDictStorage (), args [i ]);
208
+ }
209
+ return result ;
210
+ }
211
+
212
+ @ Specialization (replaces = "doCached" )
213
+ PBaseSet doGeneric (PBaseSet self , Object [] args ,
214
+ @ Cached ("create()" ) HashingStorageNodes .CopyNode copyNode ) {
215
+ PBaseSet result = create (self , copyNode .execute (self .getDictStorage ()));
216
+ for (int i = 0 ; i < args .length ; i ++) {
217
+ getBinaryUnionNode ().execute (result , result .getDictStorage (), args [i ]);
218
+ }
219
+ return result ;
220
+ }
221
+
222
+ private PBaseSet create (PBaseSet left , HashingStorage storage ) {
223
+ if (getSetTypeProfile ().profile (left ) instanceof PFrozenSet ) {
224
+ return factory ().createFrozenSet (storage );
225
+ }
226
+ return factory ().createSet (storage );
227
+ }
228
+
229
+ private BinaryUnionNode getBinaryUnionNode () {
230
+ if (binaryUnionNode == null ) {
231
+ CompilerDirectives .transferToInterpreterAndInvalidate ();
232
+ binaryUnionNode = insert (BinaryUnionNode .create ());
233
+ }
234
+ return binaryUnionNode ;
235
+ }
236
+
237
+ private ValueProfile getSetTypeProfile () {
238
+ if (setTypeProfile == null ) {
239
+ CompilerDirectives .transferToInterpreterAndInvalidate ();
240
+ setTypeProfile = ValueProfile .createClassProfile ();
241
+ }
242
+ return setTypeProfile ;
243
+ }
244
+
245
+ }
246
+
247
+ abstract static class BinaryUnionNode extends PBaseNode {
248
+ @ Child private Equivalence equivalenceNode ;
249
+
250
+ public abstract PBaseSet execute (PBaseSet container , HashingStorage left , Object right );
251
+
252
+ @ Specialization
253
+ PBaseSet doHashingCollection (PBaseSet container , EconomicMapStorage selfStorage , PHashingCollection other ) {
254
+ for (Object key : other .getDictStorage ().keys ()) {
255
+ selfStorage .setItem (key , PNone .NO_VALUE , getEquivalence ());
256
+ }
257
+ return container ;
258
+ }
259
+
260
+ @ Specialization
261
+ PBaseSet doIterable (PBaseSet container , HashingStorage dictStorage , Object iterable ,
262
+ @ Cached ("create()" ) GetIteratorNode getIteratorNode ,
263
+ @ Cached ("create()" ) GetNextNode next ,
264
+ @ Cached ("createBinaryProfile()" ) ConditionProfile errorProfile ,
265
+ @ Cached ("create()" ) HashingStorageNodes .SetItemNode setItemNode ) {
266
+
267
+ Object iterator = getIteratorNode .executeWith (iterable );
268
+ while (true ) {
269
+ Object value ;
270
+ try {
271
+ value = next .execute (iterator );
272
+ } catch (PException e ) {
273
+ e .expectStopIteration (getCore (), errorProfile );
274
+ return container ;
275
+ }
276
+ setItemNode .execute (container , dictStorage , value , PNone .NO_VALUE );
277
+ }
278
+ }
279
+
280
+ protected Equivalence getEquivalence () {
281
+ if (equivalenceNode == null ) {
282
+ CompilerDirectives .transferToInterpreterAndInvalidate ();
283
+ equivalenceNode = insert (new PythonEquivalence ());
284
+ }
285
+ return equivalenceNode ;
286
+ }
287
+
288
+ public static BinaryUnionNode create () {
289
+ return BinaryUnionNodeGen .create ();
290
+ }
291
+
292
+ }
161
293
}
0 commit comments