Skip to content

Commit 1c115f5

Browse files
committed
Support native string subclasses in more builtins
1 parent dbed777 commit 1c115f5

File tree

5 files changed

+107
-19
lines changed

5 files changed

+107
-19
lines changed

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_unicode.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
import re
4242
import sys
4343

44-
from . import CPyExtType, CPyExtTestCase, CPyExtFunction, unhandled_error_compare, GRAALPYTHON, CPyExtFunctionOutVars
44+
from . import CPyExtType, CPyExtTestCase, CPyExtFunction, unhandled_error_compare, GRAALPYTHON, CPyExtFunctionOutVars, \
45+
is_native_object
4546

4647
__dir__ = __file__.rpartition("/")[0]
4748

@@ -219,6 +220,18 @@ def gen_intern_args():
219220
return args
220221

221222

223+
UnicodeSubclass = CPyExtType(
224+
"UnicodeSubclass",
225+
'',
226+
struct_base='PyUnicodeObject base;',
227+
tp_itemsize='sizeof(char)',
228+
tp_base='&PyUnicode_Type',
229+
tp_new='0',
230+
tp_alloc='0',
231+
tp_free='0',
232+
)
233+
234+
222235
class TestPyUnicode(CPyExtTestCase):
223236

224237
test_PyUnicode_FromObject = CPyExtFunction(
@@ -229,6 +242,7 @@ class TestPyUnicode(CPyExtTestCase):
229242
(b"hello",),
230243
(Dummy(),),
231244
(str,),
245+
(UnicodeSubclass("asdf"),),
232246
),
233247
resultspec="O",
234248
argspec='O',
@@ -408,6 +422,7 @@ class TestPyUnicode(CPyExtTestCase):
408422
("hello",),
409423
("world",),
410424
("this is a longer text also cöntaining weird Ümläuts",),
425+
(UnicodeSubclass("asdf"),),
411426
),
412427
resultspec="n",
413428
argspec='O',
@@ -421,6 +436,7 @@ class TestPyUnicode(CPyExtTestCase):
421436
("hello", ", world"),
422437
("", "world"),
423438
("this is a longer text also cöntaining weird Ümläuts", ""),
439+
(UnicodeSubclass("asdf"), "gh"),
424440
),
425441
resultspec="O",
426442
argspec='OO',
@@ -477,6 +493,7 @@ class TestPyUnicode(CPyExtTestCase):
477493
lambda: (
478494
("hello",),
479495
("hellö",),
496+
(UnicodeSubclass("asdf"),),
480497
),
481498
resultspec="s",
482499
argspec='O',
@@ -489,6 +506,8 @@ class TestPyUnicode(CPyExtTestCase):
489506
lambda: (
490507
("hello",),
491508
("hellö",),
509+
(UnicodeSubclass("asdf"),),
510+
(UnicodeSubclass("žluťoučký kůň"),),
492511
),
493512
resultspec="O",
494513
argspec='O',
@@ -501,6 +520,8 @@ class TestPyUnicode(CPyExtTestCase):
501520
lambda: (
502521
("hello",),
503522
("hellö",),
523+
(UnicodeSubclass("asdf"),),
524+
(UnicodeSubclass("žluťoučký kůň"),),
504525
),
505526
resultspec="yn",
506527
resulttype='const char*',
@@ -650,6 +671,7 @@ class TestPyUnicode(CPyExtTestCase):
650671
lambda: (
651672
("hello",),
652673
("hellö",),
674+
(UnicodeSubclass("asdf"),),
653675
),
654676
resultspec="O",
655677
argspec='O',
@@ -662,6 +684,7 @@ class TestPyUnicode(CPyExtTestCase):
662684
lambda: (
663685
("hello",),
664686
("hellö",),
687+
(UnicodeSubclass("asdf"),),
665688
),
666689
resultspec="O",
667690
argspec='O',
@@ -676,6 +699,7 @@ class TestPyUnicode(CPyExtTestCase):
676699
("hellö, %s", ("wörld",)),
677700
("%s, %r", ("hello", "world")),
678701
("nothing else", tuple()),
702+
(UnicodeSubclass("%s, %r"), ("hello", "world")),
679703
),
680704
resultspec="O",
681705
argspec='OO',
@@ -691,6 +715,7 @@ class TestPyUnicode(CPyExtTestCase):
691715
(b"hello",),
692716
("hellö",),
693717
(['a', 'b', 'c'],),
718+
(UnicodeSubclass("asdf"),),
694719
),
695720
resultspec="i",
696721
argspec='O',
@@ -798,6 +823,7 @@ class TestPyUnicode(CPyExtTestCase):
798823
("hello", 0, 1),
799824
("hello", 4, 5),
800825
("hello", 1, 4),
826+
(UnicodeSubclass("asdf"), 2, 4),
801827
),
802828
resultspec="O",
803829
argspec='Onn',
@@ -826,6 +852,7 @@ class TestPyUnicode(CPyExtTestCase):
826852
("a", "b"),
827853
("a", None),
828854
("a", 1),
855+
(UnicodeSubclass("asdf"), "asdf"),
829856
),
830857
resultspec="i",
831858
argspec='OO',
@@ -840,6 +867,7 @@ class TestPyUnicode(CPyExtTestCase):
840867
("a", "b"),
841868
("a", "ab"),
842869
("ab", "a"),
870+
(UnicodeSubclass("asdf"), "asdf"),
843871
),
844872
resultspec="i",
845873
argspec='Os',
@@ -907,6 +935,7 @@ class TestPyUnicode(CPyExtTestCase):
907935
("öüä", "ascii", "ignore"),
908936
("öüä", "ascii", "replace"),
909937
(1, "ascii", "replace"),
938+
(UnicodeSubclass("asdf"), "ascii", "report"),
910939
),
911940
resultspec="O",
912941
argspec='Oss',
@@ -920,6 +949,7 @@ class TestPyUnicode(CPyExtTestCase):
920949
("abcd",),
921950
("öüä",),
922951
(1,),
952+
(UnicodeSubclass("asdf"),),
923953
),
924954
resultspec="O",
925955
argspec='O',
@@ -991,6 +1021,7 @@ class TestPyUnicode(CPyExtTestCase):
9911021
("hello", 100),
9921022
("hello", -1),
9931023
("höllö", 4),
1024+
(UnicodeSubclass("asdf"), 1),
9941025
),
9951026
code='''PyObject* wrap_PyUnicode_ReadChar(PyObject* unicode, Py_ssize_t index) {
9961027
Py_UCS4 res = PyUnicode_ReadChar(unicode, index);
@@ -1012,6 +1043,7 @@ class TestPyUnicode(CPyExtTestCase):
10121043
lambda: (
10131044
("aaa", "bbb"),
10141045
("aaa", "a"),
1046+
(UnicodeSubclass("asdf"), "s"),
10151047
),
10161048
resultspec="i",
10171049
argspec='OO',
@@ -1024,6 +1056,7 @@ class TestPyUnicode(CPyExtTestCase):
10241056
lambda: (
10251057
("foo.bar.baz", ".", 0),
10261058
("foo.bar.baz", ".", 1),
1059+
(UnicodeSubclass("foo.bar.baz"), ".", 1),
10271060
("foo.bar.baz", 7, 0),
10281061
),
10291062
resultspec="O",
@@ -1080,6 +1113,7 @@ class TestPyUnicode(CPyExtTestCase):
10801113
("ššš",),
10811114
("すごい",),
10821115
("😂",),
1116+
(UnicodeSubclass("asdf"),)
10831117
),
10841118
code='''
10851119
PyObject* wrap_PyUnicode_DATA(PyObject* string) {
@@ -1129,3 +1163,22 @@ def test_intern(self):
11291163
s2 = b'some text'.decode('ascii')
11301164
assert tester.set_intern_str(s1) == s2
11311165
assert tester.check_is_same_str_ptr(s2)
1166+
1167+
1168+
class TestNativeUnicodeSubclass:
1169+
def test_builtins(self):
1170+
s = UnicodeSubclass("asdf")
1171+
assert is_native_object(s)
1172+
assert type(s) is UnicodeSubclass
1173+
assert len(s) == 4
1174+
assert s[1] == 's'
1175+
assert s == "asdf"
1176+
assert s + "gh" == "asdfgh"
1177+
assert s > "asc"
1178+
assert s >= "asdf"
1179+
assert s < "b"
1180+
assert s <= "asdf"
1181+
assert s[1:] == "sdf"
1182+
assert "sd" in s
1183+
assert UnicodeSubclass("<{}>").format("asdf") == "<asdf>"
1184+
assert UnicodeSubclass("<%s>") % "asdf" == "<asdf>"

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextUnicodeBuiltins.java

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,17 @@
9898
import com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CApiUnaryBuiltinNode;
9999
import com.oracle.graal.python.builtins.modules.codecs.ErrorHandlers;
100100
import com.oracle.graal.python.builtins.objects.PNone;
101+
import com.oracle.graal.python.builtins.objects.buffer.PythonBufferAccessLibrary;
101102
import com.oracle.graal.python.builtins.objects.bytes.PBytes;
103+
import com.oracle.graal.python.builtins.objects.cext.PythonAbstractNativeObject;
102104
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.UnicodeFromFormatNode;
103105
import com.oracle.graal.python.builtins.objects.cext.capi.PySequenceArrayWrapper;
104106
import com.oracle.graal.python.builtins.objects.cext.capi.UnicodeObjectNodes.UnicodeAsWideCharNode;
105107
import com.oracle.graal.python.builtins.objects.cext.common.CExtCommonNodes.EncodeNativeStringNode;
106108
import com.oracle.graal.python.builtins.objects.cext.common.CExtCommonNodes.GetByteArrayNode;
107109
import com.oracle.graal.python.builtins.objects.cext.common.CExtCommonNodes.ReadUnicodeArrayNode;
110+
import com.oracle.graal.python.builtins.objects.cext.structs.CFields;
111+
import com.oracle.graal.python.builtins.objects.cext.structs.CStructAccess;
108112
import com.oracle.graal.python.builtins.objects.cext.structs.CStructs;
109113
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageGetItem;
110114
import com.oracle.graal.python.builtins.objects.common.HashingStorageNodes.HashingStorageSetItem;
@@ -1024,14 +1028,38 @@ abstract static class PyTruffle_Unicode_AsUTF8AndSize_CharPtr extends CApiUnaryB
10241028
static Object doUnicode(PString s,
10251029
@Bind("this") Node inliningTarget,
10261030
@Cached InlinedConditionProfile profile,
1027-
@Cached _PyUnicode_AsUTF8String asUTF8String) {
1031+
@Shared @Cached _PyUnicode_AsUTF8String asUTF8String) {
10281032
if (profile.profile(inliningTarget, s.getUtf8Bytes() == null)) {
10291033
PBytes bytes = (PBytes) asUTF8String.execute(s, T_STRICT);
10301034
s.setUtf8Bytes(bytes);
10311035
}
10321036
return PySequenceArrayWrapper.ensureNativeSequence(s.getUtf8Bytes());
10331037
}
10341038

1039+
@Specialization
1040+
static Object doNative(PythonAbstractNativeObject s,
1041+
@CachedLibrary(limit = "2") InteropLibrary lib,
1042+
@CachedLibrary(limit = "1") PythonBufferAccessLibrary bufferLib,
1043+
@Cached CStructAccess.ReadPointerNode readPointerNode,
1044+
@Cached CStructAccess.WritePointerNode writePointerNode,
1045+
@Cached CStructAccess.AllocateNode allocateNode,
1046+
@Cached CStructAccess.WriteByteNode writeByteNode,
1047+
@Cached CStructAccess.WriteLongNode writeLongNode,
1048+
@Shared @Cached _PyUnicode_AsUTF8String asUTF8String) {
1049+
Object utf8 = readPointerNode.readFromObj(s, CFields.PyCompactUnicodeObject__utf8);
1050+
if (lib.isNull(utf8)) {
1051+
PBytes bytes = (PBytes) asUTF8String.execute(s, T_STRICT);
1052+
int len = bufferLib.getBufferLength(bytes);
1053+
// TODO leaked?
1054+
Object mem = allocateNode.alloc(len + 1);
1055+
writeByteNode.writeByteArray(mem, bufferLib.getInternalByteArray(bytes), len, 0, 0);
1056+
writePointerNode.writeToObj(s, CFields.PyCompactUnicodeObject__utf8, mem);
1057+
writeLongNode.writeToObject(s, CFields.PyCompactUnicodeObject__utf8_length, len);
1058+
return mem;
1059+
}
1060+
return utf8;
1061+
}
1062+
10351063
@Fallback
10361064
static Object doError(@SuppressWarnings("unused") Object s,
10371065
@Cached PRaiseNode raiseNode) {
@@ -1047,6 +1075,13 @@ Object doUnicode(PString s) {
10471075
// PyTruffle_Unicode_AsUTF8AndSize_CharPtr must have been be called before
10481076
return s.getUtf8Bytes().getSequenceStorage().length();
10491077
}
1078+
1079+
@Specialization
1080+
Object doNative(PythonAbstractNativeObject s,
1081+
@Cached CStructAccess.ReadI64Node readI64Node) {
1082+
// PyTruffle_Unicode_AsUTF8AndSize_CharPtr must have been be called before
1083+
return readI64Node.readFromObj(s, CFields.PyCompactUnicodeObject__utf8_length);
1084+
}
10501085
}
10511086

10521087
@CApiBuiltin(ret = PY_UNICODE_PTR, args = {PyObject}, call = Direct)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/capi/ExternalFunctionNodes.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,10 @@ static Object doIt(Object object,
324324
@Cached CastToTruffleStringNode castToStringNode,
325325
@Cached NativeToPythonNode nativeToPythonNode) {
326326
Object result = nativeToPythonNode.execute(object);
327-
if (result instanceof TruffleString) {
327+
if (result == PNone.NO_VALUE) {
328328
return result;
329-
} else if (result instanceof PString) {
330-
return castToStringNode.execute(inliningTarget, result);
331-
} else if (result == PNone.NO_VALUE) {
332-
return result;
333-
} else {
334-
throw CompilerDirectives.shouldNotReachHere();
335329
}
330+
return castToStringNode.castKnownString(inliningTarget, result);
336331
}
337332

338333
@NeverDefault

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/cext/structs/CStructAccess.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
package com.oracle.graal.python.builtins.objects.cext.structs;
4242

4343
import com.oracle.graal.python.builtins.objects.PythonAbstractObject;
44+
import com.oracle.graal.python.builtins.objects.cext.PythonAbstractNativeObject;
4445
import com.oracle.graal.python.builtins.objects.cext.PythonNativeObject;
4546
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.FromCharPointerNode;
4647
import com.oracle.graal.python.builtins.objects.cext.capi.CExtNodes.PCallCapiFunction;
@@ -1180,6 +1181,10 @@ public final void write(Object pointer, CFields field, Object value) {
11801181
execute(pointer, field.offset(), value);
11811182
}
11821183

1184+
public final void writeToObj(PythonAbstractNativeObject obj, CFields field, Object value) {
1185+
write(obj.getPtr(), field, value);
1186+
}
1187+
11831188
public final void write(Object pointer, Object value) {
11841189
execute(pointer, 0, value);
11851190
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/str/StringBuiltins.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,20 @@ private static Spec getAndValidateSpec(Node inliningTarget, TruffleString format
283283
@GenerateNodeFactory
284284
abstract static class StrFormatNode extends PythonBuiltinNode {
285285

286-
@Specialization(guards = "isString(self)")
286+
@Specialization
287287
static TruffleString format(VirtualFrame frame, Object self, Object[] args, PKeyword[] kwargs,
288288
@Bind("this") Node inliningTarget,
289289
@Cached("createFor(this)") IndirectCallData indirectCallData,
290290
@Cached BuiltinFunctions.FormatNode format,
291-
@Cached CastToTruffleStringNode castToStringNode) {
292-
TemplateFormatter template = new TemplateFormatter(castToStringNode.execute(inliningTarget, self));
291+
@Cached CastToTruffleStringNode castToStringNode,
292+
@Cached PRaiseNode raiseNode) {
293+
TruffleString string;
294+
try {
295+
string = castToStringNode.execute(inliningTarget, self);
296+
} catch (CannotCastException e) {
297+
throw raiseNode.raise(TypeError, ErrorMessages.DESCRIPTOR_S_REQUIRES_S_OBJ_RECEIVED_P, T_FORMAT, "str", self);
298+
}
299+
TemplateFormatter template = new TemplateFormatter(string);
293300
PythonLanguage language = PythonLanguage.get(inliningTarget);
294301
PythonContext context = PythonContext.get(inliningTarget);
295302
Object state = IndirectCallContext.enter(frame, language, context, indirectCallData);
@@ -299,13 +306,6 @@ static TruffleString format(VirtualFrame frame, Object self, Object[] args, PKey
299306
IndirectCallContext.exit(frame, language, context, state);
300307
}
301308
}
302-
303-
@Specialization(guards = "!isString(self)")
304-
@SuppressWarnings("unused")
305-
static TruffleString generic(VirtualFrame frame, Object self, Object[] args, PKeyword[] kwargs,
306-
@Cached PRaiseNode raiseNode) {
307-
throw raiseNode.raise(TypeError, ErrorMessages.DESCRIPTOR_S_REQUIRES_S_OBJ_RECEIVED_P, T_FORMAT, "str", self);
308-
}
309309
}
310310

311311
@Builtin(name = J_FORMAT_MAP, minNumOfPositionalArgs = 2, declaresExplicitSelf = true, parameterNames = {"self", "mapping"})

0 commit comments

Comments
 (0)