Skip to content

Commit 28746f3

Browse files
committed
Add proper reference keeping to SetFuncNode
1 parent b649860 commit 28746f3

File tree

4 files changed

+42
-30
lines changed

4 files changed

+42
-30
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ctypes/CDataTypeBuiltins.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -429,10 +429,6 @@ void PyCData_set(VirtualFrame frame, CDataObject dst, Object type, FieldSet setf
429429
memcpyNode,
430430
writePointerNode);
431431

432-
/* KeepRef steals a refcount from it's last argument */
433-
/*
434-
* If KeepRef fails, we are stumped. The dst memory block has already been changed
435-
*/
436432
keepRefNode.execute(frame, dst, index, result);
437433
}
438434

@@ -574,7 +570,7 @@ void KeepRef(VirtualFrame frame, CDataObject target, int index, Object keep,
574570
@Cached PythonObjectFactory factory) {
575571
CDataObject ob = PyCData_GetContainer(target, factory);
576572
if (!PGuards.isDict(ob.b_objects)) {
577-
ob.b_objects = keep; /* refcount consumed */
573+
ob.b_objects = keep;
578574
return;
579575
}
580576
PDict dict = (PDict) ob.b_objects;

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ctypes/CFieldBuiltins.java

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import static com.oracle.graal.python.util.PythonUtils.ARRAY_ACCESSOR_SWAPPED;
5555
import static com.oracle.graal.python.util.PythonUtils.TS_ENCODING;
5656
import static com.oracle.graal.python.util.PythonUtils.toTruffleStringUncached;
57+
import static com.oracle.graal.python.util.PythonUtils.tsLiteral;
5758

5859
import java.util.List;
5960

@@ -70,10 +71,12 @@
7071
import com.oracle.graal.python.builtins.modules.ctypes.StgDictBuiltins.PyTypeStgDictNode;
7172
import com.oracle.graal.python.builtins.modules.ctypes.memory.Pointer;
7273
import com.oracle.graal.python.builtins.modules.ctypes.memory.PointerNodes;
74+
import com.oracle.graal.python.builtins.modules.ctypes.memory.PointerReference;
7375
import com.oracle.graal.python.builtins.objects.PNone;
7476
import com.oracle.graal.python.builtins.objects.buffer.PythonBufferAccessLibrary;
7577
import com.oracle.graal.python.builtins.objects.bytes.BytesNodes.ToBytesWithoutFrameNode;
7678
import com.oracle.graal.python.builtins.objects.bytes.PBytesLike;
79+
import com.oracle.graal.python.builtins.objects.capsule.PyCapsule;
7780
import com.oracle.graal.python.builtins.objects.common.SequenceStorageNodes.GetInternalByteArrayNode;
7881
import com.oracle.graal.python.builtins.objects.ints.PInt;
7982
import com.oracle.graal.python.builtins.objects.str.StringUtils.SimpleTruffleStringFormatNode;
@@ -90,6 +93,7 @@
9093
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryBuiltinNode;
9194
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
9295
import com.oracle.graal.python.nodes.util.CastToTruffleStringNode;
96+
import com.oracle.graal.python.runtime.PythonContext;
9397
import com.oracle.graal.python.runtime.object.PythonObjectFactory;
9498
import com.oracle.truffle.api.CompilerDirectives;
9599
import com.oracle.truffle.api.dsl.Bind;
@@ -395,7 +399,7 @@ protected abstract static class SetFuncNode extends Node {
395399
abstract Object execute(VirtualFrame frame, FieldSet setfunc, Pointer ptr, Object value, int size);
396400

397401
@Specialization(guards = "setfunc == b_set || setfunc == B_set")
398-
Object b_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
402+
static Object b_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
399403
@Bind("this") Node inliningTarget,
400404
@Shared @Cached PyLongAsLongNode asLongNode,
401405
@Shared @Cached PointerNodes.WriteByteNode writeByteNode) {
@@ -405,7 +409,7 @@ Object b_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, P
405409
}
406410

407411
@Specialization(guards = "setfunc == h_set || setfunc == H_set")
408-
Object h_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
412+
static Object h_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
409413
@Bind("this") Node inliningTarget,
410414
@Shared @Cached PyLongAsLongNode asLongNode,
411415
@Shared @Cached PointerNodes.WriteShortNode writeShortNode) {
@@ -415,7 +419,7 @@ Object h_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, P
415419
}
416420

