Skip to content

Commit aa95d40

Browse files
authored
Merge pull request #168 from DillonZChen/main
Implement ordering on Variable objects
2 parents db51bab + c895501 commit aa95d40

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

pddl/logic/terms.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,11 @@ def __eq__(self, other):
6565

6666
def __lt__(self, other):
6767
"""Compare with another term."""
68-
if isinstance(other, Constant):
69-
return (self.name, sorted(self.type_tags)) < (
70-
other.name,
71-
sorted(other.type_tags),
72-
)
73-
else:
68+
if not isinstance(other, Term):
7469
return super().__lt__(other)
70+
lhs = (type(self).__name__, self.name, sorted(self.type_tags))
71+
rhs = (type(other).__name__, other.name, sorted(other.type_tags))
72+
return lhs < rhs
7573

7674

7775
# TODO check correctness

tests/test_logic/test_terms.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"""Test pddl.logic.terms module."""
1414
import pytest
1515

16-
from pddl.logic.terms import Variable
16+
from pddl.logic.terms import Constant, Variable
1717

1818

1919
def test_no_duplicated_type_tags() -> None:
@@ -22,3 +22,47 @@ def test_no_duplicated_type_tags() -> None:
2222
ValueError, match=r"duplicate element in collection \['b', 'b'\]: 'b'"
2323
):
2424
Variable("a", ["b", "b"])
25+
26+
27+
def test_variable_ordering() -> None:
28+
"""Test variable ordering."""
29+
v1 = Variable("x", ["type1"])
30+
v2 = Variable("y", ["type3"])
31+
v3 = Variable("z", ["type2"])
32+
assert v1 < v2
33+
assert v1 < v3
34+
assert v2 < v3
35+
assert not (v1 > v2)
36+
assert not (v1 > v3)
37+
assert not (v2 > v3)
38+
39+
40+
def test_constant_ordering() -> None:
41+
"""Test constant ordering."""
42+
c1 = Constant("a", "type1")
43+
c2 = Constant("b", "type2")
44+
c3 = Constant("c", "type1")
45+
assert c1 < c2
46+
assert c1 < c3
47+
assert c2 < c3
48+
assert not (c1 > c2)
49+
assert not (c1 > c3)
50+
assert not (c2 > c3)
51+
52+
53+
def test_term_ordering() -> None:
54+
"""Test term ordering between variables and constants."""
55+
v1 = Variable("x", ["type1"])
56+
v2 = Variable("y", ["type3"])
57+
v3 = Variable("z", ["type2"])
58+
c1 = Constant("a", "type1")
59+
c2 = Constant("b", "type2")
60+
c3 = Constant("c", "type1")
61+
assert c1 < v1
62+
assert c2 < v2
63+
assert c3 < v3
64+
assert not (v1 < c1)
65+
assert not (v2 < c2)
66+
assert not (v3 < c3)
67+
array = [c1, c2, c3, v1, v2, v3]
68+
assert array == sorted(array)

0 commit comments

Comments
 (0)