Skip to content

Commit 3863c3a

Browse files
committed
[GR-34614][GR-44477] Fixes for PyTorch
PullRequest: graalpython/2665
2 parents 7f47600 + 906cad2 commit 3863c3a

File tree

22 files changed

+364
-91
lines changed

22 files changed

+364
-91
lines changed

graalpython/com.oracle.graal.python.cext/src/capi.c

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,17 @@ PRIMITIVE_ARRAY_TO_NATIVE(Long, int64_t, i64, polyglot_as_i64);
894894
PRIMITIVE_ARRAY_TO_NATIVE(Double, double, double, polyglot_as_double);
895895
PRIMITIVE_ARRAY_TO_NATIVE(Object, PyObjectPtr, PyObjectPtr, (PyObjectPtr));
896896

897+
void PyTruffle_PrimitiveArrayFree(void* array) {
898+
free(array);
899+
}
900+
901+
void PyTruffle_ObjectArrayFree(PyObject** array, int32_t size) {
902+
for (int i = 0; i < size; i++) {
903+
Py_DECREF(array[i]);
904+
}
905+
free(array);
906+
}
907+
897908
PyAPI_FUNC(Py_ssize_t) PyTruffle_Object_Size(PyObject *op) {
898909
return ((PyVarObject*)op)->ob_size;
899910
}

graalpython/com.oracle.graal.python.test/src/tests/cpyext/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -36,6 +36,7 @@
3636
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
3737
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3838
# SOFTWARE.
39+
import gc
3940

4041
import sys
4142
import os
@@ -400,6 +401,7 @@ def test(self):
400401
assert cresult == presult, ("%r == %r in %s(%s)" % (cresult, presult, self.name, pargs[i]))
401402
else:
402403
assert self.cmpfunc(cresult, presult), ("%r == %r in %s(%s)" % (cresult, presult, self.name, pargs[i]))
404+
gc.collect()
403405

404406
def __get__(self, instance, typ=None):
405407
if typ is None:

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_bytes.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -132,6 +132,31 @@ def compile_module(self, name):
132132
resulttype="int"
133133
)
134134

