Skip to content

Commit bc954a2

Browse files
committed
[GR-45421] [GR-45312] Fix/work around various issues encountered in packages.
PullRequest: graalpython/2725
2 parents 124dfc1 + 423da84 commit bc954a2

File tree

4 files changed

+62
-3
lines changed

4 files changed

+62
-3
lines changed

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_object.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,39 @@ def test_index(self):
153153
tester = TestIndex()
154154
assert [0, 1][tester] == 1
155155

156+
def test_slots_binops(self):
157+
TestSlotsBinop = CPyExtType("TestSlotsBinop",
158+
"""
159+
PyObject* test_int_impl(PyObject* self) {
160+
PyErr_SetString(PyExc_RuntimeError, "Should not call __int__");
161+
return NULL;
162+
}
163+
PyObject* test_index_impl(PyObject* self) {
164+
PyErr_SetString(PyExc_RuntimeError, "Should not call __index__");
165+
return NULL;
166+
}
167+
PyObject* test_mul_impl(PyObject* a, PyObject* b) {
168+
return PyLong_FromLong(42);
169+
}
170+
""",
171+
nb_int="test_int_impl",
172+
nb_index="test_index_impl",
173+
nb_multiply="test_mul_impl"
174+
)
175+
assert [4, 2] * TestSlotsBinop() == 42
176+
177+
def test_index(self):
178+
TestIndex = CPyExtType("TestIndex",
179+
"""
180+
PyObject* test_index(PyObject* self) {
181+
return PyLong_FromLong(1);
182+
}
183+
""",
184+
nb_index="test_index"
185+
)
186+
tester = TestIndex()
187+
assert [0, 1][tester] == 1
188+
156189
def test_getattro(self):
157190
return # TODO: not working yet
158191
# XXX: Cludge to get type into C

graalpython/com.oracle.graal.python.test/src/tests/test_builtin.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, 2022, Oracle and/or its affiliates.
1+
# Copyright (c) 2018, 2023, Oracle and/or its affiliates.
22
# Copyright (C) 1996-2020 Python Software Foundation
33
#
44
# Licensed under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
@@ -90,4 +90,22 @@ def test_builtin_constants(self):
9090

9191
def test_min(self):
9292
self.assertEqual(min((), default=1, key="adsf"), 1)
93+
94+
def test_sort_keyfunc(self):
95+
lists = [[], [1], [1,2], [1,2,3], [1,3,2], [3,2,1], [9,3,8,1,7,9,3,6,7,8]]
9396

97+
for l in lists:
98+
count = 0
99+
100+
def keyfunc(v):
101+
nonlocal count
102+
count += 1
103+
return v
104+
105+
result = sorted(l, key = keyfunc)
106+
self.assertEqual(len(l), count)
107+
self.assertEqual(sorted(l), result)
108+
count = 0
109+
result = sorted(l, key = keyfunc, reverse = True)
110+
self.assertEqual(len(l), count)
111+
self.assertEqual(sorted(l, reverse = True), result)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/common/SortNodes.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,11 @@ public static KeySortComparator forClass(Class<?> clazz) {
376376
}
377377

378378
private void sortWithKey(VirtualFrame frame, Object[] array, int len, Object keyfunc, boolean reverse, CallNode callNode, CallContext callContext) {
379+
if (len == 0) {
380+
return;
381+
}
382+
// some packages expect "keyfunc" to be called even for one-element lists
383+
Object key = callNode.execute(frame, keyfunc, array[0]);
379384
if (len <= 1) {
380385
return;
381386
}
@@ -391,7 +396,6 @@ private void sortWithKey(VirtualFrame frame, Object[] array, int len, Object key
391396
* Look at the first key and determine which comparator we could use to compare if the
392397
* keys turn all to be the same primitive type
393398
*/
394-
Object key = callNode.execute(frame, keyfunc, array[0]);
395399
pairArray[reverse ? len - 1 : 0] = new SortingPair(key, array[0]);
396400
Class<?> keyClass = keyClassProfile.profile(key.getClass());
397401
KeySortComparator keySortComparator = KeySortComparator.forClass(keyClass);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/call/special/LookupAndCallReversibleBinaryNode.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,11 @@ private Object doCallObjectR(VirtualFrame frame, Node inliningTarget, Object lef
216216
if (hasLeftCallable.profile(inliningTarget, leftCallable != PNone.NO_VALUE)) {
217217
if (hasRightCallable.profile(inliningTarget, rightCallable != PNone.NO_VALUE) &&
218218
(!isSameTypeNode.execute(inliningTarget, leftClass, rightClass) && isSubtype.execute(frame, rightClass, leftClass) ||
219-
isFlagSequenceCompat(inliningTarget, leftClass, rightClass, slot, noLeftBuiltinClassType, noRightBuiltinClassType))) {
219+
isFlagSequenceCompat(inliningTarget, leftClass, rightClass, slot, noLeftBuiltinClassType, noRightBuiltinClassType) ||
220+
// this condition is a quick fix for the fact that
221+
// CPython tries both normal and reverse nb_multiply
222+
// before trying sq_repeat (happens in numpy):
223+
(slot == SpecialMethodSlot.Mul && leftClass == PythonBuiltinClassType.PList))) {
220224
result = dispatch(frame, inliningTarget, ensureReverseDispatch(), rightCallable, right, left, rightClass, rslot, isSubtype, getEnclosingType);
221225
if (result != PNotImplemented.NOT_IMPLEMENTED) {
222226
return result;

0 commit comments

Comments
 (0)