417421
@Specialization(guards = "setfunc == h_set_sw || setfunc == H_set_sw")
418-
Object h_set_sw(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
422+
static Object h_set_sw(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
419423
@Bind("this") Node inliningTarget,
420424
@Shared @Cached PyLongAsLongNode asLongNode,
421425
@Shared @Cached PointerNodes.WriteShortNode writeShortNode) {
@@ -426,7 +430,7 @@ Object h_set_sw(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc
426430
}
427431

428432
@Specialization(guards = "setfunc == i_set || setfunc == I_set")
429-
Object i_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
433+
static Object i_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
430434
@Bind("this") Node inliningTarget,
431435
@Shared @Cached PyLongAsLongNode asLongNode,
432436
@Shared @Cached PointerNodes.WriteIntNode writeIntNode) {
@@ -436,7 +440,7 @@ Object i_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, P
436440
}
437441

438442
@Specialization(guards = "setfunc == i_set_sw || setfunc == I_set_sw")
439-
Object i_set_sw(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
443+
static Object i_set_sw(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
440444
@Bind("this") Node inliningTarget,
441445
@Shared @Cached PyLongAsLongNode asLongNode,
442446
@Shared @Cached PointerNodes.WriteIntNode writeIntNode) {
@@ -477,7 +481,7 @@ static Object bool_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet
477481
}
478482

479483
@Specialization(guards = "setfunc == l_set || setfunc == L_set")
480-
Object l_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
484+
static Object l_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
481485
@Bind("this") Node inliningTarget,
482486
@Shared @Cached PyLongAsLongNode asLongNode,
483487
@Shared @Cached PointerNodes.WriteLongNode writeLongNode) {
@@ -487,7 +491,7 @@ Object l_set(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, P
487491
}
488492

489493
@Specialization(guards = "setfunc == l_set_sw || setfunc == L_set_sw")
490-
Object l_set_sw(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
494+
static Object l_set_sw(VirtualFrame frame, @SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
491495
@Bind("this") Node inliningTarget,
492496
@Shared @Cached PyLongAsLongNode asLongNode,
493497
@Shared @Cached PointerNodes.WriteLongNode writeLongNode) {
@@ -550,7 +554,7 @@ static Object O_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, O
550554
}
551555

552556
@Specialization(guards = "setfunc == c_set")
553-
Object c_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
557+
static Object c_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
554558
@Bind("this") Node inliningTarget,
555559
@Cached GetInternalByteArrayNode getBytes,
556560
@Shared @Cached PointerNodes.WriteByteNode writeByteNode,
@@ -576,7 +580,7 @@ Object c_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object v
576580

577581
/* u - a single wchar_t character */
578582
@Specialization(guards = "setfunc == u_set")
579-
Object u_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
583+
static Object u_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
580584
@Bind("this") Node inliningTarget,
581585
@Cached CastToTruffleStringNode toString,
582586
@Cached TruffleString.SwitchEncodingNode switchEncodingNode,
@@ -596,7 +600,7 @@ Object u_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object v
596600
}
597601

598602
@Specialization(guards = "setfunc == U_set")
599-
Object U_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, int size,
603+
static Object U_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, int size,
600604
@Bind("this") Node inliningTarget,
601605
@Cached CastToTruffleStringNode toString,
602606
@Cached TruffleString.SwitchEncodingNode switchEncodingNode,
@@ -617,7 +621,7 @@ Object U_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object v
617621
}
618622

619623
@Specialization(guards = "setfunc == s_set")
620-
Object s_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, int length,
624+
static Object s_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, int length,
621625
@Bind("this") Node inliningTarget,
622626
@Cached ToBytesWithoutFrameNode getBytes,
623627
@Shared @Cached PointerNodes.WriteBytesNode writeBytesNode,
@@ -642,19 +646,20 @@ Object s_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object v
642646
}
643647

