Skip to content

Commit 1852766

Browse files
committed
New dict/set storage based more closely on the PyPy/CPython design
The new storage is more PE friendly, improves some benchmarks, and provides basis for further improvements.
1 parent b9ee4ae commit 1852766

File tree

17 files changed

+1432
-1342
lines changed

17 files changed

+1432
-1342
lines changed
Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
/*
2+
* Copyright (c) 2020, 2022, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* The Universal Permissive License (UPL), Version 1.0
6+
*
7+
* Subject to the condition set forth below, permission is hereby granted to any
8+
* person obtaining a copy of this software, associated documentation and/or
9+
* data (collectively the "Software"), free of charge and under any and all
10+
* copyright rights in the Software, and any and all patent rights owned or
11+
* freely licensable by each licensor hereunder covering either (i) the
12+
* unmodified Software as contributed to or provided by such licensor, or (ii)
13+
* the Larger Works (as defined below), to deal in both
14+
*
15+
* (a) the Software, and
16+
*
17+
* (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if
18+
* one is included with the Software each a "Larger Work" to which the Software
19+
* is contributed by such licensors),
20+
*
21+
* without restriction, including without limitation the rights to copy, create
22+
* derivative works of, display, perform, and distribute the Software and make,
23+
* use, sell, offer for sale, import, export, have made, and have sold the
24+
* Software and the Larger Work(s), and to sublicense the foregoing rights on
25+
* either these or other terms.
26+
*
27+
* This license is subject to the following condition:
28+
*
29+
* The above copyright notice and either this complete permission notice or at a
30+
* minimum a reference to the UPL must be included in all copies or substantial
31+
* portions of the Software.
32+
*
33+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
34+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
35+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
36+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
37+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
38+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
39+
* SOFTWARE.
40+
*/
41+
package com.oracle.graal.python.test.objects;
42+
43+
import static org.junit.Assert.assertArrayEquals;
44+
import static org.junit.Assert.assertEquals;
45+
import static org.junit.Assert.assertNull;
46+
47+
import java.util.ArrayList;
48+
import java.util.Collections;
49+
import java.util.Iterator;
50+
import java.util.LinkedHashMap;
51+
import java.util.List;
52+
import java.util.Random;
53+
import java.util.Spliterator;
54+
import java.util.Spliterators;
55+
import java.util.stream.Collectors;
56+
import java.util.stream.StreamSupport;
57+
58+
import org.junit.Assert;
59+
import org.junit.Test;
60+
61+
import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary.ForEachNode;
62+
import com.oracle.graal.python.builtins.objects.common.ObjectHashMap;
63+
import com.oracle.graal.python.builtins.objects.common.ObjectHashMap.GetProfiles;
64+
import com.oracle.graal.python.builtins.objects.common.ObjectHashMap.MapCursor;
65+
import com.oracle.graal.python.builtins.objects.common.ObjectHashMap.PutProfiles;
66+
import com.oracle.graal.python.builtins.objects.common.ObjectHashMap.RemoveProfiles;
67+
import com.oracle.graal.python.lib.PyObjectHashNode;
68+
import com.oracle.graal.python.lib.PyObjectRichCompareBool;
69+
import com.oracle.truffle.api.frame.Frame;
70+
import com.oracle.truffle.api.interop.TruffleObject;
71+
import com.oracle.truffle.api.profiles.LoopConditionProfile;
72+
73+
public class ObjectHashMapTests {
74+
public static final class DictKey implements TruffleObject {
75+
final long hash;
76+
77+
DictKey(long hash) {
78+
this.hash = hash;
79+
}
80+
}
81+
82+
private static final class EqNodeStub extends PyObjectRichCompareBool.EqNode {
83+
@Override
84+
public boolean execute(Frame frame, Object a, Object b) {
85+
// Sanity check: we do not use any other keys in the tests
86+
assert a instanceof Long || a instanceof DictKey;
87+
assert b instanceof Long || b instanceof DictKey;
88+
// the hashmap should never call __eq__ unless the hashes match
89+
assertEquals("keys: " + a + ", " + b, getKeyHash(a), getKeyHash(b));
90+
return a.equals(b);
91+
}
92+
}
93+
94+
private static final ObjectHashMap.PutProfiles PUT_PROFILES = new PutProfiles(false, new EqNodeStub());
95+
private static final ObjectHashMap.GetProfiles GET_PROFILES = new GetProfiles(false, new EqNodeStub());
96+
private static final ObjectHashMap.RemoveProfiles RM_PROFILES = new RemoveProfiles(false, new EqNodeStub());
97+
98+
@Test
99+
public void testCollisionsByPuttingManyKeysWithSameHash() {
100+
ObjectHashMap map = new ObjectHashMap();
101+
LinkedHashMap<DictKey, Object> expected = new LinkedHashMap<>();
102+
for (int i = 0; i < 100; i++) {
103+
DictKey key = new DictKey(42);
104+
Object value = newValue();
105+
expected.put(key, value);
106+
map.put(null, key, 42, value, PUT_PROFILES);
107+
assertEqual(i, expected, map);
108+
}
109+
for (int i = 0; i < 55; i++) {
110+
DictKey key = expected.keySet().stream().skip(22).findFirst().get();
111+
expected.remove(key);
112+
map.remove(null, key, 42, RM_PROFILES);
113+
assertEqual(i, expected, map);
114+
}
115+
for (int i = 0; i < 10; i++) {
116+
DictKey key = expected.keySet().stream().skip(10 + i).findFirst().get();
117+
Object value = newValue();
118+
expected.put(key, value);
119+
map.put(null, key, 42, value, PUT_PROFILES);
120+
assertEqual(i, expected, map);
121+
}
122+
}
123+
124+
@Test
125+
public void testCollisionsByPuttingAndRemovingTheSameKey() {
126+
ObjectHashMap map = new ObjectHashMap();
127+
LinkedHashMap<DictKey, Object> expected = new LinkedHashMap<>();
128+
DictKey key = new DictKey(42);
129+
for (int i = 0; i < 100; i++) {
130+
Object value = newValue();
131+
map.put(null, key, 42, value, PUT_PROFILES);
132+
expected.put(key, value);
133+
assertEqual(i, expected, map);
134+
135+
map.remove(null, key, 42, RM_PROFILES);
136+
expected.remove(key);
137+
assertEqual(i, expected, map);
138+
}
139+
}
140+
141+
@Test
142+
public void testCollisionsByPuttingAndRemovingTheSameKeys() {
143+
ObjectHashMap map = new ObjectHashMap();
144+
LinkedHashMap<DictKey, Object> expected = new LinkedHashMap<>();
145+
DictKey[] keys = new DictKey[]{new DictKey(42), new DictKey(1)};
146+
for (int i = 0; i < 100; i++) {
147+
Object value = newValue();
148+
final DictKey toPut = keys[i % keys.length];
149+
map.put(null, toPut, toPut.hash, value, PUT_PROFILES);
150+
expected.put(toPut, value);
151+
assertEqual(i, expected, map);
152+
153+
final DictKey toRemove = keys[(i + 1) % keys.length];
154+
map.remove(null, toRemove, toRemove.hash, RM_PROFILES);
155+
expected.remove(toRemove);
156+
assertEqual(i, expected, map);
157+
}
158+
}
159+
160+
@Test
161+
public void testLongHashMapStressTest() {
162+
ObjectHashMap map = new ObjectHashMap();
163+
164+
// put/remove many random (with fixed seed) keys, check consistency against LinkedHashMap
165+
testBasics(map);
166+
removeAll(map);
167+
testBasics(map);
168+
169+
// Basic tests of other methods
170+
Object[] oldKeys = iteratorToArray(map.keys().getIterator());
171+
172+
ObjectHashMap copy = map.copy();
173+
assertEquals(map.size(), copy.size());
174+
for (Object key : oldKeys) {
175+
assertEquals(key.toString(), //
176+
map.get(null, key, getKeyHash(key), GET_PROFILES), //
177+
copy.get(null, key, getKeyHash(key), GET_PROFILES));
178+
}
179+
180+
map.clear();
181+
assertEquals(0, map.size());
182+
for (Object key : oldKeys) {
183+
assertNull(key.toString(), map.get(null, key, getKeyHash(key), GET_PROFILES));
184+
}
185+
}
186+
187+
private static void testBasics(ObjectHashMap map) {
188+
LinkedHashMap<Long, Object> expected = new LinkedHashMap<>();
189+
Random rand = new Random(42);
190+
191+
putValues(map, expected, rand, 100);
192+
removeValues(map, expected, rand, 33);
193+
putValues(map, expected, rand, 44);
194+
overrideValues(map, expected, rand, 55);
195+
removeValues(map, expected, rand, 66);
196+
overrideValues(map, expected, rand, 11);
197+
putValues(map, expected, rand, 22);
198+
199+
removeAll(map, expected);
200+
201+
putValues(map, expected, rand, 50);
202+
removeValues(map, expected, rand, 33);
203+
overrideValues(map, expected, rand, 10);
204+
205+
putValues(map, expected, rand, 300);
206+
removeValues(map, expected, rand, 10);
207+
overrideValues(map, expected, rand, 10);
208+
}
209+
210+
private static void removeAll(ObjectHashMap map) {
211+
ArrayList<Long> keys = new ArrayList<>();
212+
for (Object key : map.keys()) {
213+
keys.add((Long) key);
214+
}
215+
for (Long key : keys) {
216+
map.remove(null, key, PyObjectHashNode.hash(key), RM_PROFILES);
217+
assertNull(map.get(null, key, PyObjectHashNode.hash(key), GET_PROFILES));
218+
}
219+
}
220+
221+
private static void removeAll(ObjectHashMap map, LinkedHashMap<Long, Object> expected) {
222+
for (Long key : expected.keySet().toArray(new Long[0])) {
223+
map.remove(null, key, PyObjectHashNode.hash(key), RM_PROFILES);
224+
expected.remove(key);
225+
assertEqual(Long.toString(key), expected, map);
226+
}
227+
}
228+
229+
private static void removeValues(ObjectHashMap map, LinkedHashMap<Long, Object> expected, Random rand, int count) {
230+
for (int i = 0; i < count; i++) {
231+
int index = rand.nextInt(expected.size() - 1);
232+
long key = expected.keySet().stream().skip(index).findFirst().get();
233+
map.remove(null, key, PyObjectHashNode.hash(key), RM_PROFILES);
234+
expected.remove(key);
235+
assertEqual(i, expected, map);
236+
}
237+
}
238+
239+
private static void overrideValues(ObjectHashMap map, LinkedHashMap<Long, Object> expected, Random rand, int count) {
240+
for (int i = 0; i < count; i++) {
241+
Object value = newValue();
242+
int index = rand.nextInt(expected.size() - 1);
243+
long key = expected.keySet().stream().skip(index).findFirst().get();
244+
map.put(null, key, PyObjectHashNode.hash(key), value, PUT_PROFILES);
245+
expected.put(key, value);
246+
assertEqual(i, expected, map);
247+
}
248+
}
249+
250+
private static void putValues(ObjectHashMap map, LinkedHashMap<Long, Object> expected, Random rand, int count) {
251+
for (int i = 0; i < count; i++) {
252+
Object value = newValue();
253+
long key = rand.nextLong();
254+
map.put(null, key, PyObjectHashNode.hash(key), value, PUT_PROFILES);
255+
expected.put(key, value);
256+
assertEqual(i, expected, map);
257+
}
258+
}
259+
260+
static <T> void assertEqual(int iter, LinkedHashMap<T, Object> expected, ObjectHashMap actual) {
261+
assertEqual(Integer.toString(iter), expected, actual);
262+
}
263+
264+
static <T> void assertEqual(String message, LinkedHashMap<T, Object> expected, ObjectHashMap actual) {
265+
assertEquals(message, expected.size(), actual.size());
266+
267+
// Check getEntries and build array of keys/values
268+
MapCursor it = actual.getEntries();
269+
ArrayList<ObjectHashMap.DictKey> keys = new ArrayList<>();
270+
ArrayList<Object> valuesList = new ArrayList<>();
271+
for (T key : expected.keySet()) {
272+
Assert.assertTrue(message + "; the actual is shorter ", it.advance());
273+
274+
assertEquals(message, key, it.getKey().getValue());
275+
long hash = getKeyHash(key);
276+
assertEquals(message + "; hash in DictKey: " + key, hash, it.getKey().getPythonHash());
277+
278+
Object expectedVal = expected.get(key);
279+
Object actualVal = actual.get(null, key, hash, GET_PROFILES);
280+
assertEquals(message + "; value under key: " + key, expectedVal, actualVal);
281+
assertEquals(message + "; value in DictKey: " + key, expectedVal, it.getValue());
282+
283+
keys.add(it.getKey());
284+
valuesList.add(it.getValue());
285+
}
286+
Assert.assertFalse(message + "; the actual is longer", it.advance());
287+
288+
// Using the array of keys/values, check other methods
289+
List<Object> keysValues = keys.stream().map(ObjectHashMap.DictKey::getValue).collect(Collectors.toList());
290+
assertArrayEquals(message, keysValues.toArray(), iteratorToArray(actual.keys().getIterator()));
291+
292+
List<Object> keysValuesReversed = new ArrayList<>(keysValues);
293+
Collections.reverse(keysValuesReversed);
294+
assertArrayEquals(message, keysValuesReversed.toArray(), iteratorToArray(actual.reverseKeys().getIterator()));
295+
296+
actual.forEachUntyped(new ForEachNode<>() {
297+
@Override
298+
public Object execute(Object key, Object indexObj) {
299+
int index = (int) indexObj;
300+
assertEquals(message + "; for index " + index, keysValues.get(index), key);
301+
return index + 1;
302+
}
303+
}, 0, LoopConditionProfile.getUncached());
304+
}
305+
306+
private static Object[] iteratorToArray(Iterator<Object> iterator) {
307+
var spliterator = Spliterators.spliteratorUnknownSize(iterator, Spliterator.ORDERED);
308+
return StreamSupport.stream(spliterator, false).toArray();
309+
}
310+
311+
private static int valueCounter = 0;
312+
313+
public static Object newValue() {
314+
return "Val: " + (valueCounter++);
315+
}
316+
317+
private static long getKeyHash(Object key) {
318+
return key instanceof Long ? PyObjectHashNode.hash((Long) key) : ((DictKey) key).hash;
319+
}
320+
}

