Skip to content

Commit 33a3d0a

Browse files
committed
Properly promote borrowed dict keys in PyDict_Next
1 parent 7519642 commit 33a3d0a

File tree

3 files changed

+75
-49
lines changed

3 files changed

+75
-49
lines changed

graalpython/com.oracle.graal.python.cext/src/dictobject.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ int _PyDict_Next(PyObject *d, Py_ssize_t *ppos, PyObject **pkey, PyObject **pval
7979
}
8080
return 0;
8181
}
82-
(*ppos)++;
8382
if (pkey != NULL) {
8483
*pkey = PyTuple_GetItem(tresult, 0);
8584
}
@@ -89,6 +88,7 @@ int _PyDict_Next(PyObject *d, Py_ssize_t *ppos, PyObject **pkey, PyObject **pval
8988
if (phash != NULL) {
9089
*phash = PyLong_AsSsize_t(PyTuple_GetItem(tresult, 2));
9190
}
91+
*ppos = PyLong_AsSsize_t(PyTuple_GetItem(tresult, 3));
9292
Py_DECREF(tresult);
9393
return 1;
9494

graalpython/com.oracle.graal.python.test/src/tests/unittest_tags/test_capi.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
*graalpython.lib-python.3.test.test_capi.Test_testcapi.test_config
3737
*graalpython.lib-python.3.test.test_capi.Test_testcapi.test_datetime_capi
3838
*graalpython.lib-python.3.test.test_capi.Test_testcapi.test_decref_doesnt_leak
39+
*graalpython.lib-python.3.test.test_capi.Test_testcapi.test_dict_iteration
3940
*graalpython.lib-python.3.test.test_capi.Test_testcapi.test_empty_argparse
4041
*graalpython.lib-python.3.test.test_capi.Test_testcapi.test_incref_decref_API
4142
*graalpython.lib-python.3.test.test_capi.Test_testcapi.test_incref_doesnt_leak

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextDictBuiltins.java

Lines changed: 73 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.Py_hash_t;
5252
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.Py_ssize_t;
5353
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.Void;
54+
import static com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageIterator;
5455
import static com.oracle.graal.python.nodes.ErrorMessages.BAD_ARG_TO_INTERNAL_FUNC_WAS_S_P;
5556
import static com.oracle.graal.python.nodes.ErrorMessages.HASH_MISMATCH;
5657
import static com.oracle.graal.python.nodes.ErrorMessages.OBJ_P_HAS_NO_ATTR_S;
@@ -68,9 +69,9 @@
6869
import com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.PromoteBorrowedValue;
6970
import com.oracle.graal.python.builtins.objects.PNone;
7071
import com.oracle.graal.python.builtins.objects.cext.capi.PythonNativePointer;
72+
import com.oracle.graal.python.builtins.objects.common.EconomicMapStorage;
7173
import com.oracle.graal.python.builtins.objects.common.HashingCollectionNodes.SetItemNode;
7274
import com.oracle.graal.python.builtins.objects.common.HashingStorage;
73-
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes;
7475
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageCopy;
7576
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageGetItem;
7677
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageGetItemWithHash;
@@ -94,11 +95,8 @@
9495
import com.oracle.graal.python.lib.PyObjectGetAttr;
9596
import com.oracle.graal.python.lib.PyObjectHashNode;
9697
import com.oracle.graal.python.lib.PyObjectLookupAttr;
97-
import com.oracle.graal.python.lib.PyObjectSizeNode;
9898
import com.oracle.graal.python.nodes.builtins.ListNodes.ConstructListNode;
9999
import com.oracle.graal.python.nodes.call.CallNode;
100-
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
101-
import com.oracle.graal.python.nodes.object.InlinedGetClassNode;
102100
import com.oracle.graal.python.nodes.util.CastToJavaLongExactNode;
103101
import com.oracle.graal.python.runtime.exception.PException;
104102
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
@@ -108,6 +106,7 @@
108106
import com.oracle.truffle.api.dsl.Specialization;
109107
import com.oracle.truffle.api.nodes.Node;
110108
import com.oracle.truffle.api.profiles.BranchProfile;
109+
import com.oracle.truffle.api.profiles.InlinedBranchProfile;
111110
import com.oracle.truffle.api.profiles.LoopConditionProfile;
112111

113112
public final class PythonCextDictBuiltins {
@@ -124,67 +123,93 @@ Object run() {
124123
@CApiBuiltin(ret = PyObjectTransfer, args = {PyObject, Py_ssize_t}, call = Ignored)
125124
abstract static class PyTruffleDict_Next extends CApiBinaryBuiltinNode {
126125

127-
@Specialization(guards = "pos < size(dict, sizeNode)", limit = "1")
126+
@Specialization
128127
Object run(PDict dict, long pos,
129128
@Bind("this") Node inliningTarget,
130-
@SuppressWarnings("unused") @Cached PyObjectSizeNode sizeNode,
129+
@Cached InlinedBranchProfile needsRewriteProfile,
130+
@Cached InlinedBranchProfile economicMapProfile,
131+
@Cached HashingStorageLen lenNode,
131132
@Cached HashingStorageGetIterator getIterator,
132133
@Cached HashingStorageIteratorNext itNext,
133134
@Cached HashingStorageIteratorKey itKey,
134135
@Cached HashingStorageIteratorValue itValue,
135136
@Cached HashingStorageIteratorKeyHash itKeyHash,
136137
@Cached PromoteBorrowedValue promoteKeyNode,
137138
@Cached PromoteBorrowedValue promoteValueNode,
138-
@Cached SetItemNode setItemNode,
139-
@Cached LoopConditionProfile loopProfile) {
139+
@Cached HashingStorageSetItem setItem) {
140+
/*
141+
* We need to promote primitive values and strings to object types for borrowing to work
142+
* correctly. This is very hard to do mid-iteration, so we do all the promotion for the
143+
* whole dict at once in the first call (which is required to start with position 0). In
144+
* order to not violate the ordering, we construct a completely new storage.
145+
*/
146+
if (pos == 0) {
147+
HashingStorage storage = dict.getDictStorage();
148+
int len = lenNode.execute(storage);
149+
if (len > 0) {
150+
boolean needsRewrite = false;
151+
if (storage instanceof EconomicMapStorage) {
152+
economicMapProfile.enter(inliningTarget);
153+
HashingStorageIterator it = getIterator.execute(storage);
154+
while (itNext.execute(storage, it)) {
155+
if (promoteKeyNode.execute(itKey.execute(storage, it)) != null || promoteValueNode.execute(itValue.execute(storage, it)) != null) {
156+
needsRewrite = true;
157+
break;
158+
}
159+
}
160+
} else {
161+
/*
162+
* Other storages always have string keys or have complex iterators, just
163+
* convert them to economic map
164+
*/
165+
needsRewrite = true;
166+
}
167+
if (needsRewrite) {
168+
needsRewriteProfile.enter(inliningTarget);
169+
EconomicMapStorage newStorage = EconomicMapStorage.create(len);
170+
HashingStorageIterator it = getIterator.execute(storage);
171+
while (itNext.execute(storage, it)) {
172+
Object key = itKey.execute(storage, it);
173+
Object value = itValue.execute(storage, it);
174+
Object promotedKey = promoteKeyNode.execute(key);
175+
if (promotedKey != null) {
176+
key = promotedKey;
177+
}
178+
Object promotedValue = promoteValueNode.execute(value);
179+
if (promotedValue != null) {
180+
value = promotedValue;
181+
}
182+
setItem.execute(null, newStorage, key, value);
183+
}
184+
dict.setDictStorage(newStorage);
185+
}
186+
}
187+
}
140188

141189
HashingStorage storage = dict.getDictStorage();
142-
HashingStorageNodes.HashingStorageIterator it = getIterator.execute(storage);
143-
loopProfile.profileCounted(pos);
144-
for (int i = 0; loopProfile.inject(i <= pos); i++) {
145-
if (!itNext.execute(storage, it)) {
146-
return getNativeNull();
147-
}
190+
HashingStorageIterator it = getIterator.execute(storage);
191+
/*
192+
* The iterator index starts from -1, but pos starts from 0, so we subtract 1 here and
193+
* add it back later when computing new pos.
194+
*/
195+
it.setState((int) pos - 1);
196+
boolean hasNext = itNext.execute(storage, it);
197+
if (!hasNext) {
198+
return getNativeNull();
148199
}
149200
Object key = itKey.execute(storage, it);
150201
Object value = itValue.execute(storage, it);
151-
Object promotedKey = promoteKeyNode.execute(key);
152-
Object promotedValue = promoteValueNode.execute(value);
153-
if (promotedKey != null) {
154-
key = promotedKey;
155-
// TODO: replace key with promoted value (also, re-enable
156-
// 'test_capi.py::test_dict_iteration' once fixed)
157-
}
158-
if (promotedValue != null) {
159-
setItemNode.execute(null, inliningTarget, dict, key, value = promotedValue);
160-
}
161-
return factory().createTuple(new Object[]{key, value, itKeyHash.execute(storage, it)});
202+
assert promoteKeyNode.execute(key) == null;
203+
assert promoteValueNode.execute(value) == null;
204+
long hash = itKeyHash.execute(storage, it);
205+
int newPos = it.getState() + 1;
206+
return factory().createTuple(new Object[]{key, value, hash, newPos});
162207
}
163208

164-
@Specialization(guards = "isGreaterPosOrNative(inliningTarget, pos, dict, sizeNode, getClassNode, isSubtypeNode)", limit = "1")
165-
Object run(@SuppressWarnings("unused") Object dict, @SuppressWarnings("unused") long pos,
166-
@SuppressWarnings("unused") @Bind("this") Node inliningTarget,
167-
@SuppressWarnings("unused") @Cached PyObjectSizeNode sizeNode,
168-
@SuppressWarnings("unused") @Cached InlinedGetClassNode getClassNode,
169-
@SuppressWarnings("unused") @Cached IsSubtypeNode isSubtypeNode) {
209+
@Fallback
210+
Object run(@SuppressWarnings("unused") Object dict, @SuppressWarnings("unused") Object pos) {
170211
return getNativeNull();
171212
}
172-
173-
protected boolean isGreaterPosOrNative(Node inliningTarget, long pos, Object obj, PyObjectSizeNode sizeNode, InlinedGetClassNode getClassNode, IsSubtypeNode isSubtypeNode) {
174-
return (isDict(obj) && pos >= size(obj, sizeNode)) || (!isDict(obj) && !isDictSubtype(inliningTarget, obj, getClassNode, isSubtypeNode));
175-
}
176-
177-
protected boolean isDict(Object obj) {
178-
return obj instanceof PDict;
179-
}
180-
181-
protected int size(Object dict, PyObjectSizeNode sizeNode) {
182-
return sizeNode.execute(null, dict);
183-
}
184-
185-
protected boolean isDictSubtype(Node inliningTarget, Object obj, InlinedGetClassNode getClassNode, IsSubtypeNode isSubtypeNode) {
186-
return isSubtypeNode.execute(getClassNode.execute(inliningTarget, obj), PythonBuiltinClassType.PDict);
187-
}
188213
}
189214

190215
@CApiBuiltin(ret = PyObjectTransfer, args = {PyObject, PyObject, PyObject}, call = Direct)
@@ -465,7 +490,7 @@ static int merge(PDict a, PDict b, @SuppressWarnings("unused") int override,
465490
@Cached HashingStorageSetItemWithHash setAItem,
466491
@Cached LoopConditionProfile loopProfile) {
467492
HashingStorage bStorage = b.getDictStorage();
468-
HashingStorageNodes.HashingStorageIterator bIt = getBIter.execute(bStorage);
493+
HashingStorageIterator bIt = getBIter.execute(bStorage);
469494
HashingStorage aStorage = a.getDictStorage();
470495
while (loopProfile.profile(itBNext.execute(bStorage, bIt))) {
471496
Object key = itBKey.execute(bStorage, bIt);

0 commit comments

Comments
 (0)