644648
@Specialization(guards = "setfunc == z_set")
645-
Object z_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
649+
static Object z_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
646650
@Bind("this") Node inliningTarget,
647651
@Shared @Cached PyLongCheckNode longCheckNode,
648652
@Shared @Cached PointerNodes.PointerFromLongNode pointerFromLongNode,
649653
@CachedLibrary(limit = "1") PythonBufferAccessLibrary bufferLib,
650654
@Shared @Cached PointerNodes.WritePointerNode writePointerNode,
651-
@Shared @Cached PRaiseNode raiseNode) {
655+
@Shared @Cached PRaiseNode raiseNode,
656+
@Shared @Cached PythonObjectFactory factory) {
652657
if (value == PNone.NONE) {
653658
writePointerNode.execute(inliningTarget, ptr, Pointer.NULL);
654659
return PNone.NONE;
655660
} else if (longCheckNode.execute(value)) {
656661
writePointerNode.execute(inliningTarget, ptr, pointerFromLongNode.execute(inliningTarget, value));
657-
return value;
662+
return PNone.NONE;
658663
}
659664
if (PGuards.isPBytes(value)) {
660665
int len = bufferLib.getBufferLength(value);
@@ -663,27 +668,33 @@ Object z_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object v
663668
/* ptr is a char**, we need to add the indirection */
664669
Pointer valuePtr = Pointer.bytes(bytes);
665670
writePointerNode.execute(inliningTarget, ptr, valuePtr);
671+
/*
672+
* We make a copy of the memory, so we need to register a destructor to free the
673+
* memory in case it goes to native.
674+
*/
675+
new PointerReference(value, valuePtr, PythonContext.get(inliningTarget).getSharedFinalizer());
666676
return value;
667677
}
668678
throw raiseNode.raise(TypeError, ErrorMessages.BYTES_OR_INT_ADDR_EXPECTED_INSTEAD_OF_P, value);
669679
}
670680

