Skip to content

Commit 251fe89

Browse files
Add new conditions and tests (#614)
* add new conditions and tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b8f0d0b commit 251fe89

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

openff/evaluator/_tests/test_protocols/test_groups.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,41 @@ def test_conditional_protocol_group_fail():
6767
protocol_group.execute(directory, ComputeResources())
6868

6969

70+
@pytest.mark.parametrize(
71+
"left, right, condition_type, outcome",
72+
[
73+
(1, 1, ConditionalGroup.Condition.Type.EqualTo, True),
74+
(1, 1, ConditionalGroup.Condition.Type.GreaterThan, False),
75+
(1, 1, ConditionalGroup.Condition.Type.GreaterThanOrEqualTo, True),
76+
(1, 1, ConditionalGroup.Condition.Type.LessThan, False),
77+
(1, 1, ConditionalGroup.Condition.Type.LessThanOrEqualTo, True),
78+
(1, 2, ConditionalGroup.Condition.Type.EqualTo, False),
79+
(1, 2, ConditionalGroup.Condition.Type.GreaterThan, False),
80+
(1, 2, ConditionalGroup.Condition.Type.GreaterThanOrEqualTo, False),
81+
(1, 2, ConditionalGroup.Condition.Type.LessThan, True),
82+
(1, 2, ConditionalGroup.Condition.Type.LessThanOrEqualTo, True),
83+
(2, 1, ConditionalGroup.Condition.Type.EqualTo, False),
84+
(2, 1, ConditionalGroup.Condition.Type.GreaterThan, True),
85+
(2, 1, ConditionalGroup.Condition.Type.GreaterThanOrEqualTo, True),
86+
(2, 1, ConditionalGroup.Condition.Type.LessThan, False),
87+
(2, 1, ConditionalGroup.Condition.Type.LessThanOrEqualTo, False),
88+
],
89+
)
90+
def test_evaluate_condition(left, right, condition_type, outcome):
91+
"""Tests that the conditions of a conditional group
92+
are correctly evaluated."""
93+
94+
group = ConditionalGroup("conditional_group")
95+
96+
condition = ConditionalGroup.Condition()
97+
condition.left_hand_value = left
98+
condition.right_hand_value = right
99+
condition.type = ConditionalGroup.Condition.Type(condition_type)
100+
101+
evaluated = group._evaluate_condition(condition)
102+
assert evaluated == outcome
103+
104+
70105
def test_conditional_group_self_reference():
71106
"""Tests that protocols within a conditional group
72107
can access the outputs of its parent, such as the

openff/evaluator/protocols/groups.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class Type(Enum):
4646

4747
LessThan = "lessthan"
4848
GreaterThan = "greaterthan"
49+
EqualTo = "equalto"
50+
LessThanOrEqualTo = "lessthanorequalto"
51+
GreaterThanOrEqualTo = "greaterthanorequalto"
4952

5053
left_hand_value = Attribute(
5154
docstring="The left-hand value to compare.",
@@ -148,6 +151,12 @@ def _evaluate_condition(self, condition):
148151
return left_hand_value < right_hand_value
149152
elif condition.type == self.Condition.Type.GreaterThan:
150153
return left_hand_value > right_hand_value
154+
elif condition.type == self.Condition.Type.EqualTo:
155+
return left_hand_value == right_hand_value
156+
elif condition.type == self.Condition.Type.LessThanOrEqualTo:
157+
return left_hand_value <= right_hand_value
158+
elif condition.type == self.Condition.Type.GreaterThanOrEqualTo:
159+
return left_hand_value >= right_hand_value
151160

152161
raise NotImplementedError()
153162

0 commit comments

Comments
 (0)