Skip to content

Commit 4874ad2

Browse files
committed
[GR-34614][GR-45275] More fixes for PyTorch - fix slot and member inheritance
PullRequest: graalpython/2712
2 parents 181b384 + e74890d commit 4874ad2

File tree

23 files changed

+548
-304
lines changed

23 files changed

+548
-304
lines changed

graalpython/com.oracle.graal.python.cext/src/typeobject.c

Lines changed: 97 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -228,48 +228,101 @@ static void inherit_special(PyTypeObject *type, PyTypeObject *base) {
228228
}
229229

230230
static void inherit_slots(PyTypeObject *type, PyTypeObject *base) {
231-
PyTypeObject *basebase;
232-
233-
#undef SLOTDEFINED
234-
#undef COPYSLOT
235-
#undef SLOTDEFINED2
236-
#undef COPYSLOT2
237-
#undef COPYBUF
238-
239-
#define SLOTDEFINED(SLOT) \
240-
(PyTypeObject_##SLOT(base) != 0 && \
241-
(basebase == NULL || PyTypeObject_##SLOT(base) != PyTypeObject_##SLOT(basebase)))
242-
243-
#define COPYSLOT(SLOT) \
244-
if (!PyTypeObject_##SLOT(type) && SLOTDEFINED(SLOT)) set_PyTypeObject_##SLOT(type, PyTypeObject_##SLOT(base));
245-
246-
#define SLOTDEFINED2(SLOT, SLOT2) \
247-
(PyTypeObject_##SLOT(base)->SLOT2 != 0 && \
248-
(basebase == NULL || PyTypeObject_##SLOT(base) != PyTypeObject_##SLOT(basebase)))
249-
250-
#define COPYSLOT2(SLOT, SLOT2) \
251-
if (!PyTypeObject_##SLOT(type)->SLOT2 && SLOTDEFINED2(SLOT, SLOT2)) PyTypeObject_##SLOT(type)->SLOT2 = PyTypeObject_##SLOT(base)->SLOT2;
252-
253-
#define COPYBUF(SLOT) COPYSLOT2(tp_as_buffer, SLOT)
254-
255-
if (type->tp_as_buffer != NULL && base->tp_as_buffer != NULL) {
256-
basebase = base->tp_base;
257-
if (basebase->tp_as_buffer == NULL)
258-
basebase = NULL;
231+
const PyTypeObject *basebase = base->tp_base;
232+
233+
#define SLOTDEFINED(BASE, BASEBASE, GETFUNC) \
234+
(GETFUNC(BASE) != 0 && (BASEBASE == NULL || GETFUNC(BASE) != GETFUNC(BASEBASE)))
235+
236+
#define COPYSLOT_GENERIC(SLOT, TYPE, BASE, BASEBASE, GETFUNC) \
237+
if (!TYPE->SLOT && SLOTDEFINED(BASE, BASEBASE, GETFUNC)) TYPE->SLOT = GETFUNC(BASE);
238+
239+
#define COPYSLOT(SLOT) COPYSLOT_GENERIC(SLOT, type, base, basebase, PyTypeObject_##SLOT)
240+
#define COPYSLOT_GENERIC_GROUP(SLOT, GROUP, PREFIX) COPYSLOT_GENERIC(SLOT, type_##GROUP, base_##GROUP, basebase_##GROUP, PREFIX##_##SLOT)
241+
#define GROUP_DECL(GROUP, TYPE) \
242+
TYPE* type_##GROUP = type->GROUP; \
243+
TYPE* base_##GROUP = PyTypeObject_##GROUP(base); \
244+
TYPE* basebase_##GROUP = basebase ? PyTypeObject_##GROUP(basebase) : NULL; \
245+
if (type_##GROUP != NULL && base_##GROUP != NULL)
246+
#define COPYASYNC(SLOT) COPYSLOT_GENERIC_GROUP(SLOT, tp_as_async, PyAsyncMethods)
247+
#define COPYNUM(SLOT) COPYSLOT_GENERIC_GROUP(SLOT, tp_as_number, PyNumberMethods)
248+
#define COPYSEQ(SLOT) COPYSLOT_GENERIC_GROUP(SLOT, tp_as_sequence, PySequenceMethods)
249+
#define COPYMAP(SLOT) COPYSLOT_GENERIC_GROUP(SLOT, tp_as_mapping, PyMappingMethods)
250+
#define COPYBUF(SLOT) COPYSLOT_GENERIC_GROUP(SLOT, tp_as_buffer, PyBufferProcs)
251+
252+
GROUP_DECL(tp_as_number, PyNumberMethods) {
253+
COPYNUM(nb_add);
254+
COPYNUM(nb_subtract);
255+
COPYNUM(nb_multiply);
256+
COPYNUM(nb_remainder);
257+
COPYNUM(nb_divmod);
258+
COPYNUM(nb_power);
259+
COPYNUM(nb_negative);
260+
COPYNUM(nb_positive);
261+
COPYNUM(nb_absolute);
262+
COPYNUM(nb_bool);
263+
COPYNUM(nb_invert);
264+
COPYNUM(nb_lshift);
265+
COPYNUM(nb_rshift);
266+
COPYNUM(nb_and);
267+
COPYNUM(nb_xor);
268+
COPYNUM(nb_or);
269+
COPYNUM(nb_int);
270+
COPYNUM(nb_float);
271+
COPYNUM(nb_inplace_add);
272+
COPYNUM(nb_inplace_subtract);
273+
COPYNUM(nb_inplace_multiply);
274+
COPYNUM(nb_inplace_remainder);
275+
COPYNUM(nb_inplace_power);
276+
COPYNUM(nb_inplace_lshift);
277+
COPYNUM(nb_inplace_rshift);
278+
COPYNUM(nb_inplace_and);
279+
COPYNUM(nb_inplace_xor);
280+
COPYNUM(nb_inplace_or);
281+
COPYNUM(nb_true_divide);
282+
COPYNUM(nb_floor_divide);
283+
COPYNUM(nb_inplace_true_divide);
284+
COPYNUM(nb_inplace_floor_divide);
285+
COPYNUM(nb_index);
286+
COPYNUM(nb_matrix_multiply);
287+
COPYNUM(nb_inplace_matrix_multiply);
288+
}
289+
290+
GROUP_DECL(tp_as_async, PyAsyncMethods) {
291+
COPYASYNC(am_await);
292+
COPYASYNC(am_aiter);
293+
COPYASYNC(am_anext);
294+
}
295+
296+
GROUP_DECL(tp_as_sequence, PySequenceMethods) {
297+
COPYSEQ(sq_length);
298+
COPYSEQ(sq_concat);
299+
COPYSEQ(sq_repeat);
300+
COPYSEQ(sq_item);
301+
COPYSEQ(sq_ass_item);
302+
COPYSEQ(sq_contains);
303+
COPYSEQ(sq_inplace_concat);
304+
COPYSEQ(sq_inplace_repeat);
305+
}
306+
307+
GROUP_DECL(tp_as_mapping, PyMappingMethods) {
308+
COPYMAP(mp_length);
309+
COPYMAP(mp_subscript);
310+
COPYMAP(mp_ass_subscript);
311+
}
312+
313+
GROUP_DECL(tp_as_buffer, PyBufferProcs) {
259314
COPYBUF(bf_getbuffer);
260315
COPYBUF(bf_releasebuffer);
261316
}
262317

263-
basebase = PyTypeObject_tp_base(base);
264-
265318
COPYSLOT(tp_dealloc);
266-
if (PyTypeObject_tp_getattr(type) == NULL && PyTypeObject_tp_getattro(type) == NULL) {
267-
set_PyTypeObject_tp_getattr(type, PyTypeObject_tp_getattr(base));
268-
set_PyTypeObject_tp_getattro(type, PyTypeObject_tp_getattro(base));
319+
if (type->tp_getattr == NULL && type->tp_getattro == NULL) {
320+
type->tp_getattr = PyTypeObject_tp_getattr(base);
321+
type->tp_getattro = PyTypeObject_tp_getattro(base);
269322
}
270-
if (PyTypeObject_tp_setattr(type) == NULL && PyTypeObject_tp_setattro(type) == NULL) {
271-
set_PyTypeObject_tp_setattr(type, PyTypeObject_tp_setattr(base));
272-
set_PyTypeObject_tp_setattro(type, PyTypeObject_tp_setattro(base));
323+
if (type->tp_setattr == NULL && type->tp_setattro == NULL) {
324+
type->tp_setattr = PyTypeObject_tp_setattr(base);
325+
type->tp_setattro = PyTypeObject_tp_setattro(base);
273326
}
274327
{
275328
/* Always inherit tp_vectorcall_offset to support PyVectorcall_Call().
@@ -279,11 +332,11 @@ static void inherit_slots(PyTypeObject *type, PyTypeObject *base) {
279332

280333
/* Inherit _Py_TPFLAGS_HAVE_VECTORCALL for non-heap types
281334
* if tp_call is not overridden */
282-
if (!PyTypeObject_tp_call(type) &&
335+
if (!type->tp_call &&
283336
(PyTypeObject_tp_flags(base) & _Py_TPFLAGS_HAVE_VECTORCALL) &&
284-
!(PyTypeObject_tp_flags(type) & Py_TPFLAGS_HEAPTYPE))
337+
!(type->tp_flags & Py_TPFLAGS_HEAPTYPE))
285338
{
286-
set_PyTypeObject_tp_flags(type, PyTypeObject_tp_flags(type) | _Py_TPFLAGS_HAVE_VECTORCALL);
339+
type->tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL;
287340
}
288341
/* COPYSLOT(tp_call); */
289342
}
@@ -293,24 +346,24 @@ static void inherit_slots(PyTypeObject *type, PyTypeObject *base) {
293346
}
294347
{
295348
COPYSLOT(tp_alloc);
296-
if ((PyTypeObject_tp_flags(type) & Py_TPFLAGS_HAVE_FINALIZE) &&
349+
if ((type->tp_flags & Py_TPFLAGS_HAVE_FINALIZE) &&
297350
(PyTypeObject_tp_flags(base) & Py_TPFLAGS_HAVE_FINALIZE)) {
298351
COPYSLOT(tp_finalize);
299352
}
300-
if ((PyTypeObject_tp_flags(type) & Py_TPFLAGS_HAVE_GC) ==
353+
if ((type->tp_flags & Py_TPFLAGS_HAVE_GC) ==
301354
(PyTypeObject_tp_flags(base) & Py_TPFLAGS_HAVE_GC)) {
302355
/* They agree about gc. */
303356
COPYSLOT(tp_free);
304357
}
305-
else if ((PyTypeObject_tp_flags(type) & Py_TPFLAGS_HAVE_GC) &&
306-
PyTypeObject_tp_free(type) == NULL &&
358+
else if ((type->tp_flags & Py_TPFLAGS_HAVE_GC) &&
359+
type->tp_free == NULL &&
307360
PyTypeObject_tp_free(base) == PyObject_Free) {
308361
/* A bit of magic to plug in the correct default
309362
* tp_free function when a derived class adds gc,
310363
* didn't define tp_free, and the base uses the
311364
* default non-gc tp_free.
312365
*/
313-
set_PyTypeObject_tp_free(type, PyObject_GC_Del);
366+
type->tp_free = PyObject_GC_Del;
314367
}
315368
/* else they didn't agree about gc, and there isn't something
316369
* obvious to be done -- the type is on its own.

graalpython/com.oracle.graal.python.test/src/tests/cpyext/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@
4949
GRAALPYTHON = sys.implementation.name == "graalpy"
5050

5151

52+
def assert_raises(err, fn, *args, **kwargs):
53+
raised = False
54+
try:
55+
fn(*args, **kwargs)
56+
except err:
57+
raised = True
58+
assert raised
59+
60+
5261
def unhandled_error_compare(x, y):
5362
if (isinstance(x, BaseException) and isinstance(y, BaseException)):
5463
return type(x) == type(y)
@@ -552,6 +561,11 @@ def CPyExtType(name, code, **kwargs):
552561
{{NULL, NULL, 0, NULL}}
553562
}};
554563
564+
static struct PyGetSetDef {name}_getset[] = {{
565+
""" + ("""{tp_getset},""" if "tp_getset" in kwargs else "") + """
566+
{{NULL, NULL, NULL, NULL, NULL}}
567+
}};
568+
555569
static struct PyMemberDef {name}_members[] = {{
556570
""" + ("""{tp_members},""" if "tp_members" in kwargs else "") + """
557571
{{NULL, 0, 0, 0, NULL}}
@@ -587,7 +601,7 @@ def CPyExtType(name, code, **kwargs):
587601
{tp_iternext}, /* tp_iternext */
588602
{name}_methods, /* tp_methods */
589603
{name}_members, /* tp_members */
590-
0, /* tp_getset */
604+
{name}_getset, /* tp_getset */
591605
{tp_base}, /* tp_base */
592606
{tp_dict}, /* tp_dict */
593607
{tp_descr_get}, /* tp_descr_get */

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_member.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -37,21 +37,11 @@
3737
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3838
# SOFTWARE.
3939

40-
import sys
41-
from . import CPyExtType
40+
from . import CPyExtType, assert_raises
4241

4342
__dir__ = __file__.rpartition("/")[0]
4443

4544

46-
def assert_raises(err, fn, *args, **kwargs):
47-
raised = False
48-
try:
49-
fn(*args, **kwargs)
50-
except err:
51-
raised = True
52-
assert raised
53-
54-
5545
def _reference_classmethod(args):
5646
if isinstance(args[0], type(list.append)):
5747
return classmethod(args[0])()
@@ -279,4 +269,4 @@ def test_member(self):
279269
assert_raises(TypeError, setattr, obj, "member_char", "xyz")
280270
assert obj.member_char == "x"
281271

282-
warnings.resetwarnings()
272+
warnings.resetwarnings()

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_method.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,11 @@
3838
# SOFTWARE.
3939
import types
4040

41-
from . import CPyExtType, CPyExtTestCase, unhandled_error_compare, CPyExtFunction
41+
from . import CPyExtType, CPyExtTestCase, unhandled_error_compare, CPyExtFunction, assert_raises
4242

4343
__dir__ = __file__.rpartition("/")[0]
4444

4545

46-
def assert_raises(err, fn, *args, **kwargs):
47-
raised = False
48-
try:
49-
fn(*args, **kwargs)
50-
except err:
51-
raised = True
52-
assert raised
53-
54-
5546
def _reference_classmethod(args):
5647
if isinstance(args[0], type(list.append)):
5748
return classmethod(args[0])()

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_object.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
import sys
4141

42-
from . import CPyExtType, CPyExtTestCase, CPyExtFunction, GRAALPYTHON, unhandled_error_compare
42+
from . import CPyExtType, CPyExtTestCase, CPyExtFunction, GRAALPYTHON, unhandled_error_compare, assert_raises
4343

4444
__dir__ = __file__.rpartition("/")[0]
4545

@@ -370,6 +370,9 @@ class TestNoDictoffsetSubclass(TestNoDictoffset):
370370
obj.__dict__["newAttr"] = 123
371371
assert obj.newAttr == 123, "invalid attr"
372372

373+
obj.__dict__ = {'a': 1}
374+
assert obj.a == 1
375+
373376
def ignore_test_float_subclass(self):
374377
TestFloatSubclass = CPyExtType("TestFloatSubclass",
375378
"""
@@ -741,6 +744,84 @@ class X(_A, B):
741744
x = X()
742745
assert x.foo == "foo"
743746

747+
def test_getset(self):
748+
TestGetter = CPyExtType(
749+
"TestGetter",
750+
"""
751+
static PyObject* foo_getter(PyObject* self, void* unused) {
752+
return PyUnicode_FromString("getter");
753+
}
754+
""",
755+
tp_getset='{"foo", foo_getter, (setter)NULL, NULL, NULL}',
756+
)
757+
obj = TestGetter()
758+
assert obj.foo == 'getter'
759+
760+
def call_set():
761+
obj.foo = 'set'
762+
763+
assert_raises(AttributeError, call_set)
764+
765+
TestSetter = CPyExtType(
766+
"TestSetter",
767+
"""
768+
static int state;
769+
770+
static PyObject* foo_getter(PyObject* self, void* unused) {
771+
if (state == 0)
772+
return PyUnicode_FromString("unset");
773+
else
774+
return PyUnicode_FromString("set");
775+
}
776+
777+
static int foo_setter(PyObject* self, PyObject* val, void* unused) {
778+
state = val != NULL;
779+
return 0;
780+
}
781+
""",
782+
tp_getset='{"foo", foo_getter, (setter)foo_setter, NULL, NULL}',
783+
)
784+
obj = TestSetter()
785+
assert obj.foo == 'unset'
786+
obj.foo = 'asdf'
787+
assert obj.foo == 'set'
788+
del obj.foo
789+
assert obj.foo == 'unset'
790+
791+
def test_member_kind_precedence(self):
792+
TestWithConflictingMember1 = CPyExtType(
793+
"TestWithConflictingMember1",
794+
"""
795+
static PyObject* foo_method(PyObject* self, PyObject* unused) {
796+
return PyUnicode_FromString("method");
797+
}
798+
799+
static PyObject* foo_getter(PyObject* self, void* unused) {
800+
return PyUnicode_FromString("getter");
801+
}
802+
""",
803+
cmembers="PyObject* foo_member;",
804+
tp_members='{"foo", T_OBJECT, offsetof(TestWithConflictingMember1Object, foo_member), 0, NULL}',
805+
tp_methods='{"foo", foo_method, METH_NOARGS, ""}',
806+
tp_getset='{"foo", foo_getter, (setter)NULL, NULL, NULL}',
807+
)
808+
obj = TestWithConflictingMember1()
809+
assert obj.foo() == 'method'
810+
811+
TestWithConflictingMember2 = CPyExtType(
812+
"TestWithConflictingMember2",
813+
"""
814+
static PyObject* foo_getter(PyObject* self, void* unused) {
815+
return PyUnicode_FromString("getter");
816+
}
817+
""",
818+
cmembers="PyObject* foo_member;",
819+
tp_members='{"foo", T_OBJECT, offsetof(TestWithConflictingMember2Object, foo_member), 0, NULL}',
820+
tp_getset='{"foo", foo_getter, (setter)NULL, NULL, NULL}',
821+
)
822+
obj = TestWithConflictingMember2()
823+
assert obj.foo is None # The member takes precedence
824+
744825
def test_slot_precedence(self):
745826
MapAndSeq = CPyExtType("MapAndSeq",
746827
'''

0 commit comments

Comments
 (0)