Skip to content

Commit d7c6910

Browse files
committed
Keep a reference to memory returned by PyUnicode_AsUTF8AndSize
1 parent 93d3625 commit d7c6910

File tree

4 files changed

+124
-38
lines changed

4 files changed

+124
-38
lines changed

graalpython/com.oracle.graal.python.cext/src/unicodeobject.c

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -246,12 +246,11 @@ const char* PyUnicode_AsUTF8(PyObject *unicode) {
246246
}
247247

248248
const char* PyUnicode_AsUTF8AndSize(PyObject *unicode, Py_ssize_t *psize) {
249-
PyObject *result;
250-
result = _PyUnicode_AsUTF8String(unicode, NULL);
251-
if (psize) {
252-
*psize = PyObject_Length(result);
249+
const char* charptr = GraalPyTruffle_Unicode_AsUTF8AndSize_CharPtr(unicode);
250+
if (charptr && psize) {
251+
*psize = GraalPyTruffle_Unicode_AsUTF8AndSize_Size(unicode);
253252
}
254-
return PyBytes_AsString(result);
253+
return charptr;
255254
}
256255

257256
// taken from CPython "Python/Objects/unicodeobject.c"
@@ -295,13 +294,11 @@ Py_UNICODE* PyUnicode_AsUnicode(PyObject *unicode) {
295294
}
296295

297296
Py_UNICODE* PyUnicode_AsUnicodeAndSize(PyObject *unicode, Py_ssize_t *size) {
298-
PyObject* bytes = GraalPyTruffle_Unicode_AsWideChar(unicode, Py_UNICODE_SIZE);
299-
if (bytes != NULL) {
300-
// exclude null terminator at the end
301-
*size = PyBytes_Size(bytes) / Py_UNICODE_SIZE;
302-
return (Py_UNICODE*) PyBytes_AsString(bytes);
297+
Py_UNICODE* charptr = GraalPyTruffle_Unicode_AsUnicodeAndSize_CharPtr(unicode);
298+
if (charptr && size) {
299+
*size = GraalPyTruffle_Unicode_AsUnicodeAndSize_Size(unicode);
303300
}
304-
return NULL;
301+
return charptr;
305302
}
306303

307304
int _PyUnicode_Ready(PyObject *unicode) {

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_unicode.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2018, 2023, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -37,10 +37,10 @@
3737
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
3838
# SOFTWARE.
3939

40-
import sys
4140
import re
41+
import sys
4242

43-
from . import CPyExtTestCase, CPyExtFunction, unhandled_error_compare, GRAALPYTHON
43+
from . import CPyExtTestCase, CPyExtFunction, unhandled_error_compare, GRAALPYTHON, CPyExtFunctionOutVars
4444

4545
__dir__ = __file__.rpartition("/")[0]
4646

@@ -99,9 +99,10 @@ def _reference_contains(args):
9999
raise TypeError
100100
return args[1] in args[0]
101101

102+
102103
def _reference_compare(args):
103104
if not isinstance(args[0], str) or not isinstance(args[1], str):
104-
raise TypeError
105+
raise TypeError
105106

106107
if args[0] == args[1]:
107108
return 0
@@ -110,24 +111,29 @@ def _reference_compare(args):
110111
else:
111112
return 1
112113

114+
113115
def _reference_as_encoded_string(args):
114116
if not isinstance(args[0], str):
115-
raise TypeError
117+
raise TypeError
116118

117119
s = args[0]
118120
encoding = args[1]
119121
errors = args[2]
120122
return s.encode(encoding, errors)
121123

124+
122125
_codecs_module = None
126+
127+
123128
def _reference_as_unicode_escape_string(args):
124129
if not isinstance(args[0], str):
125-
raise TypeError
130+
raise TypeError
126131
global _codecs_module
127132
if not _codecs_module:
128133
import _codecs as _codecs_module
129134
return _codecs_module.unicode_escape_encode(args[0])[0]
130135

136+
131137
def _reference_tailmatch(args):
132138
if not isinstance(args[0], str) or not isinstance(args[1], str):
133139
raise TypeError
@@ -141,6 +147,7 @@ def _reference_tailmatch(args):
141147
return 1 if s[start:end].endswith(substr) else 0
142148
return 1 if s[start:end].startswith(substr) else 0
143149

150+
144151
class CustomString(str):
145152
pass
146153

@@ -336,7 +343,7 @@ def compile_module(self, name):
336343
_reference_fromformat,
337344
lambda: (
338345
("word0: %s; word1: %s; int: %d; long long: %lld", "hello", "world", 1234, 1234),
339-
("word0: %s; word1: %s; int: %d; long long: %lld", "hello", "world", 1234, (1<<44)+123),
346+
("word0: %s; word1: %s; int: %d; long long: %lld", "hello", "world", 1234, (1 << 44) + 123),
340347
),
341348
code="typedef long long longlong_t;",
342349
resultspec="O",
@@ -461,6 +468,20 @@ def compile_module(self, name):
461468
cmpfunc=unhandled_error_compare
462469
)
463470

471+
test_PyUnicode_AsUTF8AndSize = CPyExtFunctionOutVars(
472+
lambda args: (s := args[0].encode("utf-8"), len(s)),
473+
lambda: (
474+
("hello",),
475+
("hellö",),
476+
),
477+
resultspec="yn",
478+
resulttype='const char*',
479+
argspec='O',
480+
arguments=["PyObject* s"],
481+
resultvars=["Py_ssize_t size"],
482+
cmpfunc=unhandled_error_compare
483+
)
484+
464485
test_PyUnicode_DecodeUTF32 = CPyExtFunction(
465486
lambda args: args[1],
466487
lambda: (
@@ -487,7 +508,7 @@ def compile_module(self, name):
487508
test_PyUnicode_DecodeUTF8Stateful = CPyExtFunction(
488509
lambda args: args[0],
489510
lambda: (
490-
("_type_", ),
511+
("_type_",),
491512
),
492513
code="""PyObject* wrap_PyUnicode_DecodeUTF8Stateful(PyObject* _type_str) {
493514
_Py_IDENTIFIER(_type_);
@@ -569,7 +590,8 @@ def compile_module(self, name):
569590
test_PyUnicode_AsUnicode = CPyExtFunction(
570591
lambda args: True,
571592
lambda: (
572-
("hello", b'\x68\x00\x65\x00\x6c\x00\x6c\x00\x6f\x00', b"\x68\x00\x00\x00\x65\x00\x00\x00\x6c\x00\x00\x00\x6c\x00\x00\x00\x6f\x00\x00\x00"),
593+
("hello", b'\x68\x00\x65\x00\x6c\x00\x6c\x00\x6f\x00',
594+
b"\x68\x00\x00\x00\x65\x00\x00\x00\x6c\x00\x00\x00\x6c\x00\x00\x00\x6f\x00\x00\x00"),
573595
),
574596
code=""" PyObject* wrap_PyUnicode_AsUnicode(PyObject* unicodeObj, PyObject* expected_16, PyObject* expected_32) {
575597
Py_ssize_t n = Py_UNICODE_SIZE == 2 ? PyBytes_Size(expected_16) : PyBytes_Size(expected_32);
@@ -595,7 +617,8 @@ def compile_module(self, name):
595617
test_PyUnicode_AsUnicodeAndSize = CPyExtFunction(
596618
lambda args: True,
597619
lambda: (
598-
("hello", b'\x68\x00\x65\x00\x6c\x00\x6c\x00\x6f\x00', b"\x68\x00\x00\x00\x65\x00\x00\x00\x6c\x00\x00\x00\x6c\x00\x00\x00\x6f\x00\x00\x00"),
620+
("hello", b'\x68\x00\x65\x00\x6c\x00\x6c\x00\x6f\x00',
621+
b"\x68\x00\x00\x00\x65\x00\x00\x00\x6c\x00\x00\x00\x6c\x00\x00\x00\x6f\x00\x00\x00"),
599622
),
600623
code=""" PyObject* wrap_PyUnicode_AsUnicodeAndSize(PyObject* unicodeObj, PyObject* expected_16, PyObject* expected_32) {
601624
Py_ssize_t n = Py_UNICODE_SIZE == 2 ? PyBytes_Size(expected_16) : PyBytes_Size(expected_32);
@@ -673,7 +696,7 @@ def compile_module(self, name):
673696
arguments=["PyObject* str", "PyObject* seq"],
674697
cmpfunc=unhandled_error_compare
675698
)
676-
699+
677700
test_PyUnicode_Compare = CPyExtFunction(
678701
_reference_compare,
679702
lambda: (
@@ -738,7 +761,6 @@ def compile_module(self, name):
738761
cmpfunc=unhandled_error_compare
739762
)
740763

741-
742764
test_PyUnicode_AsEncodedString = CPyExtFunction(
743765
_reference_as_encoded_string,
744766
lambda: (
@@ -760,7 +782,7 @@ def compile_module(self, name):
760782
_reference_as_unicode_escape_string,
761783
lambda: (
762784
("abcd",),
763-
("öüä",),
785+
("öüä",),
764786
(1,),
765787
),
766788
resultspec="O",
@@ -873,5 +895,3 @@ def compile_module(self, name):
873895
arguments=["PyObject* string", "PyObject* sep", "Py_ssize_t maxsplit"],
874896
cmpfunc=unhandled_error_compare
875897
)
876-
877-

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextUnicodeBuiltins.java

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,11 @@
4747
import static com.oracle.graal.python.builtins.modules.CodecsModuleBuiltins.T_UNICODE_ESCAPE;
4848
import static com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CApiCallPath.Direct;
4949
import static com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CApiCallPath.Ignored;
50+
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.ConstCharPtr;
5051
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.ConstCharPtrAsTruffleString;
5152
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.Int;
5253
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.PY_UCS4;
54+
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.PY_UNICODE_PTR;
5355
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.Pointer;
5456
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.PyObject;
5557
import static com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor.PyObjectBorrowed;
@@ -97,8 +99,10 @@
9799
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
98100
import com.oracle.graal.python.builtins.objects.bytes.PBytesLike;
99101
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.UnicodeFromFormatNode;
102+
import com.oracle.graal.python.builtins.objects.cext.capi.PySequenceArrayWrapper;
100103
import com.oracle.graal.python.builtins.objects.cext.capi.UnicodeObjectNodes.UnicodeAsWideCharNode;
101104
import com.oracle.graal.python.builtins.objects.cext.capi.transitions.ArgDescriptor;
105+
import com.oracle.graal.python.builtins.objects.cext.common.CExtCommonNodes;
102106
import com.oracle.graal.python.builtins.objects.cext.common.CExtCommonNodes.Charsets;
103107
import com.oracle.graal.python.builtins.objects.cext.common.CExtCommonNodes.EncodeNativeStringNode;
104108
import com.oracle.graal.python.builtins.objects.cext.common.CExtCommonNodes.GetByteArrayNode;
@@ -134,7 +138,6 @@
134138
import com.oracle.graal.python.runtime.exception.PException;
135139
import com.oracle.graal.python.runtime.exception.PythonErrorType;
136140
import com.oracle.graal.python.util.OverflowException;
137-
import com.oracle.truffle.api.CompilerDirectives;
138141
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
139142
import com.oracle.truffle.api.dsl.Bind;
140143
import com.oracle.truffle.api.dsl.Cached;
@@ -147,6 +150,7 @@
147150
import com.oracle.truffle.api.library.CachedLibrary;
148151
import com.oracle.truffle.api.nodes.Node;
149152
import com.oracle.truffle.api.profiles.ConditionProfile;
153+
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
150154
import com.oracle.truffle.api.strings.TruffleString;
151155
import com.oracle.truffle.api.strings.TruffleString.Encoding;
152156
import com.oracle.truffle.api.strings.TruffleString.FromByteArrayNode;
@@ -861,22 +865,63 @@ abstract static class _PyUnicode_AsUTF8String extends NativeEncoderNode {
861865
protected _PyUnicode_AsUTF8String() {
862866
super(StandardCharsets.UTF_8);
863867
}
868+
869+
public static _PyUnicode_AsUTF8String create() {
870+
return PythonCextUnicodeBuiltinsFactory._PyUnicode_AsUTF8StringNodeGen.create();
871+
}
864872
}
865873

866-
@CApiBuiltin(ret = PyObjectTransfer, args = {PyObject}, call = Ignored)
867-
abstract static class PyTruffle_Unicode_AsUnicodeAndSize extends CApiUnaryBuiltinNode {
874+
@CApiBuiltin(ret = ConstCharPtr, args = {PyObject}, call = Direct)
875+
abstract static class PyTruffle_Unicode_AsUTF8AndSize_CharPtr extends CApiUnaryBuiltinNode {
876+
877+
@Specialization
878+
Object doUnicode(PString s,
879+
@Bind("this") Node inliningTarget,
880+
@Cached InlinedConditionProfile profile,
881+
@Cached _PyUnicode_AsUTF8String asUTF8String) {
882+
if (profile.profile(inliningTarget, s.getUtf8Bytes() == null)) {
883+
PBytes bytes = (PBytes) asUTF8String.execute(s, T_STRICT);
884+
s.setUtf8Bytes(bytes);
885+
}
886+
return new PySequenceArrayWrapper(s.getUtf8Bytes(), 1);
887+
}
888+
}
889+
890+
@CApiBuiltin(ret = Py_ssize_t, args = {PyObject}, call = Direct)
891+
abstract static class PyTruffle_Unicode_AsUTF8AndSize_Size extends CApiUnaryBuiltinNode {
892+
893+
@Specialization
894+
Object doUnicode(PString s) {
895+
// PyTruffle_Unicode_AsUTF8AndSize_CharPtr must have been be called before
896+
return s.getUtf8Bytes().getSequenceStorage().length();
897+
}
898+
}
899+
900+
@CApiBuiltin(ret = PY_UNICODE_PTR, args = {PyObject}, call = Direct)
901+
abstract static class PyTruffle_Unicode_AsUnicodeAndSize_CharPtr extends CApiUnaryBuiltinNode {
902+
868903
@Specialization
869-
@TruffleBoundary
870904
Object doUnicode(PString s,
871-
@Cached TruffleString.ToJavaStringNode toJavaStringNode) {
872-
CompilerDirectives.shouldNotReachHere("TODO: can this be reached?");
873-
char[] charArray = toJavaStringNode.execute(s.getValueUncached()).toCharArray();
874-
// stuff into byte[]
875-
ByteBuffer allocate = ByteBuffer.allocate(charArray.length * 2);
876-
for (int i = 0; i < charArray.length; i++) {
877-
allocate.putChar(charArray[i]);
905+
@Bind("this") Node inliningTarget,
906+
@Cached InlinedConditionProfile profile,
907+
@Cached CExtCommonNodes.SizeofWCharNode sizeofWcharNode,
908+
@Cached UnicodeAsWideCharNode asWideCharNode) {
909+
if (profile.profile(inliningTarget, s.getWCharBytes() == null)) {
910+
PBytes bytes = asWideCharNode.executeNativeOrder(s, sizeofWcharNode.execute(getCApiContext()));
911+
s.setWCharBytes(bytes);
878912
}
879-
return getContext().getEnv().asGuestValue(allocate.array());
913+
return new PySequenceArrayWrapper(s.getWCharBytes(), 1);
914+
}
915+
}
916+
917+
@CApiBuiltin(ret = Py_ssize_t, args = {PyObject}, call = Direct)
918+
abstract static class PyTruffle_Unicode_AsUnicodeAndSize_Size extends CApiUnaryBuiltinNode {
919+
920+
@Specialization
921+
Object doUnicode(PString s,
922+
@Cached CExtCommonNodes.SizeofWCharNode sizeofWcharNode) {
923+
// PyTruffle_Unicode_AsUnicodeAndSize_CharPtr must have been be called before
924+
return s.getWCharBytes().getSequenceStorage().length() / sizeofWcharNode.execute(getCApiContext());
880925
}
881926
}
882927

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/str/PString.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import static com.oracle.graal.python.util.PythonUtils.TS_ENCODING;
2929

30+
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
3031
import com.oracle.graal.python.builtins.objects.str.StringNodes.StringMaterializeNode;
3132
import com.oracle.graal.python.builtins.objects.str.StringNodesFactory.StringMaterializeNodeGen;
3233
import com.oracle.graal.python.nodes.util.CannotCastException;
@@ -55,6 +56,13 @@ public final class PString extends PSequence {
5556
private TruffleString materializedValue;
5657
private NativeCharSequence nativeCharSequence;
5758

59+
/*
60+
* We need to keep a reference to the encoded forms for functions that return char pointers to
61+
* keep the underlying memory alive (NativeSequenceStorage frees memory in finalizer).
62+
*/
63+
private PBytes utf8Bytes;
64+
private PBytes wCharBytes;
65+
5866
public PString(Object clazz, Shape instanceShape, NativeCharSequence value) {
5967
super(clazz, instanceShape);
6068
this.nativeCharSequence = value;
@@ -104,6 +112,22 @@ public String toString() {
104112
return isMaterialized() ? materializedValue.toJavaStringUncached() : nativeCharSequence.toString();
105113
}
106114

115+
public PBytes getUtf8Bytes() {
116+
return utf8Bytes;
117+
}
118+
119+
public void setUtf8Bytes(PBytes bytes) {
120+
this.utf8Bytes = bytes;
121+
}
122+
123+
public PBytes getWCharBytes() {
124+
return wCharBytes;
125+
}
126+
127+
public void setWCharBytes(PBytes bytes) {
128+
this.wCharBytes = bytes;
129+
}
130+
107131
@Override
108132
public SequenceStorage getSequenceStorage() {
109133
CompilerDirectives.transferToInterpreterAndInvalidate();

0 commit comments

Comments
 (0)