Skip to content

Commit 0f1328e

Browse files
updated validator to restrict cohort_label for R, X and Y rules (#466)
* updated validator to restrict cohort_label for R, X and Y rules * lint fix
1 parent 62f4351 commit 0f1328e

File tree

6 files changed

+49
-95
lines changed

6 files changed

+49
-95
lines changed

src/eligibility_signposting_api/services/processors/rule_processor.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,6 @@ def evaluate_rules_priority_group(
148148

149149
return best_status, exclusion_reasons, is_rule_stop
150150

151-
@staticmethod
152-
def get_exclusion_rules(cohort: IterationCohort, rules: Iterable[IterationRule]) -> Iterator[IterationRule]:
153-
return (ir for ir in rules if not ir.parsed_cohort_labels or cohort.cohort_label in ir.parsed_cohort_labels)
154-
155151
def get_cohort_group_results(
156152
self, person: Person, active_iteration: Iteration
157153
) -> dict[CohortLabel, CohortGroupResult]:

src/rules_validation_api/validators/iteration_rules_validator.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
from pydantic import model_validator
44

5-
from eligibility_signposting_api.model.campaign_config import IterationRule, RuleAttributeLevel, RuleAttributeName
5+
from eligibility_signposting_api.model.campaign_config import (
6+
IterationRule,
7+
RuleAttributeLevel,
8+
RuleAttributeName,
9+
RuleType,
10+
)
611

712

813
class IterationRuleValidation(IterationRule):
@@ -16,3 +21,14 @@ def check_cohort_attribute_name(self) -> Self:
1621
msg = "When attribute_level is COHORT, attribute_name must be COHORT_LABEL or None (default:COHORT_LABEL)"
1722
raise ValueError(msg)
1823
return self
24+
25+
@model_validator(mode="after")
26+
def check_cohort_label_for_non_f_and_s_types(self) -> Self:
27+
allowed_types = {RuleType("F"), RuleType("S")}
28+
if self.cohort_label is not None and self.type not in allowed_types:
29+
msg = (
30+
f"CohortLabel is only allowed for rule types F and S. "
31+
f"Found type: {self.type} with cohort_label: {self.cohort_label}"
32+
)
33+
raise ValueError(msg)
34+
return self

tests/unit/services/processors/test_rule_processor.py

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -25,93 +25,6 @@ def rule_processor(mock_person_data_reader):
2525
MOCK_PERSON_DATA = Person([{"ATTRIBUTE_TYPE": "PERSON", "AGE": "30"}])
2626

2727

28-
def test_get_exclusion_rules_no_rules():
29-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
30-
rules_to_filter = []
31-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
32-
assert_that(result, is_([]))
33-
34-
35-
def test_get_exclusion_rules_general_rule():
36-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
37-
no_cohort_label_rule = rule_builder.IterationRuleFactory.build(cohort_label=None)
38-
rules_to_filter = [no_cohort_label_rule]
39-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
40-
assert_that(result, is_([no_cohort_label_rule]))
41-
42-
43-
def test_get_exclusion_rules_matching_cohort_label():
44-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
45-
matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A")
46-
rules_to_filter = [matching_rule]
47-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
48-
assert_that(result, is_([matching_rule]))
49-
50-
51-
def test_get_exclusion_rules_matching_cohort_label_when_it_contains_multiple_cohort_labels():
52-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
53-
matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A,COHORT_B")
54-
rules_to_filter = [matching_rule]
55-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
56-
assert_that(result, is_([matching_rule]))
57-
58-
59-
def test_get_exclusion_rules_non_matching_cohort_label():
60-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
61-
non_matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B")
62-
rules_to_filter = [non_matching_rule]
63-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
64-
assert_that(result, is_([]))
65-
66-
67-
def test_get_exclusion_rules_non_matching_cohort_label_when_it_contains_multiple_cohort_labels():
68-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
69-
non_matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B,COHORT_C")
70-
rules_to_filter = [non_matching_rule]
71-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
72-
assert_that(result, is_([]))
73-
74-
75-
def test_get_exclusion_rules_matching_from_list_cohort_label():
76-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
77-
rule1 = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A")
78-
rule2 = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B")
79-
rules_to_filter = [rule1, rule2]
80-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
81-
assert_that(result, is_([rule1]))
82-
83-
84-
def test_get_exclusion_rules_matching_from_list_cohort_label_when_it_contains_multiple_cohort_labels():
85-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
86-
rule1 = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A")
87-
rule2 = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B,COHORT_C")
88-
rules_to_filter = [rule1, rule2]
89-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
90-
assert_that(result, is_([rule1]))
91-
92-
93-
def test_get_exclusion_rules_mixed_rules():
94-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
95-
no_cohort_label_rule = rule_builder.IterationRuleFactory.build(cohort_label=None, name="General")
96-
matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A", name="Matching")
97-
non_matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B", name="NonMatching")
98-
99-
rules_to_filter = [no_cohort_label_rule, matching_rule, non_matching_rule]
100-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
101-
assert_that({r.name for r in result}, is_({"General", "Matching"}))
102-
103-
104-
def test_get_exclusion_rules_mixed_rules_when_it_contains_multiple_cohort_labels():
105-
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
106-
no_cohort_label_rule = rule_builder.IterationRuleFactory.build(cohort_label=None, name="General")
107-
matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A,COHORT_C", name="Matching")
108-
non_matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B,COHORT_C", name="NonMatching")
109-
110-
rules_to_filter = [no_cohort_label_rule, matching_rule, non_matching_rule]
111-
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
112-
assert_that({r.name for r in result}, is_({"General", "Matching"}))
113-
114-
11528
@patch("eligibility_signposting_api.services.processors.rule_processor.RuleCalculator")
11629
def test_evaluate_rules_priority_group_all_actionable(mock_rule_calculator_class, rule_processor):
11730
mock_rule_calculator_class.return_value.evaluate_exclusion.return_value = (

tests/unit/validation/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def valid_iteration_rule_with_only_mandatory_fields():
4545
"AttributeTarget": "RSV",
4646
"AttributeLevel": "TARGET",
4747
"AttributeName": "LAST_SUCCESSFUL_DATE",
48-
"CohortLabel": "elid_all_people",
4948
"Priority": 100,
5049
}
5150

tests/unit/validation/test_iteration_rules_validator.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,35 @@ def test_invalid_when_attribute_level_is_cohort_but_attribute_name_is_neither_no
259259
IterationRuleValidation(**data)
260260
msg = "When attribute_level is COHORT, attribute_name must be COHORT_LABEL or None (default:COHORT_LABEL)"
261261
assert msg in str(error.value)
262+
263+
@pytest.mark.parametrize("rule_type", ["F", "S"])
264+
def test_valid_when_cohort_label_present_for_type_f_or_s(
265+
self, rule_type, valid_iteration_rule_with_only_mandatory_fields
266+
):
267+
data = valid_iteration_rule_with_only_mandatory_fields.copy()
268+
data["Type"] = rule_type
269+
data["CohortLabel"] = "Test Cohort"
270+
result = IterationRuleValidation(**data)
271+
assert result.cohort_label == "Test Cohort"
272+
273+
@pytest.mark.parametrize("rule_type", ["R", "X", "Y"])
274+
def test_invalid_when_cohort_label_present_for_non_f_s_types(
275+
self, rule_type, valid_iteration_rule_with_only_mandatory_fields
276+
):
277+
data = valid_iteration_rule_with_only_mandatory_fields.copy()
278+
data["Type"] = rule_type
279+
data["CohortLabel"] = "Invalid Cohort"
280+
with pytest.raises(ValidationError) as error:
281+
IterationRuleValidation(**data)
282+
msg = "CohortLabel is only allowed for rule types F and S."
283+
assert msg in str(error.value)
284+
285+
@pytest.mark.parametrize("rule_type", ["R", "X", "Y"])
286+
def test_valid_when_cohort_label_absent_for_non_f_s_types(
287+
self, rule_type, valid_iteration_rule_with_only_mandatory_fields
288+
):
289+
data = valid_iteration_rule_with_only_mandatory_fields.copy()
290+
data["Type"] = rule_type
291+
data.pop("CohortLabel", None)
292+
result = IterationRuleValidation(**data)
293+
assert result.cohort_label is None

tests/unit/validation/test_iteration_validator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,6 @@ def test_valid_iteration_if_actions_mapper_exists_for_rule_routing(
353353
"AttributeTarget": "RSV",
354354
"AttributeLevel": "TARGET",
355355
"AttributeName": "LAST_SUCCESSFUL_DATE",
356-
"CohortLabel": "elid_all_people",
357356
"Priority": 100,
358357
"CommsRouting": default_routing,
359358
}
@@ -389,7 +388,6 @@ def test_invalid_iteration_if_actions_mapper_exists_for_rule_routing(
389388
"AttributeTarget": "RSV",
390389
"AttributeLevel": "TARGET",
391390
"AttributeName": "LAST_SUCCESSFUL_DATE",
392-
"CohortLabel": "elid_all_people",
393391
"Priority": 100,
394392
"CommsRouting": default_routing,
395393
}

0 commit comments

Comments
 (0)