Skip to content

Commit 7366711

Browse files
committed
PyObject_Bytes can use directly BytesNode if provided object has __BYTES__
1 parent 2a4c647 commit 7366711

File tree

2 files changed

+33
-36
lines changed

2 files changed

+33
-36
lines changed

graalpython/com.oracle.graal.python.test/src/tests/cpyext/test_object.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2018, 2021, Oracle and/or its affiliates. All rights reserved.
1+
# Copyright (c) 2018, 2022, Oracle and/or its affiliates. All rights reserved.
22
# DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
33
#
44
# The Universal Permissive License (UPL), Version 1.0
@@ -52,6 +52,7 @@ def _reference_bytes(args):
5252
res = obj.__bytes__()
5353
if not isinstance(res, bytes):
5454
raise TypeError("__bytes__ returned non-bytes (type %s)" % type(res).__name__)
55+
return res
5556
if isinstance(obj, (list, tuple, memoryview)) or (not isinstance(obj, str) and hasattr(obj, "__iter__")):
5657
return bytes(obj)
5758
raise TypeError("cannot convert '%s' object to bytes" % type(obj).__name__)
@@ -683,6 +684,16 @@ def test_doc(self):
683684
assert len(obj.some_member.__doc__) == len(expected_doc)
684685
assert obj.some_member.__doc__ == expected_doc
685686

687+
class CBytes:
688+
def __bytes__(self):
689+
return b'abc'
690+
691+
class CBytesWrongReturn:
692+
def __bytes__(self):
693+
return 'abc'
694+
695+
class DummyBytes(bytes):
696+
pass
686697

