Skip to content

Commit 6fb7a57

Browse files
committed
[GR-28698] Efficient sort and set.update.
PullRequest: graalpython/1634
2 parents 96eb044 + 2a58ec4 commit 6fb7a57

File tree

5 files changed

+196
-9
lines changed

5 files changed

+196
-9
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingStorageLibrary.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
* argument. Thus, storages must ensure to cache the hash value if a call could be observed.
7676
*/
7777
@GenerateLibrary
78-
@SuppressWarnings("unused")
7978
public abstract class HashingStorageLibrary extends Library {
8079
/**
8180
* @return the length of {@code self}
@@ -299,7 +298,7 @@ public final <T> T forEach(HashingStorage self, ForEachNode<T> node, T arg) {
299298
* @return {@code true} if the storage has elements with a potential side effect, otherwise
300299
* {@code false}.
301300
*/
302-
public boolean hasSideEffect(HashingStorage self) {
301+
public boolean hasSideEffect(@SuppressWarnings("unused") HashingStorage self) {
303302
return false;
304303
}
305304

@@ -473,7 +472,7 @@ public Iterator<T> getIterator() {
473472

474473
@Override
475474
public HashingStorageIterator<T> iterator() {
476-
return new HashingStorageIterator<T>(this.iterator);
475+
return new HashingStorageIterator<>(this.iterator);
477476
}
478477
}
479478

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/LocalsStorage.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
import com.oracle.truffle.api.profiles.ConditionProfile;
7373

7474
@ExportLibrary(HashingStorageLibrary.class)
75-
public class LocalsStorage extends HashingStorage {
75+
public final class LocalsStorage extends HashingStorage {
7676
/* This won't be the real (materialized) frame but a clone of it. */
7777
protected final MaterializedFrame frame;
7878
private int len = -1;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/list/ListBuiltins.java

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151
import static com.oracle.graal.python.runtime.exception.PythonErrorType.TypeError;
5252

5353
import java.math.BigInteger;
54+
import java.util.Arrays;
55+
import java.util.Comparator;
5456
import java.util.List;
5557

5658
import com.oracle.graal.python.PythonLanguage;
@@ -81,6 +83,7 @@
8183
import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
8284
import com.oracle.graal.python.builtins.objects.range.PIntRange;
8385
import com.oracle.graal.python.builtins.objects.str.PString;
86+
import com.oracle.graal.python.builtins.objects.str.StringUtils;
8487
import com.oracle.graal.python.builtins.objects.tuple.PTuple;
8588
import com.oracle.graal.python.nodes.ErrorMessages;
8689
import com.oracle.graal.python.nodes.PGuards;
@@ -105,10 +108,13 @@
105108
import com.oracle.graal.python.runtime.exception.PException;
106109
import com.oracle.graal.python.runtime.exception.PythonErrorType;
107110
import com.oracle.graal.python.runtime.sequence.PSequence;
111+
import com.oracle.graal.python.runtime.sequence.storage.BoolSequenceStorage;
112+
import com.oracle.graal.python.runtime.sequence.storage.ByteSequenceStorage;
108113
import com.oracle.graal.python.runtime.sequence.storage.DoubleSequenceStorage;
109114
import com.oracle.graal.python.runtime.sequence.storage.EmptySequenceStorage;
110115
import com.oracle.graal.python.runtime.sequence.storage.IntSequenceStorage;
111116
import com.oracle.graal.python.runtime.sequence.storage.LongSequenceStorage;
117+
import com.oracle.graal.python.runtime.sequence.storage.ObjectSequenceStorage;
112118
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
113119
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorageFactory;
114120
import com.oracle.graal.python.util.PythonUtils;
@@ -124,6 +130,7 @@
124130
import com.oracle.truffle.api.dsl.TypeSystemReference;
125131
import com.oracle.truffle.api.frame.VirtualFrame;
126132
import com.oracle.truffle.api.library.CachedLibrary;
133+
import com.oracle.truffle.api.nodes.Node;
127134
import com.oracle.truffle.api.nodes.UnexpectedResultException;
128135
import com.oracle.truffle.api.profiles.ConditionProfile;
129136

@@ -861,6 +868,92 @@ public static ListReverseNode create() {
861868
}
862869
}
863870

871+
abstract static class SimpleSortNode extends Node {
872+
873+
protected static final String SORT = "_sort";
874+
875+
protected abstract void execute(VirtualFrame frame, PList list, SequenceStorage storage);
876+
877+
@Specialization
878+
@TruffleBoundary
879+
void sort(@SuppressWarnings("unused") PList list, BoolSequenceStorage storage) {
880+
int length = storage.length();
881+
int trueValues = 0;
882+
boolean[] array = storage.getInternalBoolArray();
883+
for (int i = 0; i < length; i++) {
884+
if (array[i]) {
885+
trueValues++;
886+
}
887+
}
888+
Arrays.fill(array, 0, length - trueValues, false);
889+
Arrays.fill(array, length - trueValues, length, true);
890+
}
891+
892+
@Specialization
893+
@TruffleBoundary
894+
void sort(@SuppressWarnings("unused") PList list, ByteSequenceStorage storage) {
895+
Arrays.sort(storage.getInternalByteArray(), 0, storage.length());
896+
}
897+
898+
@Specialization
899+
@TruffleBoundary
900+
void sort(@SuppressWarnings("unused") PList list, IntSequenceStorage storage) {
901+
Arrays.sort(storage.getInternalIntArray(), 0, storage.length());
902+
}
903+
904+
@Specialization
905+
@TruffleBoundary
906+
void sort(@SuppressWarnings("unused") PList list, LongSequenceStorage storage) {
907+
Arrays.sort(storage.getInternalLongArray(), 0, storage.length());
908+
}
909+
910+
@Specialization
911+
@TruffleBoundary
912+
void sort(@SuppressWarnings("unused") PList list, DoubleSequenceStorage storage) {
913+
Arrays.sort(storage.getInternalDoubleArray(), 0, storage.length());
914+
}
915+
916+
private static final class StringComparator implements Comparator<Object> {
917+
public int compare(Object o1, Object o2) {
918+
return StringUtils.compareToUnicodeAware((String) o1, (String) o2);
919+
}
920+
}
921+
922+
private static final StringComparator COMPARATOR = new StringComparator();
923+
924+
@Specialization(guards = "isStringOnly(storage)")
925+
@TruffleBoundary
926+
void sort(@SuppressWarnings("unused") PList list, ObjectSequenceStorage storage) {
927+
Arrays.sort(storage.getInternalArray(), 0, storage.length(), COMPARATOR);
928+
}
929+
930+
@TruffleBoundary
931+
protected static boolean isStringOnly(ObjectSequenceStorage storage) {
932+
int length = storage.length();
933+
Object[] array = storage.getInternalArray();
934+
for (int i = 0; i < length; i++) {
935+
Object value = array[i];
936+
if (!(value instanceof String)) {
937+
return false;
938+
}
939+
}
940+
return true;
941+
}
942+
943+
protected static boolean isSimpleType(SequenceStorage storage) {
944+
return storage instanceof BoolSequenceStorage || storage instanceof ByteSequenceStorage || storage instanceof IntSequenceStorage || storage instanceof LongSequenceStorage ||
945+
storage instanceof DoubleSequenceStorage || (storage instanceof ObjectSequenceStorage && isStringOnly((ObjectSequenceStorage) storage));
946+
}
947+
948+
@Specialization(guards = "!isSimpleType(storage)")
949+
void defaultSort(VirtualFrame frame, PList list, @SuppressWarnings("unused") SequenceStorage storage,
950+
@Cached("create(SORT)") GetAttributeNode sort,
951+
@Cached CallNode callSort) {
952+
Object sortMethod = sort.executeObject(frame, list);
953+
callSort.execute(frame, sortMethod, PythonUtils.EMPTY_OBJECT_ARRAY, PKeyword.EMPTY_KEYWORDS);
954+
}
955+
}
956+
864957
// list.sort(key=, reverse=)
865958
@Builtin(name = SORT, minNumOfPositionalArgs = 1, takesVarArgs = true, takesVarKeywordArgs = true, needsFrame = true)
866959
@GenerateNodeFactory
@@ -878,10 +971,10 @@ protected static boolean maySideEffect(PList list, PKeyword[] keywords) {
878971
return true;
879972
}
880973
if (keywords.length > 0) {
881-
if (keywords[0].getName().equals(KEY)) {
974+
if (KEY.equals(keywords[0].getName())) {
882975
return true;
883976
}
884-
if (keywords.length > 1 && keywords[1].getName().equals(KEY)) {
977+
if (keywords.length > 1 && KEY.equals(keywords[1].getName())) {
885978
return true;
886979
}
887980
}
@@ -901,6 +994,14 @@ Object none(VirtualFrame frame, PList list, Object[] arguments, PKeyword[] keywo
901994
return PNone.NONE;
902995
}
903996

997+
@Specialization(guards = {"isSortable(list, lenNode)", "arguments.length == 0", "keywords.length == 0", "!maySideEffect(list, keywords)"})
998+
Object simple(VirtualFrame frame, PList list, @SuppressWarnings("unused") Object[] arguments, @SuppressWarnings("unused") PKeyword[] keywords,
999+
@Cached SimpleSortNode simpleSort,
1000+
@SuppressWarnings("unused") @Cached SequenceStorageNodes.LenNode lenNode) {
1001+
simpleSort.execute(frame, list, list.getSequenceStorage());
1002+
return PNone.NONE;
1003+
}
1004+
9041005
@Specialization(guards = {"isSortable(list, lenNode)", "maySideEffect(list, keywords)"})
9051006
Object withKey(VirtualFrame frame, PList list, Object[] arguments, PKeyword[] keywords,
9061007
@Cached("create(SORT)") GetAttributeNode sort,
@@ -921,7 +1022,7 @@ Object defaultSort(VirtualFrame frame, PList list, Object[] arguments, PKeyword[
9211022
@Cached CallNode callSort,
9221023
@SuppressWarnings("unused") @Cached SequenceStorageNodes.LenNode lenNode) {
9231024
Object sortMethod = sort.executeObject(frame, list);
924-
callSort.execute(sortMethod, arguments, keywords);
1025+
callSort.execute(frame, sortMethod, arguments, keywords);
9251026
return PNone.NONE;
9261027
}
9271028

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/SetBuiltins.java

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,26 @@
4444
import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes.GetHashingStorageNode;
4545
import com.oracle.graal.python.builtins.objects.common.HashingStorage;
4646
import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary;
47+
import com.oracle.graal.python.builtins.objects.common.PHashingCollection;
48+
import com.oracle.graal.python.builtins.objects.dict.PDictView;
49+
import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
50+
import com.oracle.graal.python.builtins.objects.str.PString;
4751
import com.oracle.graal.python.nodes.ErrorMessages;
4852
import com.oracle.graal.python.nodes.PGuards;
4953
import com.oracle.graal.python.nodes.SpecialMethodNames;
54+
import com.oracle.graal.python.nodes.control.GetNextNode;
5055
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5156
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
5257
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
5358
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
59+
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
5460
import com.oracle.graal.python.runtime.PythonCore;
61+
import com.oracle.graal.python.runtime.PythonOptions;
62+
import com.oracle.graal.python.runtime.exception.PException;
5563
import com.oracle.graal.python.runtime.exception.PythonErrorType;
64+
import com.oracle.graal.python.runtime.sequence.PSequence;
65+
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
66+
import com.oracle.truffle.api.dsl.Bind;
5667
import com.oracle.truffle.api.dsl.Cached;
5768
import com.oracle.truffle.api.dsl.Fallback;
5869
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
@@ -61,6 +72,7 @@
6172
import com.oracle.truffle.api.dsl.Specialization;
6273
import com.oracle.truffle.api.frame.VirtualFrame;
6374
import com.oracle.truffle.api.library.CachedLibrary;
75+
import com.oracle.truffle.api.nodes.Node;
6476
import com.oracle.truffle.api.profiles.BranchProfile;
6577
import com.oracle.truffle.api.profiles.ConditionProfile;
6678

@@ -188,6 +200,74 @@ PBaseSet doGeneric(VirtualFrame frame, PSet self, Object[] args,
188200
}
189201
}
190202

203+
@ImportStatic({PGuards.class, PythonOptions.class})
204+
public abstract static class UpdateSingleNode extends Node {
205+
206+
public abstract HashingStorage execute(VirtualFrame frame, HashingStorage storage, Object other);
207+
208+
@Specialization
209+
static HashingStorage update(HashingStorage storage, PHashingCollection other,
210+
@CachedLibrary(limit = "1") HashingStorageLibrary lib) {
211+
HashingStorage dictStorage = other.getDictStorage();
212+
return lib.addAllToOther(dictStorage, storage);
213+
}
214+
215+
@Specialization
216+
static HashingStorage update(HashingStorage storage, PDictView.PDictKeysView other,
217+
@CachedLibrary(limit = "1") HashingStorageLibrary lib) {
218+
HashingStorage dictStorage = other.getWrappedDict().getDictStorage();
219+
return lib.addAllToOther(dictStorage, storage);
220+
}
221+
222+
static boolean isBuiltinSequence(Object other, PythonObjectLibrary lib) {
223+
return other instanceof PSequence && !(other instanceof PString) && lib.getLazyPythonClass(other) instanceof PythonBuiltinClassType;
224+
}
225+
226+
static SequenceStorage getSequenceStorage(PSequence sequence, Class<? extends PSequence> clazz) {
227+
return clazz.cast(sequence).getSequenceStorage();
228+
}
229+
230+
@Specialization(guards = {"isBuiltinSequence(other, otherLib)", "other.getClass() == sequenceClass",
231+
"sequenceStorage.getClass() == storageClass"}, limit = "getCallSiteInlineCacheMaxDepth()")
232+
static HashingStorage doIterable(VirtualFrame frame, HashingStorage storage, @SuppressWarnings("unused") PSequence other,
233+
@SuppressWarnings("unused") @CachedLibrary("other") PythonObjectLibrary otherLib,
234+
@SuppressWarnings("unused") @Cached("other.getClass()") Class<? extends PSequence> sequenceClass,
235+
@Bind("getSequenceStorage(other, sequenceClass)") SequenceStorage sequenceStorage,
236+
@SuppressWarnings("unused") @Cached("sequenceStorage.getClass()") Class<? extends SequenceStorage> storageClass,
237+
@Cached("createBinaryProfile()") ConditionProfile hasFrame,
238+
@CachedLibrary(limit = "2") HashingStorageLibrary lib) {
239+
SequenceStorage profiledSequenceStorage = storageClass.cast(sequenceStorage);
240+
int length = profiledSequenceStorage.length();
241+
HashingStorage curStorage = storage;
242+
for (int i = 0; i < length; i++) {
243+
Object key = profiledSequenceStorage.getItemNormalized(i);
244+
curStorage = lib.setItemWithFrame(curStorage, key, PNone.NONE, hasFrame, frame);
245+
}
246+
return curStorage;
247+
}
248+
249+
@Specialization(guards = {"!isPHashingCollection(other)", "!isDictKeysView(other)", "!isBuiltinSequence(other, otherLib)"}, limit = "getCallSiteInlineCacheMaxDepth()")
250+
static HashingStorage doIterable(VirtualFrame frame, HashingStorage storage, Object other,
251+
@CachedLibrary("other") PythonObjectLibrary otherLib,
252+
@Cached GetNextNode nextNode,
253+
@Cached IsBuiltinClassProfile errorProfile,
254+
@Cached("createBinaryProfile()") ConditionProfile hasFrame,
255+
@CachedLibrary(limit = "2") HashingStorageLibrary lib) {
256+
HashingStorage curStorage = storage;
257+
Object iterator = otherLib.getIteratorWithFrame(other, frame);
258+
while (true) {
259+
Object key;
260+
try {
261+
key = nextNode.execute(frame, iterator);
262+
} catch (PException e) {
263+
e.expectStopIteration(errorProfile);
264+
return curStorage;
265+
}
266+
curStorage = lib.setItemWithFrame(curStorage, key, PNone.NONE, hasFrame, frame);
267+
}
268+
}
269+
}
270+
191271
@Builtin(name = "update", minNumOfPositionalArgs = 1, takesVarArgs = true)
192272
@GenerateNodeFactory
193273
public abstract static class UpdateNode extends PythonBuiltinNode {
@@ -198,6 +278,13 @@ static PNone doSet(VirtualFrame frame, PSet self, PNone other) {
198278
return PNone.NONE;
199279
}
200280

281+
@Specialization(guards = "args.length == 1")
282+
static PNone doCached(VirtualFrame frame, PSet self, Object[] args,
283+
@Cached UpdateSingleNode update) {
284+
self.setDictStorage(update.execute(frame, self.getDictStorage(), args[0]));
285+
return PNone.NONE;
286+
}
287+
201288
@Specialization(guards = {"args.length == len", "args.length < 32"}, limit = "3")
202289
static PNone doCached(VirtualFrame frame, PSet self, Object[] args,
203290
@Cached("args.length") int len,

graalpython/lib-graalpython/type.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __dir__(klass):
5353
# we merge classes is unimportant
5454
for base in bases:
5555
names.update(_classdir(base))
56-
return sorted(list(names))
56+
return sorted(names)
5757
_classdir = __dir__
5858

5959

@@ -70,7 +70,7 @@ def __dir__(obj):
7070
klass = getattr(obj, '__class__', None)
7171
if klass is not None:
7272
names.update(_classdir(klass))
73-
return sorted(list(names))
73+
return sorted(names)
7474
_objectdir = __dir__
7575

7676

0 commit comments

Comments
 (0)