Skip to content

Commit ca8745f

Browse files
committed
ObjectHashMap: move get into a GetNode
1 parent a6c8dbf commit ca8745f

File tree

3 files changed

+103
-69
lines changed

3 files changed

+103
-69
lines changed

graalpython/com.oracle.graal.python.test/src/com/oracle/graal/python/test/objects/ObjectHashMapTests.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import java.util.stream.Collectors;
5656
import java.util.stream.StreamSupport;
5757

58+
import com.oracle.truffle.api.profiles.ConditionProfile;
5859
import org.junit.Assert;
5960
import org.junit.Test;
6061

@@ -173,14 +174,14 @@ public void testLongHashMapStressTest() {
173174
assertEquals(map.size(), copy.size());
174175
for (Object key : oldKeys) {
175176
assertEquals(key.toString(), //
176-
map.get(null, key, getKeyHash(key), GET_PROFILES), //
177-
copy.get(null, key, getKeyHash(key), GET_PROFILES));
177+
get(map, key, getKeyHash(key)), //
178+
get(copy, key, getKeyHash(key)));
178179
}
179180

180181
map.clear();
181182
assertEquals(0, map.size());
182183
for (Object key : oldKeys) {
183-
assertNull(key.toString(), map.get(null, key, getKeyHash(key), GET_PROFILES));
184+
assertNull(key.toString(), get(map, key, getKeyHash(key)));
184185
}
185186
}
186187

@@ -214,7 +215,7 @@ private static void removeAll(ObjectHashMap map) {
214215
}
215216
for (Long key : keys) {
216217
map.remove(null, key, PyObjectHashNode.hash(key), RM_PROFILES);
217-
assertNull(map.get(null, key, PyObjectHashNode.hash(key), GET_PROFILES));
218+
assertNull(get(map, key, PyObjectHashNode.hash(key)));
218219
}
219220
}
220221