687698
class TestObjectFunctions(CPyExtTestCase):
688699
def compile_module(self, name):
@@ -712,6 +723,10 @@ def compile_module(self, name):
712723
(memoryview(b"world"),),
713724
(1.234,),
714725
(bytearray(b"blah"),),
726+
(CBytes(),),
727+
(CBytesWrongReturn(),),
728+
(DummyBytes(),),
729+
([1,2,3],),
715730
),
716731
arguments=["PyObject* obj"],
717732
resultspec="O",

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/cext/PythonCextObjectBuiltins.java

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import com.oracle.graal.python.builtins.Python3Core;
5151
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
5252
import com.oracle.graal.python.builtins.PythonBuiltins;
53+
import com.oracle.graal.python.builtins.modules.BuiltinConstructors.BytesNode;
5354
import com.oracle.graal.python.builtins.modules.BuiltinFunctions.IsInstanceNode;
5455
import com.oracle.graal.python.builtins.modules.BuiltinFunctions.IsSubClassNode;
5556
import com.oracle.graal.python.builtins.modules.cext.PythonCextBuiltins.CastArgsNode;
@@ -75,12 +76,8 @@
7576
import com.oracle.graal.python.lib.PyObjectReprAsObjectNode;
7677
import com.oracle.graal.python.lib.PyObjectSetItem;
7778
import com.oracle.graal.python.lib.PyObjectStrAsObjectNode;
78-
import static com.oracle.graal.python.nodes.ErrorMessages.RETURNED_NONBYTES;
79-
import com.oracle.graal.python.nodes.SpecialMethodNames;
80-
import static com.oracle.graal.python.nodes.SpecialMethodNames.__SETATTR__;
8179
import com.oracle.graal.python.nodes.attributes.GetAttributeNode.GetAnyAttributeNode;
8280
import com.oracle.graal.python.nodes.call.CallNode;
83-
import com.oracle.graal.python.nodes.call.special.CallUnaryMethodNode;
8481
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
8582
import com.oracle.graal.python.nodes.expression.BinaryComparisonNode;
8683
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
@@ -94,7 +91,6 @@
9491
import com.oracle.graal.python.util.OverflowException;
9592
import com.oracle.graal.python.util.PythonUtils;
9693
import com.oracle.truffle.api.CompilerDirectives;
97-
import com.oracle.truffle.api.dsl.Bind;
9894
import com.oracle.truffle.api.dsl.Cached;
9995
import com.oracle.truffle.api.dsl.Cached.Shared;
10096
import com.oracle.truffle.api.dsl.GenerateNodeFactory;
@@ -486,7 +482,7 @@ Object getAttr(VirtualFrame frame, Object obj, Object attr,
486482
@Cached GetAttributeNode getAttrNode,
487483
@Cached TransformExceptionToNativeNode transformExceptionToNativeNode) {
488484
try {
489-
return getAttrNode.execute(frame, getCore().lookupType(PythonBuiltinClassType.PythonObject), SpecialMethodNames.__GETATTRIBUTE__);
485+
return getAttrNode.execute(frame, obj, attr);
490486
} catch (PException e) {
491487
transformExceptionToNativeNode.execute(frame, e);
492488
return getContext().getNativeNull();
@@ -502,7 +498,7 @@ int setAttr(VirtualFrame frame, Object obj, Object attr, Object value,
502498
@Cached SetattrNode setAttrNode,
503499
@Cached TransformExceptionToNativeNode transformExceptionToNativeNode) {
504500
try {
505-
setAttrNode.execute(frame, getCore().lookupType(PythonBuiltinClassType.PythonObject), __SETATTR__, value);
501+
setAttrNode.execute(frame, obj, attr, value);
506502
return 0;
507503
} catch (PException e) {
508504
transformExceptionToNativeNode.execute(frame, e);
@@ -568,49 +564,35 @@ Object bytes(VirtualFrame frame, Object bytes,
568564
return bytes;
569565
}
570566

571-
@Specialization(guards = {"!isBytes(obj)", "!isBytesSubtype(frame, obj, getClassNode, isSubtypeNode)", "!isPNone(bytesCallable)"})
567+
@Specialization(guards = {"!isBytes(obj)", "!isBytesSubtype(frame, obj, getClassNode, isSubtypeNode)", "hasBytes(frame, obj, lookupAttrNode)"}, limit = "1")
572568
Object bytes(VirtualFrame frame, Object obj,
573-
@SuppressWarnings("unused") @Cached GetClassNode getClassNode,
574-
@SuppressWarnings("unused") @Cached IsSubtypeNode isSubtypeNode,
569+
@Shared("getClass") @SuppressWarnings("unused") @Cached GetClassNode getClassNode,
570+
@Shared("isSubtype") @SuppressWarnings("unused") @Cached IsSubtypeNode isSubtypeNode,
575571
@Cached PyObjectLookupAttr lookupAttrNode,
576-
@Bind("getBytes(frame, obj, lookupAttrNode)") Object bytesCallable,
577-
@Cached CallUnaryMethodNode callNode,
578-
@Cached BranchProfile branchProfile,
579-
@Cached PRaiseNativeNode raiseNativeNode,
572+
@Cached BytesNode bytesNode,
580573
@Cached TransformExceptionToNativeNode transformExceptionToNativeNode) {
581574
try {
582-
Object res = callNode.executeObject(frame, bytesCallable, obj);
583-
if (!isBytesSubtype(frame, res, getClassNode, isSubtypeNode)) {
584-
branchProfile.enter();
585-
return raiseNativeNode.execute(frame, getContext().getNativeNull(), TypeError, RETURNED_NONBYTES, new Object[]{__BYTES__, res});
586-
}
587-
return res;
575+
return bytesNode.execute(frame, PythonBuiltinClassType.PBytes, obj, PNone.NO_VALUE, PNone.NO_VALUE);
588576
} catch (PException e) {
589577
transformExceptionToNativeNode.execute(e);
590578
return getContext().getNativeNull();
591579
}
592580
}
593581

594-
@Specialization(guards = {"!isBytes(obj)", "!isBytesSubtype(frame, obj, getClassNode, isSubtypeNode)", "isPNone(getBytes(frame, obj, lookupAttrNode))"})
595-
Object bytes(VirtualFrame frame, Object obj,
596-
@SuppressWarnings("unused") @Cached GetClassNode getClassNode,
597-
@SuppressWarnings("unused") @Cached IsSubtypeNode isSubtypeNode,
582+
@Specialization(guards = {"!isBytes(obj)", "!isBytesSubtype(frame, obj, getClassNode, isSubtypeNode)", "!hasBytes(frame, obj, lookupAttrNode)"}, limit = "1")
583+
static Object bytes(VirtualFrame frame, Object obj,
584+
@Shared("getClass") @SuppressWarnings("unused") @Cached GetClassNode getClassNode,
585+
@Shared("isSubtype") @SuppressWarnings("unused") @Cached IsSubtypeNode isSubtypeNode,
598586
@Cached PyObjectLookupAttr lookupAttrNode,
599-
@Cached PyBytesFromObjectNode fromObjectNode,
600-
@Cached TransformExceptionToNativeNode transformExceptionToNativeNode) {
601-
try {
602-
return fromObjectNode.execute(frame, obj);
603-
} catch (PException e) {
604-
transformExceptionToNativeNode.execute(e);
605-
return getContext().getNativeNull();
606-
}
587+
@Cached PyBytesFromObjectNode fromObjectNode) {
588+
return fromObjectNode.execute(frame, obj);
607589
}
608590

609-
protected Object getBytes(VirtualFrame frame, Object obj, PyObjectLookupAttr lookupAttrNode) {
610-
return lookupAttrNode.execute(frame, obj, __BYTES__);
591+
protected static boolean hasBytes(VirtualFrame frame, Object obj, PyObjectLookupAttr lookupAttrNode) {
592+
return lookupAttrNode.execute(frame, obj, __BYTES__) != PNone.NO_VALUE;
611593
}
612594

613-
protected boolean isBytesSubtype(VirtualFrame frame, Object obj, GetClassNode getClassNode, IsSubtypeNode isSubtypeNode) {
595+
protected static boolean isBytesSubtype(VirtualFrame frame, Object obj, GetClassNode getClassNode, IsSubtypeNode isSubtypeNode) {
614596
return isSubtypeNode.execute(frame, getClassNode.execute(obj), PythonBuiltinClassType.PBytes);
615597
}
616598
}

0 commit comments

Comments
 (0)