671681
@Specialization(guards = "setfunc == Z_set")
672-
Object Z_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
682+
static Object Z_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
673683
@Bind("this") Node inliningTarget,
674684
@Cached CastToTruffleStringNode toString,
675685
@Shared @Cached PyLongCheckNode longCheckNode,
676686
@Shared @Cached PointerNodes.PointerFromLongNode pointerFromLongNode,
677687
@Cached TruffleString.SwitchEncodingNode switchEncodingNode,
678688
@Cached TruffleString.CopyToByteArrayNode copyToByteArrayNode,
679689
@Shared @Cached PointerNodes.WritePointerNode writePointerNode,
680-
@Shared @Cached PRaiseNode raiseNode) { // CTYPES_UNICODE
690+
@Shared @Cached PRaiseNode raiseNode,
691+
@Shared @Cached PythonObjectFactory factory) { // CTYPES_UNICODE
681692
if (value == PNone.NONE) {
682693
writePointerNode.execute(inliningTarget, ptr, Pointer.NULL);
683694
return PNone.NONE;
684695
} else if (longCheckNode.execute(value)) {
685696
writePointerNode.execute(inliningTarget, ptr, pointerFromLongNode.execute(inliningTarget, value));
686-
return value;
697+
return PNone.NONE;
687698
}
688699
if (!PGuards.isString(value)) {
689700
throw raiseNode.raise(TypeError, ErrorMessages.UNICODE_STR_OR_INT_ADDR_EXPECTED_INSTEAD_OF_P, value);
@@ -697,11 +708,11 @@ Object Z_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object v
697708
/* ptr is a char**, we need to add the indirection */
698709
Pointer valuePtr = Pointer.bytes(bytes);
699710
writePointerNode.execute(inliningTarget, ptr, valuePtr);
700-
return str;
711+
return createPyMemCapsule(valuePtr, factory);
701712
}
702713

703714
@Specialization(guards = "setfunc == P_set")
704-
Object P_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
715+
static Object P_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object value, @SuppressWarnings("unused") int size,
705716
@Bind("this") Node inliningTarget,
706717
@Shared @Cached PyLongCheckNode longCheckNode,
707718
@Shared @Cached PointerNodes.PointerFromLongNode pointerFromLongNode,
@@ -722,11 +733,18 @@ Object P_set(@SuppressWarnings("unused") FieldSet setfunc, Pointer ptr, Object v
722733

723734
@SuppressWarnings("unused")
724735
@Fallback
725-
Object error(VirtualFrame frame, FieldSet setfunc, Pointer ptr, Object value, int size) {
736+
static Object error(VirtualFrame frame, FieldSet setfunc, Pointer ptr, Object value, int size) {
726737
CompilerDirectives.transferToInterpreterAndInvalidate();
727738
throw PRaiseNode.getUncached().raise(NotImplementedError, toTruffleStringUncached("Field setter %s is not supported yet."), setfunc.name());
728739
}
729740

741+
private static TruffleString CTYPES_CFIELD_CAPSULE_NAME_PYMEM = tsLiteral("_ctypes/cfield.c pymem");
742+
743+
private static PyCapsule createPyMemCapsule(Pointer pointer, PythonObjectFactory factory) {
744+
PyCapsule capsule = factory.createCapsule(pointer, CTYPES_CFIELD_CAPSULE_NAME_PYMEM, null);
745+
new PointerReference(capsule, pointer, PythonContext.get(factory).getSharedFinalizer());
746+
return capsule;
747+
}
730748
}
731749

732750
@ImportStatic(FieldGet.class)
@@ -897,7 +915,7 @@ static double f_get_sw(@SuppressWarnings("unused") FieldGet getfunc, Pointer ptr
897915
}
898916

899917
@Specialization(guards = "getfunc == O_get")
900-
Object O_get(@SuppressWarnings("unused") FieldGet getfunc, Pointer ptr, @SuppressWarnings("unused") int size,
918+
static Object O_get(@SuppressWarnings("unused") FieldGet getfunc, Pointer ptr, @SuppressWarnings("unused") int size,
901919
@Bind("this") Node inliningTarget,
902920
@Shared @Cached PointerNodes.ReadPointerNode readPointerNode,
903921
@Cached PointerNodes.ReadPythonObject readPythonObject,

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ctypes/SimpleCDataBuiltins.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ static void Simple_set_value(VirtualFrame frame, CDataObject self, Object value,
106106
assert dict.setfunc != FieldSet.nil;
107107
Object result = setFuncNode.execute(frame, dict.setfunc, self.b_ptr, value, dict.size);
108108

109-
/* consumes the refcount the setfunc returns */
110109
keepRefNode.execute(frame, self, 0, result);
111110
}
112111

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ctypes/memory/PointerReference.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
package com.oracle.graal.python.builtins.modules.ctypes.memory;
22

3-
import com.oracle.graal.python.builtins.modules.ctypes.CDataObject;
43
import com.oracle.graal.python.runtime.AsyncHandler;
54

65
public class PointerReference extends AsyncHandler.SharedFinalizer.FinalizableReference {
76

8-
public PointerReference(CDataObject cDataObject, Pointer pointer, AsyncHandler.SharedFinalizer sharedFinalizer) {
9-
super(cDataObject, pointer, sharedFinalizer);
7+
public PointerReference(Object referent, Pointer pointer, AsyncHandler.SharedFinalizer sharedFinalizer) {
8+
super(referent, pointer, sharedFinalizer);
109
}
1110

1211
@Override

0 commit comments

Comments
 (0)