@@ -276,7 +277,7 @@ static <T> void assertEqual(String message, LinkedHashMap<T, Object> expected, O
276277
assertEquals(message + "; hash in DictKey: " + key, hash, it.getKey().getPythonHash());
277278

278279
Object expectedVal = expected.get(key);
279-
Object actualVal = actual.get(null, key, hash, GET_PROFILES);
280+
Object actualVal = get(actual, key, hash);
280281
assertEquals(message + "; value under key: " + key, expectedVal, actualVal);
281282
assertEquals(message + "; value in DictKey: " + key, expectedVal, it.getValue());
282283

@@ -317,4 +318,11 @@ public static Object newValue() {
317318
private static long getKeyHash(Object key) {
318319
return key instanceof Long ? PyObjectHashNode.hash((Long) key) : ((DictKey) key).hash;
319320
}
321+
322+
private static Object get(ObjectHashMap map, Object key, long hash) {
323+
return ObjectHashMap.GetNode.doGet(null, map, key, hash,//
324+
ConditionProfile.getUncached(), ConditionProfile.getUncached(), ConditionProfile.getUncached(),//
325+
ConditionProfile.getUncached(), ConditionProfile.getUncached(), ConditionProfile.getUncached(),//
326+
new EqNodeStub());
327+
}
320328
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/EconomicMapStorage.java

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -114,31 +114,27 @@ static class GetItemWithState {
114114
@Specialization
115115
static Object getItemTruffleString(EconomicMapStorage self, TruffleString key, ThreadState state,
116116
@Shared("tsHash") @Cached TruffleString.HashCodeNode hashCodeNode,
117-
@Shared("getProfiles") @Cached ObjectHashMap.GetProfiles profiles,
118-
@Shared("gotState") @Cached ConditionProfile gotState) {
119-
VirtualFrame frame = gotState.profile(state == null) ? null : PArguments.frameForCall(state);
120-
DictKey newKey = new DictKey(key, PyObjectHashNode.hash(key, hashCodeNode));
121-
return self.map.get(frame, newKey, profiles);
117+
@Shared("getNode") @Cached ObjectHashMap.GetNode getNode) {
118+
return getNode.get(state, self.map, key, PyObjectHashNode.hash(key, hashCodeNode));
122119
}
123120

124121
@Specialization(guards = {"isBuiltinString(key, isBuiltinClassProfile)"}, limit = "1")
125122
static Object getItemPString(EconomicMapStorage self, PString key, ThreadState state,
126123
@Shared("stringMaterialize") @Cached StringMaterializeNode stringMaterializeNode,
127124
@Shared("tsHash") @Cached TruffleString.HashCodeNode hashCodeNode,
128-
@Shared("getProfiles") @Cached ObjectHashMap.GetProfiles profiles,
129-
@Shared("gotState") @Cached ConditionProfile gotState,
125+
@Shared("getNode") @Cached ObjectHashMap.GetNode getNode,
130126
@Shared("builtinProfile") @Cached @SuppressWarnings("unused") IsBuiltinClassProfile isBuiltinClassProfile) {
131127
final TruffleString k = stringMaterializeNode.execute(key);
132-
return getItemTruffleString(self, k, state, hashCodeNode, profiles, gotState);
128+
return getItemTruffleString(self, k, state, hashCodeNode, getNode);
133129
}
134130

135131
@Specialization(replaces = {"getItemTruffleString", "getItemPString"})
136132
static Object getItemGeneric(EconomicMapStorage self, Object key, ThreadState state,
137-
@Shared("getProfiles") @Cached ObjectHashMap.GetProfiles profiles,
138133
@Shared("hashNode") @Cached PyObjectHashNode hashNode,
134+
@Shared("getNode") @Cached ObjectHashMap.GetNode getNode,
139135
@Shared("gotState") @Cached ConditionProfile gotState) {
140136
VirtualFrame frame = gotState.profile(state == null) ? null : PArguments.frameForCall(state);
141-
return self.map.get(frame, key, hashNode.execute(frame, key), profiles);
137+
return getNode.get(state, self.map, key, hashNode.execute(frame, key));
142138
}
143139
}
144140

@@ -332,7 +328,7 @@ public static class EqualsWithState {
332328
@Specialization
333329
static boolean equalSameType(EconomicMapStorage self, EconomicMapStorage other, ThreadState state,
334330
@CachedLibrary("self") HashingStorageLibrary thisLib,
335-
@Shared("getProfiles") @Cached ObjectHashMap.GetProfiles profiles,
331+
@Shared("getNode") @Cached ObjectHashMap.GetNode getNode,
336332
@Shared("eqNode") @Cached PyObjectRichCompareBool.EqNode eqNode,
337333
@Shared("selfEntriesLoop") @Cached LoopConditionProfile loopProfile,
338334
@Shared("selfEntriesLoopExit") @Cached LoopConditionProfile earlyExitProfile,
@@ -348,7 +344,7 @@ static boolean equalSameType(EconomicMapStorage self, EconomicMapStorage other,
348344
if (CompilerDirectives.hasNextTier()) {
349345
counter++;
350346
}
351-
Object otherValue = other.map.get(frame, cursor.getKey(), profiles);
347+
Object otherValue = getNode.get(state, other.map, cursor.getKey());
352348
if (earlyExitProfile.profile(!(otherValue == null || !eqNode.execute(frame, otherValue, getValue(cursor))))) {
353349
// if->continue such that the "true" count of the profile represents the
354350
// loop iterations and the "false" count the early exit
@@ -407,7 +403,7 @@ static int compareSameType(EconomicMapStorage self, EconomicMapStorage other, Th
407403
@CachedLibrary("self") HashingStorageLibrary thisLib,
408404
@Shared("selfEntriesLoop") @Cached LoopConditionProfile loopProfile,
409405
@Shared("selfEntriesLoopExit") @Cached LoopConditionProfile earlyExitProfile,
410-
@Shared("getProfiles") @Cached ObjectHashMap.GetProfiles profiles,
406+
@Shared("getNode") @Cached ObjectHashMap.GetNode getNode,
411407
@Shared("gotState") @Cached ConditionProfile gotState) {
412408
int size = self.map.size();
413409
int size2 = other.map.size();
@@ -422,7 +418,7 @@ static int compareSameType(EconomicMapStorage self, EconomicMapStorage other, Th
422418
if (CompilerDirectives.hasNextTier()) {
423419
counter++;
424420
}
425-
if (earlyExitProfile.profile(other.map.get(frame, getDictKey(cursor), profiles) != null)) {
421+
if (earlyExitProfile.profile(getNode.get(state, other.map, getDictKey(cursor)) != null)) {
426422
continue;
427423
}
428424
return 1;
@@ -481,7 +477,7 @@ public static class IntersectWithState {
481477
static HashingStorage intersectSameType(EconomicMapStorage self, EconomicMapStorage other, ThreadState state,
482478
@CachedLibrary("self") HashingStorageLibrary thisLib,
483479
@Shared("putProfiles") @Cached ObjectHashMap.PutProfiles putProfiles,
484-
@Shared("getProfiles") @Cached ObjectHashMap.GetProfiles getProfiles,
480+
@Shared("getNode") @Cached ObjectHashMap.GetNode getNode,
485481
@Shared("selfEntriesLoop") @Cached LoopConditionProfile loopProfile,
486482
@Shared("gotState") @Cached ConditionProfile gotState) {
487483
EconomicMapStorage result = EconomicMapStorage.create();
@@ -492,7 +488,7 @@ static HashingStorage intersectSameType(EconomicMapStorage self, EconomicMapStor
492488
loopProfile.profileCounted(size);
493489
LoopNode.reportLoopCount(thisLib, size);
494490
while (loopProfile.inject(advance(cursor))) {
495-
if (other.map.get(frame, getDictKey(cursor), getProfiles) != null) {
491+
if (getNode.get(state, other.map, getDictKey(cursor)) != null) {
496492
result.map.put(frame, getDictKey(cursor), getValue(cursor), putProfiles);
497493
}
498494
}
@@ -528,7 +524,7 @@ public static class DiffWithState {
528524
static HashingStorage diffSameType(EconomicMapStorage self, EconomicMapStorage other, ThreadState state,
529525
@CachedLibrary("self") HashingStorageLibrary thisLib,
530526
@Shared("putProfiles") @Cached ObjectHashMap.PutProfiles putProfiles,
531-
@Shared("getProfiles") @Cached ObjectHashMap.GetProfiles getProfiles,
527+
@Shared("getNode") @Cached ObjectHashMap.GetNode getNode,
532528
@Shared("selfEntriesLoop") @Cached LoopConditionProfile loopProfile,
533529
@Shared("gotState") @Cached ConditionProfile gotState) {
534530
EconomicMapStorage result = EconomicMapStorage.create();
@@ -539,7 +535,7 @@ static HashingStorage diffSameType(EconomicMapStorage self, EconomicMapStorage o
539535
loopProfile.profileCounted(size);
540536
LoopNode.reportLoopCount(thisLib, size);
541537
while (loopProfile.inject(advance(cursor))) {
542-
if (other.map.get(frame, getDictKey(cursor), getProfiles) == null) {
538+
if (getNode.get(state, other.map, getDictKey(cursor)) == null) {
543539
result.map.put(frame, getDictKey(cursor), getValue(cursor), putProfiles);
544540
}
545541
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/ObjectHashMap.java

Lines changed: 76 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,18 @@
4747

4848
import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary.ForEachNode;
4949
import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary.HashingStorageIterable;
50+
import com.oracle.graal.python.builtins.objects.common.ObjectHashMapFactory.GetNodeGen;
51+
import com.oracle.graal.python.builtins.objects.function.PArguments;
52+
import com.oracle.graal.python.builtins.objects.function.PArguments.ThreadState;
5053
import com.oracle.graal.python.lib.PyObjectRichCompareBool;
5154
import com.oracle.graal.python.util.PythonUtils;
5255
import com.oracle.truffle.api.CompilerAsserts;
5356
import com.oracle.truffle.api.CompilerDirectives;
5457
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
58+
import com.oracle.truffle.api.dsl.Cached;
59+
import com.oracle.truffle.api.dsl.GenerateUncached;
60+
import com.oracle.truffle.api.dsl.Specialization;
61+
import com.oracle.truffle.api.frame.Frame;
5562
import com.oracle.truffle.api.frame.VirtualFrame;
5663
import com.oracle.truffle.api.nodes.LoopNode;
5764
import com.oracle.truffle.api.nodes.Node;
@@ -149,11 +156,11 @@ private void markCollision(int compactIndex) {
149156
indices[compactIndex] = indices[compactIndex] | COLLISION_MASK;
150157
}
151158

152-
private boolean isCollision(int index) {
159+
private static boolean isCollision(int index) {
153160
return (index & COLLISION_MASK) != 0;
154161
}
155162

156-
private int unwrapIndex(int value) {
163+
private static int unwrapIndex(int value) {
157164
return value & ~COLLISION_MASK;
158165
}
159166

@@ -498,57 +505,75 @@ public int size() {
498505
return size;
499506
}
500507

501-
public Object get(VirtualFrame frame, DictKey key, GetProfiles profiles) {
502-
return get(frame, key.getValue(), key.getPythonHash(), profiles);
503-
}
508+
@GenerateUncached
509+
public abstract static class GetNode extends Node {
510+
public final Object get(ThreadState state, ObjectHashMap map, DictKey key) {
511+
return execute(state, map, key.getValue(), key.getPythonHash());
512+
}
504513

505-
public Object get(VirtualFrame frame, Object key, long keyHash, GetProfiles profiles) {
506-
assert checkInternalState();
507-
int compactIndex = getIndex(keyHash);
508-
int index = indices[compactIndex];
509-
if (profiles.foundNullKey.profile(index == EMPTY_INDEX)) {
510-
return null;
511-
}
512-
if (profiles.foundSameHashKey.profile(index != DUMMY_INDEX)) {
513-
int unwrappedIndex = unwrapIndex(index);
514-
Object foundValue = getValue(unwrappedIndex);
515-
if (profiles.foundEqKey.profile(keysEqual(frame, unwrappedIndex, key, keyHash, profiles))) {
516-
return foundValue;
517-
} else if (!isCollision(index)) {
518-
return null;
519-
}
514+
public final Object get(ThreadState state, ObjectHashMap map, Object key, long keyHash) {
515+
return execute(state, map, key, keyHash);
520516
}
521517

522-
// collision: intentionally counted loop
523-
long perturb = keyHash;
524-
int searchLimit = getBucketsCount() + PERTURB_SHIFTS_COUT;
525-
int i = 0;
526-
try {
527-
for (; i < searchLimit; i++) {
528-
perturb >>>= PERTURB_SHIFT;
529-
compactIndex = nextIndex(compactIndex, perturb);
530-
index = indices[compactIndex];
531-
if (profiles.collisionFoundNoValue.profile(index == EMPTY_INDEX)) {
518+
abstract Object execute(ThreadState state, ObjectHashMap map, Object key, long keyHash);
519+
520+
// "public" for testing...
521+
@Specialization
522+
public static Object doGet(ThreadState state, ObjectHashMap map, Object key, long keyHash,
523+
@Cached("createCountingProfile()") ConditionProfile foundNullKey,
524+
@Cached("createCountingProfile()") ConditionProfile foundSameHashKey,
525+
@Cached("createCountingProfile()") ConditionProfile foundEqKey,
526+
@Cached("createCountingProfile()") ConditionProfile collisionFoundNoValue,
527+
@Cached("createCountingProfile()") ConditionProfile collisionFoundEqKey,
528+
@Cached ConditionProfile hasState,
529+
@Cached PyObjectRichCompareBool.EqNode eqNode) {
530+
assert map.checkInternalState();
531+
int compactIndex = map.getIndex(keyHash);
532+
int index = map.indices[compactIndex];
533+
if (foundNullKey.profile(index == EMPTY_INDEX)) {
534+
return null;
535+
}
536+
if (foundSameHashKey.profile(index != DUMMY_INDEX)) {
537+
int unwrappedIndex = unwrapIndex(index);
538+
Object foundValue = map.getValue(unwrappedIndex);
539+
if (foundEqKey.profile(map.keysEqual(state, unwrappedIndex, key, keyHash, eqNode, hasState))) {
540+
return foundValue;
541+
} else if (!isCollision(index)) {
532542
return null;
533543
}
534-
if (index != DUMMY_INDEX) {
535-
int unwrappedIndex = unwrapIndex(index);
536-
Object foundValue = getValue(unwrappedIndex);
537-
if (profiles.collisionFoundEqKey.profile(keysEqual(frame, unwrappedIndex, key, keyHash, profiles))) {
538-
return foundValue;
544+
}
545+
546+
// collision: intentionally counted loop
547+
long perturb = keyHash;
548+
int searchLimit = map.getBucketsCount() + PERTURB_SHIFTS_COUT;
549+
int i = 0;
550+
try {
551+
for (; i < searchLimit; i++) {
552+
perturb >>>= PERTURB_SHIFT;
553+
compactIndex = map.nextIndex(compactIndex, perturb);
554+
index = map.indices[compactIndex];
555+
if (collisionFoundNoValue.profile(index == EMPTY_INDEX)) {
556+
return null;
557+
}
558+
if (index != DUMMY_INDEX) {
559+
int unwrappedIndex = unwrapIndex(index);
560+
Object foundValue = map.getValue(unwrappedIndex);
561+
if (collisionFoundEqKey.profile(map.keysEqual(state, unwrappedIndex, key, keyHash, eqNode, hasState))) {
562+
return foundValue;
563+
}
564+
}
565+
if (!isCollision(index)) {
566+
return null;
539567
}
540568
}
541-
if (!isCollision(index)) {
542-
return null;
543-
}
569+
} finally {
570+
LoopNode.reportLoopCount(eqNode, i);
544571
}
545-
} finally {
546-
LoopNode.reportLoopCount(profiles, i);
572+
// all values are dummies? Not possible, since we should have compacted the
573+
// hashes/keysAndValues arrays in "remove". We always keep some head-room, so there must be
574+
// at least few empty slots, and we must have hit one.
575+
throw CompilerDirectives.shouldNotReachHere();
547576
}
548-
// all values are dummies? Not possible, since we should have compacted the
549-
// hashes/keysAndValues arrays in "remove". We always keep some head-room, so there must be
550-
// at least few empty slots, and we must have hit one.
551-
throw CompilerDirectives.shouldNotReachHere();
552577
}
553578

554579
public void put(VirtualFrame frame, DictKey key, Object value, PutProfiles profiles) {
@@ -716,11 +741,16 @@ public void remove(VirtualFrame frame, Object key, long keyHash, RemoveProfiles
716741
throw CompilerDirectives.shouldNotReachHere();
717742
}
718743

719-
private boolean keysEqual(VirtualFrame frame, int index, Object key, long keyHash, GetProfiles profiles) {
744+
private boolean keysEqual(Frame frame, int index, Object key, long keyHash, GetProfiles profiles) {
720745
return hashes[index] == keyHash && profiles.eqNode.execute(frame, getKey(index), key);
721746
}
722747

723-
private boolean keysEqual(VirtualFrame frame, int index, Object key, long keyHash, PyObjectRichCompareBool.EqNode eqNode) {
748+
private boolean keysEqual(Frame frame, int index, Object key, long keyHash, PyObjectRichCompareBool.EqNode eqNode) {
749+
return hashes[index] == keyHash && eqNode.execute(frame, getKey(index), key);
750+
}
751+
752+
private boolean keysEqual(ThreadState state, int index, Object key, long keyHash, PyObjectRichCompareBool.EqNode eqNode, ConditionProfile hasState) {
753+
VirtualFrame frame = hasState.profile(state == null) ? null : PArguments.frameForCall(state);
724754
return hashes[index] == keyHash && eqNode.execute(frame, getKey(index), key);
725755
}
726756

0 commit comments

Comments
 (0)