Skip to content

Commit 79c348f

Browse files
committed
HashMapStorage should invoke __eq__ on non-supported keys in getItem and delItem
1 parent 33393de commit 79c348f

File tree

3 files changed

+105
-5
lines changed

3 files changed

+105
-5
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_set_strategies.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,28 @@ def test_and():
138138
else:
139139
assert k1 not in res
140140
assert k2 not in res
141-
assert len(res) == 0
141+
assert len(res) == 0
142+
143+
144+
def test_find_custom_key():
145+
class MyWeirdKey(str):
146+
def __init__(self):
147+
self.log = []
148+
def __eq__(self, other):
149+
self.log.append('called __eq__ with %r' % other)
150+
return True
151+
def __hash__(self):
152+
return 'a'.__hash__()
153+
for f in FACTORIES:
154+
# Set with any value that has the same hash contains the weird key
155+
s = f()
156+
s.add('b')
157+
s.add('a')
158+
key = MyWeirdKey()
159+
assert key in s
160+
assert key.log == ["called __eq__ with 'a'"]
161+
# But empty set does not contain the weird key
162+
s = f()
163+
key = MyWeirdKey()
164+
assert key not in s
165+
assert key.log == []

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashMapStorage.java

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary.ForEachNode;
4949
import com.oracle.graal.python.builtins.objects.common.HashingStorageLibrary.HashingStorageIterable;
5050
import com.oracle.graal.python.builtins.objects.function.PArguments.ThreadState;
51+
import com.oracle.graal.python.builtins.objects.ints.PInt;
5152
import com.oracle.graal.python.builtins.objects.object.PythonObjectLibrary;
5253
import com.oracle.graal.python.nodes.PGuards;
5354
import com.oracle.graal.python.nodes.object.IsBuiltinClassProfile;
@@ -103,6 +104,56 @@ static boolean isSupportedKey(Object obj, IsBuiltinClassProfile isBuiltinClassPr
103104
return PGuards.isBuiltinString(obj, isBuiltinClassProfile);
104105
}
105106

107+
private static final class CustomKey {
108+
private final Object value;
109+
private final int hash;
110+
private PythonObjectLibrary lib;
111+
private final ThreadState state;
112+
113+
private CustomKey(Object value, int hash, ThreadState state) {
114+
this.value = value;
115+
this.hash = hash;
116+
this.state = state;
117+
}
118+
119+
@Override
120+
public boolean equals(Object other) {
121+
if (this == other) {
122+
return true;
123+
}
124+
if (other == null) {
125+
return false;
126+
}
127+
Object otherValue = other;
128+
PythonObjectLibrary otherLib = PythonObjectLibrary.getUncached();
129+
if (other instanceof CustomKey) {
130+
otherValue = ((CustomKey) other).value;
131+
otherLib = ((CustomKey) other).getPythonObjLib();
132+
if (hash != ((CustomKey) other).hash) {
133+
return false;
134+
}
135+
} else if (hash != other.hashCode()) {
136+
return false;
137+
}
138+
// Hopefully it will be uncommon that the object we search for will have the same hash
139+
// as some of the items in the storage (it may even equal to some of those items), so
140+
// the uncached equals call does not hurt that much here
141+
return getPythonObjLib().equalsWithState(value, otherValue, otherLib, state);
142+
}
143+
144+
@Override
145+
public int hashCode() {
146+
return hash;
147+
}
148+
149+
PythonObjectLibrary getPythonObjLib() {
150+
if (lib == null) {
151+
lib = PythonObjectLibrary.getFactory().getUncached(value);
152+
}
153+
return lib;
154+
}
155+
}
156+
106157
@Override
107158
@ExportMessage
108159
public int length() {
@@ -138,8 +189,15 @@ static Object getItemNotSupportedKey(@SuppressWarnings("unused") HashMapStorage
138189
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile profile,
139190
@CachedLibrary("key") PythonObjectLibrary lib,
140191
@Exclusive @Cached("createBinaryProfile()") ConditionProfile gotState) {
141-
// must call __hash__ for potential side-effect
142-
getHashWithState(key, lib, state, gotState);
192+
// we must still search the map for items that may have the same hash and that may
193+
// return true from key.__eq__, we use artificial object with overridden Java level
194+
// equals and hashCode methods to perform this search
195+
long hash = getHashWithState(key, lib, state, gotState);
196+
if (PInt.isIntRange(hash)) {
197+
CustomKey keyObj = new CustomKey(key, (int) hash, state);
198+
return get(self.values, keyObj);
199+
}
200+
// else the hashes cannot possibly match
143201
return null;
144202
}
145203
}
@@ -198,9 +256,20 @@ private static void remove(LinkedHashMap<Object, Object> values, Object key) {
198256
values.remove(key);
199257
}
200258

201-
@Specialization(guards = "!isSupportedKey(key, profile)")
259+
@Specialization(guards = "!isSupportedKey(key, profile)", limit = "3")
202260
static HashingStorage delItemNonSupportedKey(HashMapStorage self, @SuppressWarnings("unused") Object key, @SuppressWarnings("unused") ThreadState state,
203-
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile profile) {
261+
@SuppressWarnings("unused") @Cached IsBuiltinClassProfile profile,
262+
@CachedLibrary("key") PythonObjectLibrary lib,
263+
@Exclusive @Cached("createBinaryProfile()") ConditionProfile gotState) {
264+
// we must still search the map for items that may have the same hash and that may
265+
// return true from key.__eq__, we use artificial object with overridden Java level
266+
// equals and hashCode methods to perform this search
267+
long hash = getHashWithState(key, lib, state, gotState);
268+
if (PInt.isIntRange(hash)) {
269+
CustomKey keyObj = new CustomKey(key, (int) hash, state);
270+
remove(self.values, keyObj);
271+
}
272+
// else the hashes cannot possibly match
204273
return self;
205274
}
206275
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/HashingStorage.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,13 @@ protected static long getHashWithState(Object key, PythonObjectLibrary lib, Thre
652652
return lib.hashWithState(key, state);
653653
}
654654

655+
protected static boolean keysEqualWithState(Object selfKey, Object otherKey, PythonObjectLibrary selfLib, PythonObjectLibrary otherLib, ThreadState state, ConditionProfile gotState) {
656+
if (gotState.profile(state == null)) {
657+
return selfLib.equals(selfKey, otherKey, otherLib);
658+
}
659+
return selfLib.equalsWithState(selfKey, otherKey, otherLib, state);
660+
}
661+
655662
/**
656663
* Adds all items from the given mapping object to storage. It is the caller responsibility to
657664
* ensure, that mapping has the 'keys' attribute.

0 commit comments

Comments
 (0)