135+
test_native_storage = CPyExtFunction(
136+
lambda args: args[0].encode('utf-8')[-1],
137+
lambda: (("hello",), ("world",)),
138+
argspec="O",
139+
arguments=["PyObject* arg"],
140+
resultspec="i",
141+
# The code is creating the bytes objects in such a roundabout way in order to make sure the native storage will
142+
# get collected after the test
143+
code="""
144+
int wrap_test_native_storage(PyObject* str) {
145+
PyObject* bytes = PyUnicode_AsUTF8String(str);
146+
if (bytes == NULL)
147+
return -1;
148+
char* s;
149+
Py_ssize_t sz;
150+
if (PyBytes_AsStringAndSize(bytes, &s, &sz) < 0)
151+
return -1;
152+
int ret = s[sz - 1];
153+
Py_DECREF(bytes);
154+
return ret;
155+
}
156+
""",
157+
callfunction='wrap_test_native_storage',
158+
)
159+
135160
# PyBytes_Size
136161
test_PyBytes_Size = CPyExtFunction(
137162
lambda b: len(b[0]),

graalpython/com.oracle.graal.python.test/src/tests/test_frame_tests.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def test_lineno():
5151
def test_nested_lineno():
5252
def test_nested():
5353
return sys._getframe(0)
54+
5455
f = test_nested()
5556
assert f.f_lineno == 53
5657

@@ -67,9 +68,11 @@ def test_read_and_write_locals():
6768

6869
def test_backref():
6970
a = 'test_backref'
71+
7072
def foo():
7173
a = 'foo'
7274
return sys._getframe(0).f_back
75+
7376
assert foo().f_locals['a'] == 'test_backref'
7477

7578
def get_frame():
@@ -106,7 +109,7 @@ def foo(i):
106109
return stack
107110
else:
108111
# This recursive call will cause
109-
return foo(i+1)
112+
return foo(i + 1)
110113

111114
def bar():
112115
return foo(0)
@@ -151,6 +154,15 @@ def foo():
151154

152155
assert type(locals()['cell']).__name__ == 'cell'
153156

157+
158+
def test_locals_freevar_in_class():
159+
x = 1
160+
161+
class Foo:
162+
c = x
163+
assert 'c' in locals()
164+
assert 'x' not in locals()
165+
154166
# GR-22089
155167
# def test_backref_from_traceback():
156168
# def bar():

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextAbstractBuiltins.java

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import static com.oracle.graal.python.builtins.objects.cext.capi.NativeCAPISymbol.FUN_PY_TRUFFLE_PY_MAPPING_SIZE;
4949
import static com.oracle.graal.python.builtins.objects.cext.capi.NativeCAPISymbol.FUN_PY_TRUFFLE_PY_OBJECT_SIZE;
5050
import static com.oracle.graal.python.builtins.objects.cext.capi.NativeCAPISymbol.FUN_PY_TRUFFLE_PY_SEQUENCE_SIZE;
51+
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.ConstCharPtr;
5152
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.ConstCharPtrAsTruffleString;
5253
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.Int;
5354
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.PyObject;
@@ -57,6 +58,7 @@
5758
import static com.oracle.graal.python.nodes.ErrorMessages.BASE_MUST_BE;
5859
import static com.oracle.graal.python.nodes.ErrorMessages.OBJ_ISNT_MAPPING;
5960
import static com.oracle.graal.python.nodes.ErrorMessages.P_OBJ_DOES_NOT_SUPPORT_ITEM_ASSIGMENT;
61+
import static com.oracle.graal.python.nodes.SpecialAttributeNames.T___DOC__;
6062
import static com.oracle.graal.python.nodes.SpecialMethodNames.T_ITEMS;
6163
import static com.oracle.graal.python.nodes.SpecialMethodNames.T_KEYS;
6264
import static com.oracle.graal.python.nodes.SpecialMethodNames.T_VALUES;
@@ -81,19 +83,24 @@
8183
import com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CApiUnaryBuiltinNode;
8284
import com.oracle.graal.python.builtins.modules.cext.PythonCextErrBuiltins.PyErr_Restore;
8385
import com.oracle.graal.python.builtins.objects.PNone;
86+
import com.oracle.graal.python.builtins.objects.cext.PythonNativeClass;
87+
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.AsCharPointerNode;
8488
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.PCallCapiFunction;
8589
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.ToSulongNode;
8690
import com.oracle.graal.python.builtins.objects.cext.capi.DynamicObjectNativeWrapper.PrimitiveNativeWrapper;
87-
import com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor;
8891
import com.oracle.graal.python.builtins.objects.dict.DictBuiltins.ItemsNode;
8992
import com.oracle.graal.python.builtins.objects.dict.DictBuiltins.KeysNode;
9093
import com.oracle.graal.python.builtins.objects.dict.DictBuiltins.ValuesNode;
9194
import com.oracle.graal.python.builtins.objects.dict.PDict;
95+
import com.oracle.graal.python.builtins.objects.function.PBuiltinFunction;
96+
import com.oracle.graal.python.builtins.objects.getsetdescriptor.GetSetDescriptor;
9297
import com.oracle.graal.python.builtins.objects.ints.PInt;
9398
import com.oracle.graal.python.builtins.objects.iterator.IteratorNodes;
9499
import com.oracle.graal.python.builtins.objects.list.PList;
95100
import com.oracle.graal.python.builtins.objects.mappingproxy.PMappingproxy;
101+
import com.oracle.graal.python.builtins.objects.method.PBuiltinMethod;
96102
import com.oracle.graal.python.builtins.objects.type.TypeNodes.InlinedIsSameTypeNode;
103+
import com.oracle.graal.python.builtins.objects.type.TypeNodes.IsTypeNode;
97104
import com.oracle.graal.python.lib.PyIndexCheckNode;
98105
import com.oracle.graal.python.lib.PyMappingCheckNode;
99106
import com.oracle.graal.python.lib.PyNumberCheckNode;
@@ -107,6 +114,8 @@
107114
import com.oracle.graal.python.lib.PySequenceCheckNode;
108115
import com.oracle.graal.python.lib.PySliceNew;
109116
import com.oracle.graal.python.nodes.ErrorMessages;
117+
import com.oracle.graal.python.nodes.attributes.WriteAttributeToDynamicObjectNode;
118+
import com.oracle.graal.python.nodes.attributes.WriteAttributeToObjectNode;
110119
import com.oracle.graal.python.nodes.builtins.ListNodes.ConstructListNode;
111120
import com.oracle.graal.python.nodes.call.CallNode;
112121
import com.oracle.graal.python.nodes.call.special.LookupAndCallTernaryNode;
@@ -879,19 +888,61 @@ Object check(Object object,
879888
}
880889
}
881890

