Skip to content

Commit 9ad077c

Browse files
committed
Migrate to nb_add/sq_concat CPython like slots
1 parent ded5688 commit 9ad077c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1228
-270
lines changed

graalpython/com.oracle.graal.python.annotations/src/com/oracle/graal/python/annotations/Slot.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,10 @@
9393

9494
enum SlotKind {
9595
nb_bool,
96+
nb_add,
9697
sq_length,
9798
sq_item,
99+
sq_concat,
98100
mp_length,
99101
mp_subscript,
100102
tp_descr_get,

graalpython/com.oracle.graal.python.cext/src/capi.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,6 @@ PyAPI_FUNC(int64_t) get_methods_flags(PyTypeObject *cls) {
367367
PyNumberMethods* number = cls->tp_as_number;
368368
if (number != NULL) {
369369
#define COMPUTE_FLAGS(NAME, BIT_IDX) flags |= ((number->NAME != NULL) * BIT_IDX);
370-
COMPUTE_FLAGS(nb_add, NB_ADD)
371370
COMPUTE_FLAGS(nb_subtract, NB_SUBTRACT)
372371
COMPUTE_FLAGS(nb_multiply, NB_MULTIPLY)
373372
COMPUTE_FLAGS(nb_remainder, NB_REMAINDER)
@@ -409,7 +408,6 @@ PyAPI_FUNC(int64_t) get_methods_flags(PyTypeObject *cls) {
409408
if (sequence != NULL) {
410409
#define COMPUTE_FLAGS(NAME, BIT_IDX) flags |= ((sequence->NAME != NULL) * BIT_IDX);
411410
COMPUTE_FLAGS(sq_length, SQ_LENGTH)
412-
COMPUTE_FLAGS(sq_concat, SQ_CONCAT)
413411
COMPUTE_FLAGS(sq_repeat, SQ_REPEAT)
414412
COMPUTE_FLAGS(sq_item, SQ_ITEM)
415413
COMPUTE_FLAGS(sq_ass_item, SQ_ASS_ITEM)

graalpython/com.oracle.graal.python.processor/src/com/oracle/graal/python/processor/SlotsMapping.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import com.oracle.graal.python.annotations.Slot;
4444
import com.oracle.graal.python.annotations.Slot.SlotKind;
45+
import com.oracle.graal.python.processor.SlotsProcessor.TpSlotData;
4546

4647
public class SlotsMapping {
4748
private static String getSuffix(boolean isComplex) {
@@ -51,6 +52,8 @@ private static String getSuffix(boolean isComplex) {
5152
static String getSlotBaseClass(Slot s) {
5253
return switch (s.value()) {
5354
case nb_bool -> "TpSlotInquiry.TpSlotInquiryBuiltin";
55+
case nb_add -> "TpSlotBinaryOp.TpSlotBinaryOpBuiltin";
56+
case sq_concat -> "TpSlotBinaryFunc.TpSlotSqConcat";
5457
case sq_length, mp_length -> "TpSlotLen.TpSlotLenBuiltin" + getSuffix(s.isComplex());
5558
case sq_item -> "TpSlotSizeArgFun.TpSlotSizeArgFunBuiltin";
5659
case mp_subscript -> "TpSlotBinaryFunc.TpSlotMpSubscript";
@@ -65,6 +68,8 @@ static String getSlotNodeBaseClass(Slot s) {
6568
return switch (s.value()) {
6669
case tp_descr_get -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotDescrGet.DescrGetBuiltinNode";
6770
case nb_bool -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotInquiry.NbBoolBuiltinNode";
71+
case nb_add -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotBinaryOp.BinaryOpBuiltinNode";
72+
case sq_concat -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotBinaryFunc.SqConcatBuiltinNode";
6873
case sq_length, mp_length -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotLen.LenBuiltinNode";
6974
case sq_item -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotSizeArgFun.SqItemBuiltinNode";
7075
case mp_subscript -> "com.oracle.graal.python.builtins.objects.type.slots.TpSlotBinaryFunc.MpSubscriptBuiltinNode";
@@ -79,7 +84,7 @@ static String getUncachedExecuteSignature(SlotKind s) {
7984
case nb_bool -> "boolean executeUncached(Object self)";
8085
case tp_descr_get -> "Object executeUncached(Object self, Object obj, Object type)";
8186
case sq_length, mp_length -> "int executeUncached(Object self)";
82-
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript ->
87+
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript, nb_add, sq_concat ->
8388
throw new AssertionError("Should not reach here: should be always complex");
8489
};
8590
}
@@ -88,15 +93,15 @@ static boolean supportsComplex(SlotKind s) {
8893
return switch (s) {
8994
case nb_bool -> false;
9095
case sq_length, mp_length, tp_getattro, tp_descr_get, tp_descr_set,
91-
tp_setattro, sq_item, mp_subscript ->
96+
tp_setattro, sq_item, mp_subscript, nb_add, sq_concat ->
9297
true;
9398
};
9499
}
95100

96101
static boolean supportsSimple(SlotKind s) {
97102
return switch (s) {
98103
case nb_bool, sq_length, mp_length, tp_descr_get -> true;
99-
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript -> false;
104+
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript, nb_add, sq_concat -> false;
100105
};
101106
}
102107

@@ -105,8 +110,18 @@ static String getUncachedExecuteCall(SlotKind s) {
105110
case nb_bool -> "executeBool(null, self)";
106111
case sq_length, mp_length -> "executeInt(null, self)";
107112
case tp_descr_get -> "execute(null, self, obj, type)";
108-
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript ->
113+
case tp_getattro, tp_descr_set, tp_setattro, sq_item, mp_subscript, nb_add, sq_concat ->
109114
throw new AssertionError("Should not reach here: should be always complex");
110115
};
111116
}
117+
118+
public static String getExtraCtorArgs(TpSlotData slot) {
119+
return switch (slot.slot().value()) {
120+
case nb_add -> ", com.oracle.graal.python.nodes.SpecialMethodNames.J___ADD__";
121+
case nb_bool, tp_setattro, tp_getattro,
122+
tp_descr_set, tp_descr_get, mp_subscript,
123+
mp_length, sq_concat, sq_item, sq_length ->
124+
"";
125+
};
126+
}
112127
}

graalpython/com.oracle.graal.python.processor/src/com/oracle/graal/python/processor/SlotsProcessor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ private void writeSlot(CodeWriter w, TpSlotData slot) throws IOException {
184184
w.startLn().//
185185
write("super(").//
186186
write(getNodeFactory(slot, slot.slotNodeType()) + ".getInstance()").//
187-
endLn(");");
187+
write(SlotsMapping.getExtraCtorArgs(slot)).endLn(");");
188188

189189
}
190190
w.writeLn("}");

graalpython/com.oracle.graal.python.test/src/com/oracle/graal/python/test/builtin/MethodFlagsTest.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2023, 2024, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -87,7 +87,6 @@ private static void assertMethodConsistentWithFlag(PythonBuiltinClassType clazz,
8787
@Test
8888
public void testMethodFlagsConsistency() {
8989
for (PythonBuiltinClassType clazz : PythonBuiltinClassType.VALUES) {
90-
assertMethodConsistentWithFlag(clazz, MethodsFlags.NB_ADD | MethodsFlags.SQ_CONCAT, SpecialMethodNames.T___ADD__);
9190
assertMethodConsistentWithFlag(clazz, MethodsFlags.NB_SUBTRACT, SpecialMethodNames.T___SUB__);
9291
assertMethodConsistentWithFlag(clazz, MethodsFlags.NB_MULTIPLY | MethodsFlags.SQ_REPEAT, SpecialMethodNames.T___MUL__);
9392
assertMethodConsistentWithFlag(clazz, MethodsFlags.NB_REMAINDER, SpecialMethodNames.T___MOD__);

graalpython/com.oracle.graal.python.test/src/com/oracle/graal/python/test/builtin/objects/TpSlotsTests.java

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,26 @@ public void testBuilderBasic() {
7272
checkSlotValue(TpSlotMeta.TP_SETATTRO, slots.combined_tp_setattro_setattr());
7373
}
7474

75+
@Test
76+
public void testBuilderExplicitGroups() {
77+
Builder builder = TpSlots.newBuilder();
78+
builder.allocateAllGroups();
79+
TpSlots slots = builder.build();
80+
for (TpSlotGroup group : TpSlotGroup.VALID_VALUES) {
81+
Assert.assertTrue(getGroup(slots, group));
82+
}
83+
}
84+
85+
@Test
86+
public void testBuilderExplicitGroup() {
87+
for (TpSlotGroup group : TpSlotGroup.VALID_VALUES) {
88+
Builder builder = TpSlots.newBuilder();
89+
builder.setExplicitGroup(group);
90+
TpSlots slots = builder.build();
91+
Assert.assertTrue(getGroup(slots, group));
92+
}
93+
}
94+
7595
@Test
7696
public void testBuilderOptimizations1() {
7797
Builder builder = TpSlots.newBuilder();
@@ -112,18 +132,21 @@ private static void verifySlots(TpSlots slots, Function<TpSlotMeta, Boolean> che
112132
Assert.assertNull(def.name(), slotValue);
113133
}
114134
}
115-
for (TpSlotGroup group : TpSlotGroup.values()) {
116-
switch (group) {
117-
case NO_GROUP -> {
118-
}
119-
case AS_NUMBER -> Assert.assertEquals(slots.has_as_number(), groupsSeen.contains(group));
120-
case AS_SEQUENCE -> Assert.assertEquals(slots.has_as_sequence(), groupsSeen.contains(group));
121-
case AS_MAPPING -> Assert.assertEquals(slots.has_as_mapping(), groupsSeen.contains(group));
122-
}
135+
for (TpSlotGroup group : TpSlotGroup.VALID_VALUES) {
136+
Assert.assertEquals(getGroup(slots, group), groupsSeen.contains(group));
123137
}
124138
}
125139

126140
private static void checkSlotValue(TpSlotMeta def, TpSlot slotValue) {
127141
Assert.assertTrue(def.name(), slotValue instanceof TpSlotNative slotNative && slotNative.getCallable() == def);
128142
}
143+
144+
private static boolean getGroup(TpSlots slots, TpSlotGroup group) {
145+
return switch (group) {
146+
case NO_GROUP -> throw new IllegalArgumentException();
147+
case AS_NUMBER -> slots.has_as_number();
148+
case AS_SEQUENCE -> slots.has_as_sequence();
149+
case AS_MAPPING -> slots.has_as_mapping();
150+
};
151+
}
129152
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) 2024, 2024, Oracle and/or its affiliates. All rights reserved.
3+
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4+
*
5+
* The Universal Permissive License (UPL), Version 1.0
6+
*
7+
* Subject to the condition set forth below, permission is hereby granted to any
8+
* person obtaining a copy of this software, associated documentation and/or
9+
* data (collectively the "Software"), free of charge and under any and all
10+
* copyright rights in the Software, and any and all patent rights owned or
11+
* freely licensable by each licensor hereunder covering either (i) the
12+
* unmodified Software as contributed to or provided by such licensor, or (ii)
13+
* the Larger Works (as defined below), to deal in both
14+
*
15+
* (a) the Software, and
16+
*
17+
* (b) any piece of software and/or hardware listed in the lrgrwrks.txt file if
18+
* one is included with the Software each a "Larger Work" to which the Software
19+
* is contributed by such licensors),
20+
*
21+
* without restriction, including without limitation the rights to copy, create
22+
* derivative works of, display, perform, and distribute the Software and make,
23+
* use, sell, offer for sale, import, export, have made, and have sold the
24+
* Software and the Larger Work(s), and to sublicense the foregoing rights on
25+
* either these or other terms.
26+
*
27+
* This license is subject to the following condition:
28+
*
29+
* The above copyright notice and either this complete permission notice or at a
30+
* minimum a reference to the UPL must be included in all copies or substantial
31+
* portions of the Software.
32+
*
33+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
34+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
35+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
36+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
37+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
38+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
39+
* SOFTWARE.
40+
*/
41+
package com.oracle.graal.python.test.builtin.objects.cext;
42+
43+
import static org.hamcrest.CoreMatchers.equalTo;
44+
import static org.hamcrest.CoreMatchers.instanceOf;
45+
import static org.hamcrest.CoreMatchers.not;
46+
import static org.hamcrest.MatcherAssert.assertThat;
47+
48+
import org.junit.Test;
49+
50+
import com.oracle.graal.python.builtins.objects.cext.capi.PyProcsWrapper.TpSlotWrapper;
51+
import com.oracle.graal.python.builtins.objects.type.TpSlots.TpSlotMeta;
52+
import com.oracle.graal.python.builtins.objects.type.slots.TpSlot.TpSlotPythonSingle;
53+
import com.oracle.graal.python.util.PythonUtils;
54+
import com.oracle.truffle.api.strings.TruffleString;
55+
56+
public class SlotWrapperTests {
57+
// Doesn't really need to be a real callable or type
58+
private static final Object DUMMY_CALLABLE = new Object();
59+
private static final Object DUMMY_TYPE = new Object();
60+
61+
@Test
62+
public void testCloneContract() {
63+
TruffleString testName = PythonUtils.toTruffleStringUncached("__test__");
64+
TpSlotPythonSingle slotValue = new TpSlotPythonSingle(DUMMY_CALLABLE, DUMMY_TYPE, testName);
65+
TpSlotPythonSingle slotValue2 = new TpSlotPythonSingle(DUMMY_CALLABLE, DUMMY_TYPE, testName);
66+
for (TpSlotMeta slot : TpSlotMeta.values()) {
67+
if (!slot.supportsManagedSlotValues()) {
68+
continue;
69+
}
70+
// We rely on the wrapper not doing any validation of the value...
71+
TpSlotWrapper wrapper = slot.createNativeWrapper(slotValue);
72+
assertThat(wrapper.getSlot(), equalTo(slotValue));
73+
74+
TpSlotWrapper clonedWrapper = wrapper.cloneWith(slotValue2);
75+
assertThat(clonedWrapper, not(equalTo(wrapper)));
76+
assertThat(clonedWrapper, instanceOf(wrapper.getClass()));
77+
assertThat(clonedWrapper.getSlot(), equalTo(slotValue2));
78+
}
79+
}
80+
}

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_tp_slots.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@
5353
static PyObject* get_tp_setattro(PyObject* unused, PyObject* object) {
5454
return PyLong_FromVoidPtr(Py_TYPE(object)->tp_setattro);
5555
}
56+
static PyObject* get_nb_bool(PyObject* unused, PyObject* object) {
57+
return PyLong_FromVoidPtr(Py_TYPE(object)->tp_as_number == NULL ? NULL : Py_TYPE(object)->tp_as_number->nb_bool);
58+
}
59+
static PyObject* get_tp_as_number(PyObject* unused, PyObject* object) {
60+
return PyLong_FromVoidPtr(Py_TYPE(object)->tp_as_number);
61+
}
62+
static PyObject* get_sq_concat(PyObject* unused, PyObject* object) {
63+
return PyLong_FromVoidPtr(Py_TYPE(object)->tp_as_sequence == NULL ? NULL : Py_TYPE(object)->tp_as_sequence->sq_concat);
64+
}
5665
static PyObject* get_tp_descr_get(PyObject* unused, PyObject* object) {
5766
return PyLong_FromVoidPtr(Py_TYPE(object)->tp_descr_get);
5867
}
@@ -62,6 +71,9 @@
6271
'{"get_tp_attro", (PyCFunction)get_tp_attro, METH_O | METH_STATIC, ""},' +
6372
'{"get_tp_setattr", (PyCFunction)get_tp_setattr, METH_O | METH_STATIC, ""},' +
6473
'{"get_tp_setattro", (PyCFunction)get_tp_setattro, METH_O | METH_STATIC, ""},' +
74+
'{"get_nb_bool", (PyCFunction)get_nb_bool, METH_O | METH_STATIC, ""},' +
75+
'{"get_tp_as_number", (PyCFunction)get_tp_as_number, METH_O | METH_STATIC, ""},' +
76+
'{"get_sq_concat", (PyCFunction)get_sq_concat, METH_O | METH_STATIC, ""},' +
6577
'{"get_tp_descr_get", (PyCFunction)get_tp_descr_get, METH_O | METH_STATIC, ""}')
6678

6779

@@ -513,3 +525,44 @@ class DisablesHash2(TypeWithHash):
513525
# )
514526
#
515527
# assert_has_no_hash(TypeWithoutHashExplicit())
528+
529+
530+
def test_attr_update():
531+
# Note: version with managed super type whose attribute is updated and should
532+
# be propagated to the native subtype segfaults on CPython in various ways
533+
TypeToBeUpdated = CPyExtHeapType("TypeToBeUpdated")
534+
assert SlotsGetter.get_tp_as_number(TypeToBeUpdated()) != 0
535+
536+
TypeToBeUpdated.__bool__ = lambda self: False
537+
assert not bool(TypeToBeUpdated())
538+
539+
TypeToBeUpdated.__add__ = lambda self,other: f"plus {other}"
540+
assert TypeToBeUpdated() + "test" == "plus test"
541+
542+
class ManagedGoesNative:
543+
pass
544+
545+
assert bool(ManagedGoesNative())
546+
assert SlotsGetter.get_nb_bool(ManagedGoesNative()) == 0 # Sends it to native
547+
assert SlotsGetter.get_tp_as_number(ManagedGoesNative()) != 0
548+
assert bool(ManagedGoesNative())
549+
550+
ManagedGoesNative.__bool__ = lambda self: False
551+
assert not bool(ManagedGoesNative())
552+
assert SlotsGetter.get_nb_bool(ManagedGoesNative()) != 0
553+
554+
555+
def test_nb_add_inheritace_does_not_add_sq_concat():
556+
NbAddOnlyHeapType = CPyExtHeapType("NbAddOnlyHeapType",
557+
code = 'PyObject* my_nb_add(PyObject* self, PyObject *other) { return Py_NewRef(self); }',
558+
slots=['{Py_nb_add, &my_nb_add}'])
559+
class ManagedSub(NbAddOnlyHeapType):
560+
pass
561+
562+
assert ManagedSub.__add__
563+
assert SlotsGetter.get_sq_concat(ManagedSub()) == 0
564+
565+
class ManagedSub2(NbAddOnlyHeapType):
566+
def __add__(self, other): return NotImplemented
567+
568+
assert SlotsGetter.get_sq_concat(ManagedSub2()) == 0

graalpython/com.oracle.graal.python.test/src/tests/test_flag_sequence_bug_compat.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020, 2021, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2020, 2024, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -51,7 +51,7 @@ def __add__(self, other):
5151
class BA(bytearray):
5252
pass
5353

54-
# assert BA(b'abc') + C() == "RADD"
54+
assert BA(b'abc') + C() == "RADD"
5555
assert C() + BA(b'abc') == "ADD"
5656

5757
# bytes
@@ -61,7 +61,7 @@ class BA(bytearray):
6161
class B(bytes):
6262
pass
6363

64-
# assert B(b'ab') + C() == "RADD"
64+
assert B(b'ab') + C() == "RADD"
6565
assert C() + B(b'ab') == "ADD"
6666

6767
# list
@@ -71,7 +71,7 @@ class B(bytes):
7171
class L(list):
7272
pass
7373

74-
# assert L([1,2]) + C() == "RADD"
74+
assert L([1,2]) + C() == "RADD"
7575
assert C() + L([1,2]) == "ADD"
7676

7777
# tuple
@@ -81,7 +81,7 @@ class L(list):
8181
class T(tuple):
8282
pass
8383

84-
# assert T((1,2)) + C() == "RADD"
84+
assert T((1,2)) + C() == "RADD"
8585
assert C() + T((1,2)) == "ADD"
8686

8787
# str
@@ -91,7 +91,7 @@ class T(tuple):
9191
class S(str):
9292
pass
9393

94-
# assert S(":") + C() == "RADD"
94+
assert S(":") + C() == "RADD"
9595
assert C() + S(":") == "ADD"
9696

9797
# int

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/PythonBuiltinClassType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,7 @@ private static void initSlots(PythonBuiltinClassType type) {
10311031
initSlots(type.base);
10321032
}
10331033
var slots = type.base.slots.copy();
1034-
slots.override(type.declaredSlots);
1034+
slots.overrideIgnoreGroups(type.declaredSlots);
10351035
type.slots = slots.build();
10361036
}
10371037

0 commit comments

Comments
 (0)