|
47 | 47 |
|
48 | 48 | import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary.ForEachNode;
|
49 | 49 | import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary.HashingStorageIterable;
|
| 50 | +import com.oracle.graal.python.builtins.objects.common.ObjectHashMapFactory.GetNodeGen; |
| 51 | +import com.oracle.graal.python.builtins.objects.function.PArguments; |
| 52 | +import com.oracle.graal.python.builtins.objects.function.PArguments.ThreadState; |
50 | 53 | import com.oracle.graal.python.lib.PyObjectRichCompareBool;
|
51 | 54 | import com.oracle.graal.python.util.PythonUtils;
|
52 | 55 | import com.oracle.truffle.api.CompilerAsserts;
|
53 | 56 | import com.oracle.truffle.api.CompilerDirectives;
|
54 | 57 | import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
|
| 58 | +import com.oracle.truffle.api.dsl.Cached; |
| 59 | +import com.oracle.truffle.api.dsl.GenerateUncached; |
| 60 | +import com.oracle.truffle.api.dsl.Specialization; |
| 61 | +import com.oracle.truffle.api.frame.Frame; |
55 | 62 | import com.oracle.truffle.api.frame.VirtualFrame;
|
56 | 63 | import com.oracle.truffle.api.nodes.LoopNode;
|
57 | 64 | import com.oracle.truffle.api.nodes.Node;
|
@@ -149,11 +156,11 @@ private void markCollision(int compactIndex) {
|
149 | 156 | indices[compactIndex] = indices[compactIndex] | COLLISION_MASK;
|
150 | 157 | }
|
151 | 158 |
|
152 |
| - private boolean isCollision(int index) { |
| 159 | + private static boolean isCollision(int index) { |
153 | 160 | return (index & COLLISION_MASK) != 0;
|
154 | 161 | }
|
155 | 162 |
|
156 |
| - private int unwrapIndex(int value) { |
| 163 | + private static int unwrapIndex(int value) { |
157 | 164 | return value & ~COLLISION_MASK;
|
158 | 165 | }
|
159 | 166 |
|
@@ -498,57 +505,75 @@ public int size() {
|
498 | 505 | return size;
|
499 | 506 | }
|
500 | 507 |
|
501 |
| - public Object get(VirtualFrame frame, DictKey key, GetProfiles profiles) { |
502 |
| - return get(frame, key.getValue(), key.getPythonHash(), profiles); |
503 |
| - } |
| 508 | + @GenerateUncached |
| 509 | + public abstract static class GetNode extends Node { |
| 510 | + public final Object get(ThreadState state, ObjectHashMap map, DictKey key) { |
| 511 | + return execute(state, map, key.getValue(), key.getPythonHash()); |
| 512 | + } |
504 | 513 |
|
505 |
| - public Object get(VirtualFrame frame, Object key, long keyHash, GetProfiles profiles) { |
506 |
| - assert checkInternalState(); |
507 |
| - int compactIndex = getIndex(keyHash); |
508 |
| - int index = indices[compactIndex]; |
509 |
| - if (profiles.foundNullKey.profile(index == EMPTY_INDEX)) { |
510 |
| - return null; |
511 |
| - } |
512 |
| - if (profiles.foundSameHashKey.profile(index != DUMMY_INDEX)) { |
513 |
| - int unwrappedIndex = unwrapIndex(index); |
514 |
| - Object foundValue = getValue(unwrappedIndex); |
515 |
| - if (profiles.foundEqKey.profile(keysEqual(frame, unwrappedIndex, key, keyHash, profiles))) { |
516 |
| - return foundValue; |
517 |
| - } else if (!isCollision(index)) { |
518 |
| - return null; |
519 |
| - } |
| 514 | + public final Object get(ThreadState state, ObjectHashMap map, Object key, long keyHash) { |
| 515 | + return execute(state, map, key, keyHash); |
520 | 516 | }
|
521 | 517 |
|
522 |
| - // collision: intentionally counted loop |
523 |
| - long perturb = keyHash; |
524 |
| - int searchLimit = getBucketsCount() + PERTURB_SHIFTS_COUT; |
525 |
| - int i = 0; |
526 |
| - try { |
527 |
| - for (; i < searchLimit; i++) { |
528 |
| - perturb >>>= PERTURB_SHIFT; |
529 |
| - compactIndex = nextIndex(compactIndex, perturb); |
530 |
| - index = indices[compactIndex]; |
531 |
| - if (profiles.collisionFoundNoValue.profile(index == EMPTY_INDEX)) { |
| 518 | + abstract Object execute(ThreadState state, ObjectHashMap map, Object key, long keyHash); |
| 519 | + |
| 520 | + // "public" for testing... |
| 521 | + @Specialization |
| 522 | + public static Object doGet(ThreadState state, ObjectHashMap map, Object key, long keyHash, |
| 523 | + @Cached("createCountingProfile()") ConditionProfile foundNullKey, |
| 524 | + @Cached("createCountingProfile()") ConditionProfile foundSameHashKey, |
| 525 | + @Cached("createCountingProfile()") ConditionProfile foundEqKey, |
| 526 | + @Cached("createCountingProfile()") ConditionProfile collisionFoundNoValue, |
| 527 | + @Cached("createCountingProfile()") ConditionProfile collisionFoundEqKey, |
| 528 | + @Cached ConditionProfile hasState, |
| 529 | + @Cached PyObjectRichCompareBool.EqNode eqNode) { |
| 530 | + assert map.checkInternalState(); |
| 531 | + int compactIndex = map.getIndex(keyHash); |
| 532 | + int index = map.indices[compactIndex]; |
| 533 | + if (foundNullKey.profile(index == EMPTY_INDEX)) { |
| 534 | + return null; |
| 535 | + } |
| 536 | + if (foundSameHashKey.profile(index != DUMMY_INDEX)) { |
| 537 | + int unwrappedIndex = unwrapIndex(index); |
| 538 | + Object foundValue = map.getValue(unwrappedIndex); |
| 539 | + if (foundEqKey.profile(map.keysEqual(state, unwrappedIndex, key, keyHash, eqNode, hasState))) { |
| 540 | + return foundValue; |
| 541 | + } else if (!isCollision(index)) { |
532 | 542 | return null;
|
533 | 543 | }
|
534 |
| - if (index != DUMMY_INDEX) { |
535 |
| - int unwrappedIndex = unwrapIndex(index); |
536 |
| - Object foundValue = getValue(unwrappedIndex); |
537 |
| - if (profiles.collisionFoundEqKey.profile(keysEqual(frame, unwrappedIndex, key, keyHash, profiles))) { |
538 |
| - return foundValue; |
| 544 | + } |
| 545 | + |
| 546 | + // collision: intentionally counted loop |
| 547 | + long perturb = keyHash; |
| 548 | + int searchLimit = map.getBucketsCount() + PERTURB_SHIFTS_COUT; |
| 549 | + int i = 0; |
| 550 | + try { |
| 551 | + for (; i < searchLimit; i++) { |
| 552 | + perturb >>>= PERTURB_SHIFT; |
| 553 | + compactIndex = map.nextIndex(compactIndex, perturb); |
| 554 | + index = map.indices[compactIndex]; |
| 555 | + if (collisionFoundNoValue.profile(index == EMPTY_INDEX)) { |
| 556 | + return null; |
| 557 | + } |
| 558 | + if (index != DUMMY_INDEX) { |
| 559 | + int unwrappedIndex = unwrapIndex(index); |
| 560 | + Object foundValue = map.getValue(unwrappedIndex); |
| 561 | + if (collisionFoundEqKey.profile(map.keysEqual(state, unwrappedIndex, key, keyHash, eqNode, hasState))) { |
| 562 | + return foundValue; |
| 563 | + } |
| 564 | + } |
| 565 | + if (!isCollision(index)) { |
| 566 | + return null; |
539 | 567 | }
|
540 | 568 | }
|
541 |
| - if (!isCollision(index)) { |
542 |
| - return null; |
543 |
| - } |
| 569 | + } finally { |
| 570 | + LoopNode.reportLoopCount(eqNode, i); |
544 | 571 | }
|
545 |
| - } finally { |
546 |
| - LoopNode.reportLoopCount(profiles, i); |
| 572 | + // all values are dummies? Not possible, since we should have compacted the |
| 573 | + // hashes/keysAndValues arrays in "remove". We always keep some head-room, so there must be |
| 574 | + // at least few empty slots, and we must have hit one. |
| 575 | + throw CompilerDirectives.shouldNotReachHere(); |
547 | 576 | }
|
548 |
| - // all values are dummies? Not possible, since we should have compacted the |
549 |
| - // hashes/keysAndValues arrays in "remove". We always keep some head-room, so there must be |
550 |
| - // at least few empty slots, and we must have hit one. |
551 |
| - throw CompilerDirectives.shouldNotReachHere(); |
552 | 577 | }
|
553 | 578 |
|
554 | 579 | public void put(VirtualFrame frame, DictKey key, Object value, PutProfiles profiles) {
|
@@ -716,11 +741,16 @@ public void remove(VirtualFrame frame, Object key, long keyHash, RemoveProfiles
|
716 | 741 | throw CompilerDirectives.shouldNotReachHere();
|
717 | 742 | }
|
718 | 743 |
|
719 |
| - private boolean keysEqual(VirtualFrame frame, int index, Object key, long keyHash, GetProfiles profiles) { |
| 744 | + private boolean keysEqual(Frame frame, int index, Object key, long keyHash, GetProfiles profiles) { |
720 | 745 | return hashes[index] == keyHash && profiles.eqNode.execute(frame, getKey(index), key);
|
721 | 746 | }
|
722 | 747 |
|
723 |
| - private boolean keysEqual(VirtualFrame frame, int index, Object key, long keyHash, PyObjectRichCompareBool.EqNode eqNode) { |
| 748 | + private boolean keysEqual(Frame frame, int index, Object key, long keyHash, PyObjectRichCompareBool.EqNode eqNode) { |
| 749 | + return hashes[index] == keyHash && eqNode.execute(frame, getKey(index), key); |
| 750 | + } |
| 751 | + |
| 752 | + private boolean keysEqual(ThreadState state, int index, Object key, long keyHash, PyObjectRichCompareBool.EqNode eqNode, ConditionProfile hasState) { |
| 753 | + VirtualFrame frame = hasState.profile(state == null) ? null : PArguments.frameForCall(state); |
724 | 754 | return hashes[index] == keyHash && eqNode.execute(frame, getKey(index), key);
|
725 | 755 | }
|
726 | 756 |
|
|
0 commit comments