Skip to content

Commit 2a58ec4

Browse files
committed
more efficient set.update
1 parent 972acb8 commit 2a58ec4

File tree

1 file changed

+87
-0
lines changed
  • graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set

1 file changed

+87
-0
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/SetBuiltins.java

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,26 @@
4444
import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes.GetHashingStorageNode;
4545
import com.oracle.graal.python.builtins.objects.common.HashingStorage;
4646
import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary;
47+
import com.oracle.graal.python.builtins.objects.common.PHashingCollection;
48+
import com.oracle.graal.python.builtins.objects.dict.PDictView;
49+
import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
50+
import com.oracle.graal.python.builtins.objects.str.PString;
4751
import com.oracle.graal.python.nodes.ErrorMessages;
4852
import com.oracle.graal.python.nodes.PGuards;
4953
import com.oracle.graal.python.nodes.SpecialMethodNames;
54+
import com.oracle.graal.python.nodes.control.GetNextNode;
5055
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5156
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
5257
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
5358
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
59+
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
5460
import com.oracle.graal.python.runtime.PythonCore;
61+
import com.oracle.graal.python.runtime.PythonOptions;
62+
import com.oracle.graal.python.runtime.exception.PException;
5563
import com.oracle.graal.python.runtime.exception.PythonErrorType;
64+
import com.oracle.graal.python.runtime.sequence.PSequence;
65+
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
66+
import com.oracle.truffle.api.dsl.Bind;
5667
import com.oracle.truffle.api.dsl.Cached;
5768
import com.oracle.truffle.api.dsl.Fallback;
5869
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
@@ -61,6 +72,7 @@
6172
import com.oracle.truffle.api.dsl.Specialization;
6273
import com.oracle.truffle.api.frame.VirtualFrame;
6374
import com.oracle.truffle.api.library.CachedLibrary;
75+
import com.oracle.truffle.api.nodes.Node;
6476
import com.oracle.truffle.api.profiles.BranchProfile;
6577
import com.oracle.truffle.api.profiles.ConditionProfile;
6678

@@ -188,6 +200,74 @@ PBaseSet doGeneric(VirtualFrame frame, PSet self, Object[] args,
188200
}
189201
}
190202

203+
@ImportStatic({PGuards.class, PythonOptions.class})
204+
public abstract static class UpdateSingleNode extends Node {
205+
206+
public abstract HashingStorage execute(VirtualFrame frame, HashingStorage storage, Object other);
207+
208+
@Specialization
209+
static HashingStorage update(HashingStorage storage, PHashingCollection other,
210+
@CachedLibrary(limit = "1") HashingStorageLibrary lib) {
211+
HashingStorage dictStorage = other.getDictStorage();
212+
return lib.addAllToOther(dictStorage, storage);
213+
}
214+
215+
@Specialization
216+
static HashingStorage update(HashingStorage storage, PDictView.PDictKeysView other,
217+
@CachedLibrary(limit = "1") HashingStorageLibrary lib) {
218+
HashingStorage dictStorage = other.getWrappedDict().getDictStorage();
219+
return lib.addAllToOther(dictStorage, storage);
220+
}
221+
222+
static boolean isBuiltinSequence(Object other, PythonObjectLibrary lib) {
223+
return other instanceof PSequence && !(other instanceof PString) && lib.getLazyPythonClass(other) instanceof PythonBuiltinClassType;
224+
}
225+
226+
static SequenceStorage getSequenceStorage(PSequence sequence, Class<? extends PSequence> clazz) {
227+
return clazz.cast(sequence).getSequenceStorage();
228+
}
229+
230+
@Specialization(guards = {"isBuiltinSequence(other, otherLib)", "other.getClass() == sequenceClass",
231+
"sequenceStorage.getClass() == storageClass"}, limit = "getCallSiteInlineCacheMaxDepth()")
232+
static HashingStorage doIterable(VirtualFrame frame, HashingStorage storage, @SuppressWarnings("unused") PSequence other,
233+
@SuppressWarnings("unused") @CachedLibrary("other") PythonObjectLibrary otherLib,
234+
@SuppressWarnings("unused") @Cached("other.getClass()") Class<? extends PSequence> sequenceClass,
235+
@Bind("getSequenceStorage(other, sequenceClass)") SequenceStorage sequenceStorage,
236+
@SuppressWarnings("unused") @Cached("sequenceStorage.getClass()") Class<? extends SequenceStorage> storageClass,
237+
@Cached("createBinaryProfile()") ConditionProfile hasFrame,
238+
@CachedLibrary(limit = "2") HashingStorageLibrary lib) {
239+
SequenceStorage profiledSequenceStorage = storageClass.cast(sequenceStorage);
240+
int length = profiledSequenceStorage.length();
241+
HashingStorage curStorage = storage;
242+
for (int i = 0; i < length; i++) {
243+
Object key = profiledSequenceStorage.getItemNormalized(i);
244+
curStorage = lib.setItemWithFrame(curStorage, key, PNone.NONE, hasFrame, frame);
245+
}
246+
return curStorage;
247+
}
248+
249+
@Specialization(guards = {"!isPHashingCollection(other)", "!isDictKeysView(other)", "!isBuiltinSequence(other, otherLib)"}, limit = "getCallSiteInlineCacheMaxDepth()")
250+
static HashingStorage doIterable(VirtualFrame frame, HashingStorage storage, Object other,
251+
@CachedLibrary("other") PythonObjectLibrary otherLib,
252+
@Cached GetNextNode nextNode,
253+
@Cached IsBuiltinClassProfile errorProfile,
254+
@Cached("createBinaryProfile()") ConditionProfile hasFrame,
255+
@CachedLibrary(limit = "2") HashingStorageLibrary lib) {
256+
HashingStorage curStorage = storage;
257+
Object iterator = otherLib.getIteratorWithFrame(other, frame);
258+
while (true) {
259+
Object key;
260+
try {
261+
key = nextNode.execute(frame, iterator);
262+
} catch (PException e) {
263+
e.expectStopIteration(errorProfile);
264+
return curStorage;
265+
}
266+
curStorage = lib.setItemWithFrame(curStorage, key, PNone.NONE, hasFrame, frame);
267+
}
268+
}
269+
}
270+
191271
@Builtin(name = "update", minNumOfPositionalArgs = 1, takesVarArgs = true)
192272
@GenerateNodeFactory
193273
public abstract static class UpdateNode extends PythonBuiltinNode {
@@ -198,6 +278,13 @@ static PNone doSet(VirtualFrame frame, PSet self, PNone other) {
198278
return PNone.NONE;
199279
}
200280

281+
@Specialization(guards = "args.length == 1")
282+
static PNone doCached(VirtualFrame frame, PSet self, Object[] args,
283+
@Cached UpdateSingleNode update) {
284+
self.setDictStorage(update.execute(frame, self.getDictStorage(), args[0]));
285+
return PNone.NONE;
286+
}
287+
201288
@Specialization(guards = {"args.length == len", "args.length < 32"}, limit = "3")
202289
static PNone doCached(VirtualFrame frame, PSet self, Object[] args,
203290
@Cached("args.length") int len,

0 commit comments

Comments
 (0)