882-
@CApiBuiltin(ret = ArgDescriptor.ConstCharPtr, args = {PyObject}, call = Direct)
891+
@CApiBuiltin(ret = ConstCharPtr, args = {PyObject}, call = Direct)
883892
abstract static class PyObject_GetDoc extends CApiUnaryBuiltinNode {
884893
@Specialization
885-
Object check(@SuppressWarnings("unused") Object obj) {
894+
Object get(Object obj,
895+
@Cached PyObjectLookupAttr lookupAttr,
896+
@Cached AsCharPointerNode asCharPointerNode) {
897+
try {
898+
Object doc = lookupAttr.execute(null, obj, T___DOC__);
899+
if (!(doc instanceof PNone)) {
900+
return asCharPointerNode.execute(doc);
901+
}
902+
} catch (PException e) {
903+
// ignore
904+
}
886905
return getNULL();
887906
}
888907
}
889908

890909
@CApiBuiltin(ret = Int, args = {PyObject, ConstCharPtrAsTruffleString}, call = Direct)
891910
abstract static class PyObject_SetDoc extends CApiBinaryBuiltinNode {
892911
@Specialization
893-
int check(@SuppressWarnings("unused") Object obj, @SuppressWarnings("unused") TruffleString value) {
912+
static int set(PBuiltinFunction obj, TruffleString value,
913+
@Shared("write") @Cached WriteAttributeToDynamicObjectNode write) {
914+
write.execute(obj, T___DOC__, value);
915+
return 1;
916+
}
917+
918+
@Specialization
919+
static int set(PBuiltinMethod obj, TruffleString value,
920+
@Shared("write") @Cached WriteAttributeToDynamicObjectNode write) {
921+
set(obj.getFunction(), value, write);
922+
return 1;
923+
}
924+
925+
@Specialization
926+
static int set(GetSetDescriptor obj, TruffleString value,
927+
@Shared("write") @Cached WriteAttributeToDynamicObjectNode write) {
928+
write.execute(obj, T___DOC__, value);
894929
return 1;
895930
}
931+
932+
@Specialization(guards = "isType.execute(type)", limit = "1")
933+
static int set(PythonNativeClass type, TruffleString value,
934+
@SuppressWarnings("unused") @Cached IsTypeNode isType,
935+
// TODO we should write to tp_doc, this writes to __doc__ in the type dict
936+
@Cached("createForceType()") WriteAttributeToObjectNode write) {
937+
write.execute(type, T___DOC__, value);
938+
return 1;
939+
}
940+
941+
@Fallback
942+
@SuppressWarnings("unused")
943+
static int set(Object obj, Object value) {
944+
CompilerDirectives.transferToInterpreterAndInvalidate();
945+
throw CompilerDirectives.shouldNotReachHere("Don't know how to set doc for " + obj.getClass());
946+
}
896947
}
897948
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextBuiltinRegistry.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,7 +1938,7 @@ static CApiBuiltinNode createBuiltinNode(int id) {
19381938
case 389:
19391939
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.Py_get_PyTypeObject_tp_baseNodeGen.create();
19401940
case 390:
1941-
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.PyGetSlotDummyPyPtrNodeGen.create();
1941+
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.Py_get_PyTypeObject_tp_basesNodeGen.create();
19421942
case 391:
19431943
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.Py_get_PyTypeObject_tp_basicsizeNodeGen.create();
19441944
case 392:
@@ -1952,9 +1952,9 @@ static CApiBuiltinNode createBuiltinNode(int id) {
19521952
case 396:
19531953
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.Py_get_PyTypeObject_tp_delNodeGen.create();
19541954
case 397:
1955-
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.PyGetSlotDummyPtrNodeGen.create();
1955+
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.Py_get_PyTypeObject_tp_descr_getNodeGen.create();
19561956
case 398:
1957-
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.PyGetSlotDummyPtrNodeGen.create();
1957+
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.Py_get_PyTypeObject_tp_descr_setNodeGen.create();
19581958
case 399:
19591959
return com.oracle.graal.python.builtins.modules.cext.PythonCextSlotBuiltinsFactory.Py_get_PyTypeObject_tp_dictNodeGen.create();
19601960
case 400:

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1203,7 +1203,7 @@ Object wrap(Object bufferStructPointer, Object ownerObj, Object lenObj,
12031203
}
12041204
}
12051205
}
1206-
Object buffer = new NativeSequenceStorage(bufPointer, len, len, SequenceStorage.ListStorageType.Byte);
1206+
Object buffer = NativeSequenceStorage.create(bufPointer, len, len, SequenceStorage.ListStorageType.Byte, false);
12071207
int flags = initFlagsNode.execute(ndim, itemsize, shape, strides, suboffsets);
12081208
BufferLifecycleManager bufferLifecycleManager = null;
12091209
if (!lib.isNull(bufferStructPointer)) {

0 commit comments

Comments
 (0)