Skip to content

Commit d783363

Browse files
committed
Fix: Correctly read from UCS2/4 arrays and add tests.
1 parent 3b15dc7 commit d783363

File tree

3 files changed

+102
-33
lines changed

3 files changed

+102
-33
lines changed

graalpython/com.oracle.graal.python.cext/src/unicodeobject.c

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,20 @@ static PyObject* _PyUnicode_FromUCS1(const Py_UCS1* u, Py_ssize_t size) {
418418
return polyglot_from_string((const char *) u, "ISO-8859-1");
419419
}
420420

421+
typedef PyObject*(*PyTruffle_Unicode_FromWchar_t)(int8_t*, int64_t, int64_t, void*);
422+
421423
static PyObject* _PyUnicode_FromUCS2(const Py_UCS2 *u, Py_ssize_t size) {
422-
return UPCALL_CEXT_O(_jls_PyTruffle_Unicode_FromWchar, polyglot_from_i16_array(u, size), 2, NULL);
424+
// This does deliberately not use UPCALL_CEXT_O to avoid argument conversion since
425+
// 'PyTruffle_Unicode_FromWchar' really expects the bare pointer.
426+
int64_t bsize = size * sizeof(Py_UCS2);
427+
return ((PyTruffle_Unicode_FromWchar_t) _jls_PyTruffle_Unicode_FromWchar)(polyglot_from_i8_array((int8_t*)u, bsize), bsize, 2, NULL);
423428
}
424429

425430
static PyObject* _PyUnicode_FromUCS4(const Py_UCS4 *u, Py_ssize_t size) {
426-
return UPCALL_CEXT_O(_jls_PyTruffle_Unicode_FromWchar, polyglot_from_i32_array(u, size), 4, NULL);
431+
// This does deliberately not use UPCALL_CEXT_O to avoid argument conversion since
432+
// 'PyTruffle_Unicode_FromWchar' really expects the bare pointer.
433+
int64_t bsize = size * sizeof(Py_UCS4);
434+
return ((PyTruffle_Unicode_FromWchar_t) _jls_PyTruffle_Unicode_FromWchar)(polyglot_from_i8_array((int8_t*)u, bsize), bsize, 4, NULL);
427435
}
428436

429437
// taken from CPython "Python/Objects/unicodeobject.c"

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_unicode.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -482,17 +482,27 @@ def compile_module(self, name):
482482
arguments=["int ordinal"],
483483
cmpfunc=unhandled_error_compare
484484
)
485-
486485

487-
test_PyUnicode_AsUnicodeEscapeString = CPyExtFunction(
488-
_reference_unicode_escape,
486+
# NOTE: this test assumes that Python uses UTF-8 encoding for source files
487+
test_PyUnicode_FromKindAndData = CPyExtFunction(
488+
lambda args: args[3],
489489
lambda: (
490-
("abcd", ),
491-
("öüä", ),
490+
(4, bytearray([0xA2, 0x0E, 0x02, 0x00]), 1, "𠺢"),
491+
(4, bytearray([0xA2, 0x0E, 0x02, 0x00, 0x4C, 0x0F, 0x02, 0x00]), 2, "𠺢𠽌"),
492+
(2, bytearray([0x30, 0x20]), 1, "‰"),
493+
(2, bytearray([0x30, 0x20, 0x3C, 0x20]), 2, "‰‼"),
492494
),
495+
code='''PyObject* wrap_PyUnicode_FromKindAndData(int kind, Py_buffer buffer, Py_ssize_t size, PyObject* dummy) {
496+
PyObject* res;
497+
res = PyUnicode_FromKindAndData(kind, (const char *)buffer.buf, size);
498+
Py_XINCREF(res);
499+
return res;
500+
}
501+
''',
493502
resultspec="O",
494-
argspec='O',
495-
arguments=["PyObject* str"],
503+
argspec='iy*nO',
504+
arguments=["int kind", "Py_buffer buffer", "Py_ssize_t size", "PyObject* dummy"],
505+
callfunction="wrap_PyUnicode_FromKindAndData",
496506
cmpfunc=unhandled_error_compare
497507
)
498508

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/PythonCextBuiltins.java

Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import java.io.PrintWriter;
5050
import java.math.BigInteger;
5151
import java.nio.ByteBuffer;
52+
import java.nio.ByteOrder;
5253
import java.nio.CharBuffer;
5354
import java.nio.charset.CharacterCodingException;
5455
import java.nio.charset.Charset;
@@ -174,6 +175,7 @@
174175
import com.oracle.graal.python.runtime.ExecutionContext.IndirectCallContext;
175176
import com.oracle.graal.python.runtime.PythonContext;
176177
import com.oracle.graal.python.runtime.PythonCore;
178+
import com.oracle.graal.python.runtime.PythonOptions;
177179
import com.oracle.graal.python.runtime.exception.ExceptionUtils;
178180
import com.oracle.graal.python.runtime.exception.PException;
179181
import com.oracle.graal.python.runtime.exception.PythonErrorType;
@@ -207,6 +209,7 @@
207209
import com.oracle.truffle.api.interop.UnsupportedMessageException;
208210
import com.oracle.truffle.api.interop.UnsupportedTypeException;
209211
import com.oracle.truffle.api.library.CachedLibrary;
212+
import com.oracle.truffle.api.nodes.ExplodeLoop;
210213
import com.oracle.truffle.api.nodes.Node;
211214
import com.oracle.truffle.api.nodes.NodeVisitor;
212215
import com.oracle.truffle.api.nodes.RootNode;
@@ -1072,44 +1075,50 @@ private <T> T raiseNative(VirtualFrame frame, T defaultValue, PythonBuiltinClass
10721075
}
10731076
}
10741077

