Skip to content

Commit 7e2021b

Browse files
committed
GR-10939: sets -> add support for comparison with objects implementing the richcmp protocol
- added relevant unittest
1 parent 87cae01 commit 7e2021b

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_set.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,40 @@ def test_intersection():
215215
def test_same_id():
216216
empty_ids = set([id(frozenset()) for i in range(100)])
217217
assert len(empty_ids) == 1
218+
219+
220+
def test_rich_compare():
221+
class TestRichSetCompare:
222+
def __gt__(self, some_set):
223+
self.gt_called = True
224+
return False
225+
def __lt__(self, some_set):
226+
self.lt_called = True
227+
return False
228+
def __ge__(self, some_set):
229+
self.ge_called = True
230+
return False
231+
def __le__(self, some_set):
232+
self.le_called = True
233+
return False
234+
235+
# This first tries the builtin rich set comparison, which doesn't know
236+
# how to handle the custom object. Upon returning NotImplemented, the
237+
# corresponding comparison on the right object is invoked.
238+
myset = {1, 2, 3}
239+
240+
myobj = TestRichSetCompare()
241+
myset < myobj
242+
assert myobj.gt_called
243+
244+
myobj = TestRichSetCompare()
245+
myset > myobj
246+
assert myobj.lt_called
247+
248+
myobj = TestRichSetCompare()
249+
myset <= myobj
250+
assert myobj.ge_called
251+
252+
myobj = TestRichSetCompare()
253+
myset >= myobj
254+
assert myobj.le_called

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/set/FrozenSetBuiltins.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@
5353
import com.oracle.graal.python.builtins.objects.common.PHashingCollection;
5454
import com.oracle.graal.python.builtins.objects.set.FrozenSetBuiltinsFactory.BinaryUnionNodeGen;
5555
import com.oracle.graal.python.nodes.PBaseNode;
56+
import com.oracle.graal.python.nodes.call.special.LookupAndCallBinaryNode;
5657
import com.oracle.graal.python.nodes.control.GetIteratorNode;
5758
import com.oracle.graal.python.nodes.control.GetNextNode;
5859
import com.oracle.graal.python.nodes.function.PythonBuiltinBaseNode;
5960
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
6061
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryBuiltinNode;
6162
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
6263
import com.oracle.graal.python.runtime.exception.PException;
64+
import com.oracle.graal.python.runtime.exception.PythonErrorType;
6365
import com.oracle.truffle.api.CompilerDirectives;
6466
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
6567
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -366,11 +368,29 @@ boolean isSuperSet(PBaseSet self, String other,
366368
@Builtin(name = __LE__, fixedNumOfArguments = 2)
367369
@GenerateNodeFactory
368370
abstract static class LessEqualNode extends IsSubsetNode {
371+
@Specialization
372+
Object isLessEqual(PBaseSet self, Object other,
373+
@Cached("create(__GE__)") LookupAndCallBinaryNode lookupAndCallBinaryNode) {
374+
Object result = lookupAndCallBinaryNode.executeObject(other, self);
375+
if (result != PNone.NO_VALUE) {
376+
return result;
377+
}
378+
throw raise(PythonErrorType.TypeError, "unorderable types: %p <= %p", self, other);
379+
}
369380
}
370381

371382
@Builtin(name = __GE__, fixedNumOfArguments = 2)
372383
@GenerateNodeFactory
373384
abstract static class GreaterEqualNode extends IsSupersetNode {
385+
@Specialization
386+
Object isGreaterEqual(PBaseSet self, Object other,
387+
@Cached("create(__LE__)") LookupAndCallBinaryNode lookupAndCallBinaryNode) {
388+
Object result = lookupAndCallBinaryNode.executeObject(other, self);
389+
if (result != PNone.NO_VALUE) {
390+
return result;
391+
}
392+
throw raise(PythonErrorType.TypeError, "unorderable types: %p >= %p", self, other);
393+
}
374394
}
375395

376396
@Builtin(name = __LT__, fixedNumOfArguments = 2)
@@ -403,6 +423,16 @@ boolean isLessThan(PBaseSet self, String other,
403423
}
404424
return (Boolean) getLessEqualNode().execute(self, other);
405425
}
426+
427+
@Specialization
428+
Object isLessThan(PBaseSet self, Object other,
429+
@Cached("create(__GT__)") LookupAndCallBinaryNode lookupAndCallBinaryNode) {
430+
Object result = lookupAndCallBinaryNode.executeObject(other, self);
431+
if (result != PNone.NO_VALUE) {
432+
return result;
433+
}
434+
throw raise(PythonErrorType.TypeError, "unorderable types: %p < %p", self, other);
435+
}
406436
}
407437

408438
@Builtin(name = __GT__, fixedNumOfArguments = 2)
@@ -435,5 +465,15 @@ boolean isGreaterThan(PBaseSet self, String other,
435465
}
436466
return (Boolean) getGreaterEqualNode().execute(self, other);
437467
}
468+
469+
@Specialization
470+
Object isLessThan(PBaseSet self, Object other,
471+
@Cached("create(__LT__)") LookupAndCallBinaryNode lookupAndCallBinaryNode) {
472+
Object result = lookupAndCallBinaryNode.executeObject(other, self);
473+
if (result != PNone.NO_VALUE) {
474+
return result;
475+
}
476+
throw raise(PythonErrorType.TypeError, "unorderable types: %p > %p", self, other);
477+
}
438478
}
439479
}

0 commit comments

Comments
 (0)