Skip to content

Commit 8f97899

Browse files
committed
Add self check to bytearray comparison nodes
1 parent 1c0de15 commit 8f97899

File tree

3 files changed

+43
-16
lines changed

3 files changed

+43
-16
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/ByteArrayBuiltins.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
package com.oracle.graal.python.builtins.objects.bytes;
2828

2929
import static com.oracle.graal.python.nodes.BuiltinNames.J_APPEND;
30+
import static com.oracle.graal.python.nodes.BuiltinNames.J_BYTEARRAY;
3031
import static com.oracle.graal.python.nodes.BuiltinNames.J_EXTEND;
3132
import static com.oracle.graal.python.nodes.SpecialAttributeNames.T___DICT__;
3233
import static com.oracle.graal.python.nodes.SpecialMethodNames.J___DELITEM__;
@@ -78,6 +79,7 @@
7879
import com.oracle.graal.python.builtins.objects.slice.SliceNodes;
7980
import com.oracle.graal.python.builtins.objects.type.TypeNodes;
8081
import com.oracle.graal.python.builtins.objects.type.TypeNodes.InlinedIsSameTypeNode;
82+
import com.oracle.graal.python.lib.PyByteArrayCheckNode;
8183
import com.oracle.graal.python.lib.PyIndexCheckNode;
8284
import com.oracle.graal.python.lib.PyNumberAsSizeNode;
8385
import com.oracle.graal.python.lib.PyObjectLookupAttr;
@@ -799,24 +801,22 @@ public Object reduce(VirtualFrame frame, PByteArray self,
799801
abstract static class AbstractComparisonNode extends BytesNodes.AbstractComparisonBaseNode {
800802
@Specialization
801803
@SuppressWarnings("truffle-static-method")
802-
boolean cmp(PBytesLike self, PBytesLike other,
804+
boolean cmp(PByteArray self, PBytesLike other,
803805
@Shared @Cached GetInternalByteArrayNode getArray) {
804806
SequenceStorage selfStorage = self.getSequenceStorage();
805807
SequenceStorage otherStorage = other.getSequenceStorage();
806808
return doCmp(getArray.execute(selfStorage), selfStorage.length(), getArray.execute(otherStorage), otherStorage.length());
807809
}
808810

809-
@Specialization
811+
@Specialization(guards = {"check.execute(inliningTarget, self)", "acquireLib.hasBuffer(other)"}, limit = "3")
810812
@SuppressWarnings("truffle-static-method")
811813
Object cmp(VirtualFrame frame, Object self, Object other,
812814
@Bind("this") Node inliningTarget,
815+
@SuppressWarnings("unused") @Cached PyByteArrayCheckNode check,
813816
@Cached GetBytesStorage getBytesStorage,
814817
@Shared @Cached GetInternalByteArrayNode getArray,
815-
@CachedLibrary(limit = "3") PythonBufferAcquireLibrary acquireLib,
818+
@CachedLibrary("other") PythonBufferAcquireLibrary acquireLib,
816819
@CachedLibrary(limit = "3") PythonBufferAccessLibrary bufferLib) {
817-
if (!acquireLib.hasBuffer(other)) {
818-
return PNotImplemented.NOT_IMPLEMENTED;
819-
}
820820
SequenceStorage selfStorage = getBytesStorage.execute(inliningTarget, self);
821821
Object otherBuffer = acquireLib.acquireReadonly(other, frame, this);
822822
try {
@@ -826,6 +826,23 @@ Object cmp(VirtualFrame frame, Object self, Object other,
826826
bufferLib.release(otherBuffer);
827827
}
828828
}
829+
830+
@Specialization(guards = {"check.execute(inliningTarget, self)", "!acquireLib.hasBuffer(other)"}, limit = "1")
831+
@SuppressWarnings("unused")
832+
static Object cmp(VirtualFrame frame, Object self, Object other,
833+
@Bind("this") Node inliningTarget,
834+
@Cached PyByteArrayCheckNode check,
835+
@CachedLibrary(limit = "3") PythonBufferAcquireLibrary acquireLib) {
836+
return PNotImplemented.NOT_IMPLEMENTED;
837+
}
838+
839+
@Specialization(guards = "!check.execute(inliningTarget, self)", limit = "1")
840+
@SuppressWarnings({"truffle-static-method", "unused"})
841+
Object error(VirtualFrame frame, Object self, Object other,
842+
@Bind("this") Node inliningTarget,
843+
@Cached PyByteArrayCheckNode check) {
844+
throw raise(TypeError, ErrorMessages.DESCRIPTOR_REQUIRES_OBJ, J___EQ__, J_BYTEARRAY, self);
845+
}
829846
}
830847

831848
@Builtin(name = J___EQ__, minNumOfPositionalArgs = 2)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/bytes/BytesBuiltins.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import static com.oracle.graal.python.builtins.objects.bytes.BytesNodes.adjustStartIndex;
3232
import static com.oracle.graal.python.builtins.objects.bytes.BytesUtils.toLower;
3333
import static com.oracle.graal.python.builtins.objects.bytes.BytesUtils.toUpper;
34+
import static com.oracle.graal.python.nodes.BuiltinNames.J_BYTES;
3435
import static com.oracle.graal.python.nodes.BuiltinNames.J_DECODE;
3536
import static com.oracle.graal.python.nodes.BuiltinNames.J_ENDSWITH;
3637
import static com.oracle.graal.python.nodes.BuiltinNames.J_REMOVEPREFIX;
@@ -563,25 +564,33 @@ boolean cmp(PBytes self, PBytes other,
563564
return doCmp(getArray.execute(selfStorage), selfStorage.length(), getArray.execute(otherStorage), otherStorage.length());
564565
}
565566

566-
@Specialization(guards = "check.execute(inliningTarget, other)", limit = "1")
567+
@Specialization(guards = {"check.execute(inliningTarget, self)", "check.execute(inliningTarget, other)"}, limit = "1")
567568
@SuppressWarnings("truffle-static-method")
568569
boolean cmp(Object self, Object other,
569570
@Bind("this") Node inliningTarget,
570-
@SuppressWarnings("unused") @Cached PyBytesCheckNode check,
571+
@SuppressWarnings("unused") @Shared @Cached PyBytesCheckNode check,
571572
@Cached GetBytesStorage getBytesStorage,
572573
@Shared @Cached GetInternalByteArrayNode getArray) {
573574
SequenceStorage selfStorage = getBytesStorage.execute(inliningTarget, self);
574575
SequenceStorage otherStorage = getBytesStorage.execute(inliningTarget, other);
575576
return doCmp(getArray.execute(selfStorage), selfStorage.length(), getArray.execute(otherStorage), otherStorage.length());
576577
}
577578

578-
@Specialization(guards = "!check.execute(inliningTarget, other)", limit = "1")
579+
@Specialization(guards = {"check.execute(inliningTarget, self)", "!check.execute(inliningTarget, other)"}, limit = "1")
579580
@SuppressWarnings("unused")
580581
static Object cmp(Object self, Object other,
581582
@Bind("this") Node inliningTarget,
582-
@Cached PyBytesCheckNode check) {
583+
@Shared @Cached PyBytesCheckNode check) {
583584
return PNotImplemented.NOT_IMPLEMENTED;
584585
}
586+
587+
@Specialization(guards = "!check.execute(inliningTarget, self)", limit = "1")
588+
@SuppressWarnings({"unused", "truffle-static-method"})
589+
Object error(Object self, Object other,
590+
@Bind("this") Node inliningTarget,
591+
@Shared @Cached PyBytesCheckNode check) {
592+
throw raise(TypeError, ErrorMessages.DESCRIPTOR_REQUIRES_OBJ, J___EQ__, J_BYTES, self);
593+
}
585594
}
586595

587596
@Builtin(name = J___EQ__, minNumOfPositionalArgs = 2)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/lib/PyByteArrayCheckNode.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,25 @@
4646
import com.oracle.graal.python.nodes.PGuards;
4747
import com.oracle.graal.python.nodes.classes.IsSubtypeNode;
4848
import com.oracle.graal.python.nodes.object.InlinedGetClassNode;
49-
import com.oracle.truffle.api.dsl.Bind;
5049
import com.oracle.truffle.api.dsl.Cached;
5150
import com.oracle.truffle.api.dsl.Fallback;
51+
import com.oracle.truffle.api.dsl.GenerateCached;
52+
import com.oracle.truffle.api.dsl.GenerateInline;
5253
import com.oracle.truffle.api.dsl.GenerateUncached;
5354
import com.oracle.truffle.api.dsl.ImportStatic;
5455
import com.oracle.truffle.api.dsl.Specialization;
55-
import com.oracle.truffle.api.frame.VirtualFrame;
5656
import com.oracle.truffle.api.nodes.Node;
5757

5858
/**
5959
* Equivalent of CPython's {@code PyByteArray_Check}.
6060
*/
6161
@ImportStatic(PGuards.class)
6262
@GenerateUncached
63+
@GenerateInline
64+
@GenerateCached(false)
6365
public abstract class PyByteArrayCheckNode extends Node {
6466

65-
public abstract boolean execute(VirtualFrame frame, Object object);
67+
public abstract boolean execute(Node inliningTarget, Object object);
6668

6769
public static boolean executeUncached(Object object) {
6870
return PyByteArrayCheckNodeGen.getUncached().execute(null, object);
@@ -75,11 +77,10 @@ static boolean check(PByteArray obj) {
7577
}
7678

7779
@Specialization
78-
static boolean check(VirtualFrame frame, PythonAbstractNativeObject obj,
79-
@Bind("this") Node inliningTarget,
80+
static boolean check(Node inliningTarget, PythonAbstractNativeObject obj,
8081
@Cached InlinedGetClassNode getClassNode,
8182
@Cached IsSubtypeNode isSubtypeNode) {
82-
return isSubtypeNode.execute(frame, getClassNode.execute(inliningTarget, obj), PythonBuiltinClassType.PByteArray);
83+
return isSubtypeNode.execute(null, getClassNode.execute(inliningTarget, obj), PythonBuiltinClassType.PByteArray);
8384
}
8485

8586
@Fallback

0 commit comments

Comments
 (0)