1075-
@Builtin(name = "PyTruffle_Unicode_FromWchar", minNumOfPositionalArgs = 3)
1078+
@Builtin(name = "PyTruffle_Unicode_FromWchar", minNumOfPositionalArgs = 4)
10761079
@GenerateNodeFactory
10771080
@TypeSystemReference(PythonArithmeticTypes.class)
1081+
@ImportStatic(PythonOptions.class)
10781082
abstract static class PyTruffle_Unicode_FromWchar extends NativeUnicodeBuiltin {
1079-
@Specialization
1080-
Object doBytes(VirtualFrame frame, Object o, long elementSize, Object errorMarker,
1081-
@Shared("getByteArrayNode") @Cached GetByteArrayNode getByteArrayNode,
1082-
@Shared("lib") @CachedLibrary(limit = "3") InteropLibrary lib) {
1083+
@Specialization(guards = "elementSize == cachedElementSize", limit = "getVariableArgumentInlineCacheLimit()")
1084+
Object doBytes(VirtualFrame frame, Object arr, long n, long elementSize, Object errorMarker,
1085+
@Cached CExtNodes.ToSulongNode toSulongNode,
1086+
@Cached("elementSize") long cachedElementSize,
1087+
@CachedLibrary("arr") InteropLibrary lib,
1088+
@CachedLibrary(limit = "1") InteropLibrary elemLib) {
10831089
try {
10841090
ByteBuffer bytes;
1085-
if (elementSize == 2L) {
1086-
if (!lib.hasArrayElements(o)) {
1091+
if (cachedElementSize == 1L || cachedElementSize == 2L || cachedElementSize == 4L) {
1092+
if (!lib.hasArrayElements(arr)) {
10871093
return raiseNative(frame, errorMarker, PythonErrorType.SystemError, "provided object is not an array", elementSize);
10881094
}
1089-
long size = lib.getArraySize(o);
1090-
bytes = readWithSize(lib, o, (int) size);
1095+
bytes = readWithSize(lib, elemLib, arr, PInt.intValueExact(n), (int) cachedElementSize);
10911096
bytes.flip();
1092-
} else if (elementSize == 4L) {
1093-
bytes = wrap(getByteArrayNode.execute(frame, o, -1));
10941097
} else {
10951098
return raiseNative(frame, errorMarker, PythonErrorType.ValueError, "unsupported 'wchar_t' size; was: %d", elementSize);
10961099
}
1097-
return decode(bytes);
1100+
return toSulongNode.execute(decode(bytes));
1101+
} catch (ArithmeticException e) {
1102+
return raiseNative(frame, errorMarker, PythonErrorType.ValueError, "array size too large");
10981103
} catch (CharacterCodingException e) {
10991104
return raiseNative(frame, errorMarker, PythonErrorType.UnicodeError, "%m", e);
11001105
} catch (IllegalArgumentException e) {
11011106
return raiseNative(frame, errorMarker, PythonErrorType.LookupError, "%m", e);
11021107
} catch (InteropException e) {
11031108
return raiseNative(frame, errorMarker, PythonErrorType.TypeError, "%m", e);
1109+
} catch (IllegalElementTypeException e) {
1110+
return raiseNative(frame, errorMarker, PythonErrorType.UnicodeDecodeError, "Invalid input element type '%p'", e.elem);
11041111
}
11051112
}
11061113

1107-
@Specialization
1108-
Object doBytes(VirtualFrame frame, Object o, PInt elementSize, Object errorMarker,
1109-
@Shared("getByteArrayNode") @Cached GetByteArrayNode getByteArrayNode,
1110-
@Shared("lib") @CachedLibrary(limit = "3") InteropLibrary lib) {
1114+
@Specialization(limit = "getVariableArgumentInlineCacheLimit()")
1115+
Object doBytes(VirtualFrame frame, Object arr, PInt n, PInt elementSize, Object errorMarker,
1116+
@Cached CExtNodes.ToSulongNode toSulongNode,
1117+
@CachedLibrary("arr") InteropLibrary lib,
1118+
@CachedLibrary(limit = "1") InteropLibrary elemLib) {
11111119
try {
1112-
return doBytes(frame, o, elementSize.longValueExact(), errorMarker, getByteArrayNode, lib);
1120+
long es = elementSize.longValueExact();
1121+
return doBytes(frame, arr, n.longValueExact(), es, errorMarker, toSulongNode, es, lib, elemLib);
11131122
} catch (ArithmeticException e) {
11141123
return raiseNative(frame, errorMarker, PythonErrorType.ValueError, "invalid parameters");
11151124
}
@@ -1120,16 +1129,58 @@ private static String decode(ByteBuffer bytes) throws CharacterCodingException {
11201129
return getUTF32Charset(0).newDecoder().decode(bytes).toString();
11211130
}
11221131

1123-
@TruffleBoundary
1124-
private static ByteBuffer readWithSize(InteropLibrary interopLib, Object o, int size) throws UnsupportedMessageException, InvalidArrayIndexException {
1125-
ByteBuffer buf = ByteBuffer.allocate(size * Integer.BYTES);
1126-
for (long i = 0; i < size; i++) {
1127-
Object elem = interopLib.readArrayElement(o, i);
1128-
assert elem instanceof Number && 0 <= ((Number) elem).intValue() && ((Number) elem).intValue() < (1 << 16);
1129-
buf.putInt(((Number) elem).intValue());
1132+
private static ByteBuffer readWithSize(InteropLibrary arrLib, InteropLibrary elemLib, Object o, int size, int elementSize)
1133+
throws UnsupportedMessageException, InvalidArrayIndexException, IllegalElementTypeException {
1134+
ByteBuffer buf = allocate(size * Integer.BYTES);
1135+
for (int i = 0; i < size; i += elementSize) {
1136+
putInt(buf, readElement(arrLib, elemLib, o, i, elementSize));
11301137
}
11311138
return buf;
11321139
}
1140+
1141+
@ExplodeLoop
1142+
private static int readElement(InteropLibrary arrLib, InteropLibrary elemLib, Object arr, int i, int elementSize)
1143+
throws InvalidArrayIndexException, UnsupportedMessageException, IllegalElementTypeException {
1144+
byte[] barr = new byte[4];
1145+
for (int j = 0; j < elementSize; j++) {
1146+
Object elem = arrLib.readArrayElement(arr, i + j);
1147+
// The array object could be one of our wrappers (e.g. 'PySequenceArrayWrapper').
1148+
// Since the Interop library does not allow to specify how many bytes we want to
1149+
// read when we do readArrayElement, our wrappers always return long. So, we check
1150+
// for 'long' here and cast down to 'byte'.
1151+
if (elemLib.fitsInLong(elem)) {
1152+
barr[j] = (byte) elemLib.asLong(elem);
1153+
} else {
1154+
CompilerDirectives.transferToInterpreter();
1155+
throw new IllegalElementTypeException(elem);
1156+
}
1157+
}
1158+
return toInt(barr);
1159+
}
1160+
1161+
@TruffleBoundary(allowInlining = true)
1162+
private static int toInt(byte[] barr) {
1163+
return ByteBuffer.wrap(barr).order(ByteOrder.LITTLE_ENDIAN).getInt();
1164+
}
1165+
1166+
@TruffleBoundary(allowInlining = true)
1167+
private static ByteBuffer allocate(int cap) {
1168+
return ByteBuffer.allocate(cap);
1169+
}
1170+
1171+
@TruffleBoundary(allowInlining = true)
1172+
private static void putInt(ByteBuffer buf, int element) {
1173+
buf.putInt(element);
1174+
}
1175+
1176+
private static final class IllegalElementTypeException extends Exception {
1177+
private static final long serialVersionUID = 0L;
1178+
private final Object elem;
1179+
1180+
IllegalElementTypeException(Object elem) {
1181+
this.elem = elem;
1182+
}
1183+
}
11331184
}
11341185

11351186
@Builtin(name = "PyTruffle_Unicode_FromUTF8", minNumOfPositionalArgs = 2)

0 commit comments

Comments
 (0)