graalpython/com.oracle.graal.python.test/src/tests/test_map_strategies.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __eq__(self, other):
4949
set_strategy = __graalpython__.set_storage_strategy
5050
FACTORIES = [
5151
lambda: set_strategy(dict(), 'empty'),
52-
lambda: set_strategy(dict(), 'hashmap'),
5352
lambda: set_strategy(dict(), 'dynamicobject'),
5453
lambda: set_strategy(dict(), 'economicmap'),
5554
]

graalpython/com.oracle.graal.python.test/src/tests/test_set_strategies.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __eq__(self, other):
4949
set_strategy = __graalpython__.set_storage_strategy
5050
FACTORIES = [
5151
lambda: set_strategy(set(), 'empty'),
52-
lambda: set_strategy(set(), 'hashmap'),
5352
lambda: set_strategy(set(), 'dynamicobject'),
5453
lambda: set_strategy(set(), 'economicmap'),
5554
]

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/GraalPythonModuleBuiltins.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@
8787
import com.oracle.graal.python.builtins.objects.common.DynamicObjectStorage;
8888
import com.oracle.graal.python.builtins.objects.common.EconomicMapStorage;
8989
import com.oracle.graal.python.builtins.objects.common.EmptyStorage;
90-
import com.oracle.graal.python.builtins.objects.common.HashMapStorage;
9190
import com.oracle.graal.python.builtins.objects.common.HashingStorage;
9291
import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary;
9392
import com.oracle.graal.python.builtins.objects.dict.PDict;
@@ -602,8 +601,6 @@ private HashingStorage getStrategy(TruffleString tname, PythonLanguage lang) {
602601
switch (name) {
603602
case "empty":
604603
return EmptyStorage.INSTANCE;
605-
case "hashmap":
606-
return new HashMapStorage();
607604
case "dynamicobject":
608605
return new DynamicObjectStorage(lang);
609606
case "economicmap":

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/hpy/GraalHPyContext.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
import java.util.concurrent.atomic.AtomicReference;
8181
import java.util.logging.Level;
8282

83+
import com.oracle.graal.python.builtins.objects.common.EconomicMapStorage;
84+
import com.oracle.truffle.api.strings.TruffleString;
8385
import org.graalvm.nativeimage.ImageInfo;
8486

8587
import com.oracle.graal.python.PythonLanguage;
@@ -204,7 +206,6 @@
204206
import com.oracle.graal.python.builtins.objects.cext.hpy.GraalHPyNodesFactory.PCallHPyFunctionNodeGen;
205207
import com.oracle.graal.python.builtins.objects.cext.hpy.HPyExternalFunctionNodes.HPyCheckFunctionResultNode;
206208
import com.oracle.graal.python.builtins.objects.common.EmptyStorage;
207-
import com.oracle.graal.python.builtins.objects.common.HashMapStorage;
208209
import com.oracle.graal.python.builtins.objects.common.HashingStorage;
209210
import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary;
210211
import com.oracle.graal.python.builtins.objects.dict.PDict;
@@ -292,7 +293,6 @@
292293
import com.oracle.truffle.api.profiles.ConditionProfile;
293294
import com.oracle.truffle.api.source.Source;
294295
import com.oracle.truffle.api.source.Source.SourceBuilder;
295-
import com.oracle.truffle.api.strings.TruffleString;
296296
import com.oracle.truffle.llvm.spi.NativeTypeLibrary;
297297
import com.oracle.truffle.nfi.api.SignatureLibrary;
298298

@@ -1914,8 +1914,8 @@ public int ctxSetItem(long hSequence, long hKey, long hValue) {
19141914
dict.setDictStorage(dictStorage);
19151915
}
19161916

1917-
if (dictStorage instanceof HashMapStorage) {
1918-
((HashMapStorage) dictStorage).put((TruffleString) key, value);
1917+
if (dictStorage instanceof EconomicMapStorage) {
1918+
((EconomicMapStorage) dictStorage).putUncached((TruffleString) key, value);
19191919
return 0;
19201920
}
19211921
// fall through to generic case

0 commit comments

Comments
 (0)