Skip to content

Commit d3ea593

Browse files
eli 372 - IR cohort labels can have multiple labels separated by ",". (#463)
* eli 372 - IR cohort labels can have multiple labels seperated by ",". * eli 372 - added new tests for comma - separated cohort labels * eli 372 - validation test fixed * eli 372 - added more test to validation cohort label * wip integration tests - test_regardless_of_final_status_audit_all_types_of_cohort_status_rules_based_on_cohort_labels * Revert "wip integration tests - test_regardless_of_final_status_audit_all_types_of_cohort_status_rules_based_on_cohort_labels" This reverts commit f373cd0. * wip integration tests - test_multiple_comma_seperated_cohort_labels --------- Co-authored-by: ayeshalshukri1-nhs <[email protected]>
1 parent 8566dc8 commit d3ea593

File tree

6 files changed

+157
-27
lines changed

6 files changed

+157
-27
lines changed

src/eligibility_signposting_api/model/campaign_config.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,16 @@ def normalize_virtual(cls, value: str) -> Virtual:
138138
class IterationRule(BaseModel):
139139
type: RuleType = Field(..., alias="Type")
140140
name: RuleName = Field(..., alias="Name")
141-
code: RuleCode | None = Field(None, alias="Code", description="use the `rule_code` property instead.")
142-
description: RuleDescription = Field(..., alias="Description", description="use the `rule_text` property instead.")
141+
code: RuleCode | None = Field(None, alias="Code", description="use `rule_code` property instead.")
142+
description: RuleDescription = Field(..., alias="Description", description="use `rule_text` property instead.")
143143
priority: RulePriority = Field(..., alias="Priority")
144144
attribute_level: RuleAttributeLevel = Field(..., alias="AttributeLevel")
145145
attribute_name: RuleAttributeName | None = Field(None, alias="AttributeName")
146-
cohort_label: CohortLabel | None = Field(None, alias="CohortLabel")
146+
cohort_label: CohortLabel | None = Field(
147+
None,
148+
alias="CohortLabel",
149+
description="Raw label input. Prefer using `parsed_cohort_labels` for normalized access.",
150+
)
147151
operator: RuleOperator = Field(..., alias="Operator")
148152
comparator: RuleComparator = Field(..., alias="Comparator")
149153
attribute_target: RuleAttributeTarget | None = Field(None, alias="AttributeTarget")
@@ -197,6 +201,18 @@ def rule_text(self) -> str:
197201
rule_text = rule_entry.rule_text
198202
return rule_text or self.description
199203

204+
@cached_property
205+
def parsed_cohort_labels(self) -> list[str]:
206+
"""
207+
Parses the cohort_label string into a list of individual labels.
208+
209+
Returns:
210+
A list of cohort labels, split by comma. If no label is set, returns an empty list.
211+
"""
212+
if not self.cohort_label:
213+
return []
214+
return [label.strip() for label in self.cohort_label.split(",") if label.strip()]
215+
200216
def __str__(self) -> str:
201217
return json.dumps(self.model_dump(by_alias=True), indent=2)
202218

src/eligibility_signposting_api/services/processors/rule_processor.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,10 @@ def is_actionable(
123123

124124
@staticmethod
125125
def _should_skip_rule_group(cohort: IterationCohort, group_rules: list[IterationRule]) -> bool:
126-
cohort_specific_rules = [rule for rule in group_rules if rule.cohort_label is not None]
127-
matching_specific_rules = [rule for rule in cohort_specific_rules if rule.cohort_label == cohort.cohort_label]
126+
cohort_specific_rules = [rule for rule in group_rules if rule.parsed_cohort_labels]
127+
matching_specific_rules = [
128+
rule for rule in cohort_specific_rules if cohort.cohort_label in rule.parsed_cohort_labels
129+
]
128130
return bool(cohort_specific_rules and not matching_specific_rules)
129131

130132
def evaluate_rules_priority_group(
@@ -148,13 +150,7 @@ def evaluate_rules_priority_group(
148150

149151
@staticmethod
150152
def get_exclusion_rules(cohort: IterationCohort, rules: Iterable[IterationRule]) -> Iterator[IterationRule]:
151-
return (
152-
ir
153-
for ir in rules
154-
if ir.cohort_label is None
155-
or cohort.cohort_label == ir.cohort_label
156-
or (isinstance(ir.cohort_label, (list, set, tuple)) and cohort.cohort_label in ir.cohort_label)
157-
)
153+
return (ir for ir in rules if not ir.parsed_cohort_labels or cohort.cohort_label in ir.parsed_cohort_labels)
158154

159155
def get_cohort_group_results(
160156
self, person: Person, active_iteration: Iteration

tests/unit/services/calculators/test_eligibility_calculator.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,46 @@ def test_status_on_target_based_on_last_successful_date(
382382
)
383383

384384

385+
def test_multiple_comma_seperated_cohort_labels(faker: Faker):
386+
# Given
387+
nhs_number = NHSNumber(faker.nhs_number())
388+
date_of_birth = DateOfBirth(faker.date_of_birth(minimum_age=66, maximum_age=74))
389+
390+
person_rows = person_rows_builder(nhs_number, date_of_birth=date_of_birth, cohorts=["cohort5", "cohort6"])
391+
campaign_configs = [
392+
rule_builder.CampaignConfigFactory.build(
393+
target="RSV",
394+
iterations=[
395+
rule_builder.IterationFactory.build(
396+
iteration_cohorts=[
397+
rule_builder.IterationCohortFactory.build(cohort_label="cohort5", cohort_group="test1"),
398+
rule_builder.IterationCohortFactory.build(cohort_label="cohort6", cohort_group="test2"),
399+
],
400+
iteration_rules=[
401+
rule_builder.PersonAgeSuppressionRuleFactory.build(cohort_label="cohort5, cohort6")
402+
],
403+
)
404+
],
405+
)
406+
]
407+
408+
calculator = EligibilityCalculator(person_rows, campaign_configs)
409+
410+
# When
411+
actual = calculator.get_eligibility_status("Y", ["ALL"], "ALL")
412+
413+
# Then
414+
assert_that(
415+
actual,
416+
is_eligibility_status().with_conditions(
417+
has_items(is_condition().with_condition_name(ConditionName("RSV")).and_status(Status.not_actionable))
418+
),
419+
)
420+
421+
all_cohorts = [cohort.cohort_code for cohort in g.audit_log.response.condition[0].eligibility_cohorts]
422+
assert_that(all_cohorts, contains_inanyorder("cohort6", "cohort5"))
423+
424+
385425
@pytest.mark.parametrize(
386426
("person_cohorts", "expected_status", "test_comment"),
387427
[

tests/unit/services/processors/test_cohort_handler.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44
from hamcrest import assert_that, has_length, is_
55

6+
from eligibility_signposting_api.model.campaign_config import IterationCohort, IterationRule
67
from eligibility_signposting_api.model.eligibility_status import CohortGroupResult, Status
78
from eligibility_signposting_api.model.person import Person
89
from eligibility_signposting_api.services.processors.cohort_handler import (
@@ -86,12 +87,22 @@ def test_filter_rule_handler_is_not_eligible(mock_rule_processor_for_handlers, m
8687
cohort = rule_builder.IterationCohortFactory.build(cohort_label="cohort1", negative_description="Not Eligible")
8788
cohort_results = {}
8889

89-
mock_rule_processor_for_handlers.is_eligible.side_effect = (
90-
lambda p, c, cr, fr: cr.update( # noqa: ARG005
91-
{c.cohort_label: CohortGroupResult(c.cohort_group, Status.not_eligible, [], c.negative_description, [])}
90+
def mark_not_eligible_side_effect(
91+
person: Person, # noqa : ARG001
92+
context: IterationCohort,
93+
results: dict[str, CohortGroupResult],
94+
rules: list[IterationRule], # noqa : ARG001
95+
) -> bool:
96+
results.update(
97+
{
98+
context.cohort_label: CohortGroupResult(
99+
context.cohort_group, Status.not_eligible, [], context.negative_description, []
100+
)
101+
}
92102
)
93-
or False
94-
)
103+
return False
104+
105+
mock_rule_processor_for_handlers.is_eligible.side_effect = mark_not_eligible_side_effect
95106

96107
handler.handle(MOCK_PERSON, cohort, cohort_results, mock_rule_processor_for_handlers)
97108

@@ -109,9 +120,21 @@ def test_suppression_rule_handler_is_actionable(mock_rule_processor_for_handlers
109120
cohort = rule_builder.IterationCohortFactory.build(cohort_label="cohort1", positive_description="Actionable")
110121
cohort_results = {}
111122

112-
mock_rule_processor_for_handlers.is_actionable.side_effect = lambda p, c, cr, sr: cr.update( # noqa: ARG005
113-
{c.cohort_label: CohortGroupResult(c.cohort_group, Status.actionable, [], c.positive_description, [])}
114-
)
123+
def mark_actionable_side_effect(
124+
person: Person, # noqa : ARG001
125+
context: IterationCohort,
126+
results: dict[str, CohortGroupResult],
127+
rules: list[IterationRule], # noqa : ARG001
128+
) -> None:
129+
results.update(
130+
{
131+
context.cohort_label: CohortGroupResult(
132+
context.cohort_group, Status.actionable, [], context.positive_description, []
133+
)
134+
}
135+
)
136+
137+
mock_rule_processor_for_handlers.is_actionable.side_effect = mark_actionable_side_effect
115138

116139
handler.handle(MOCK_PERSON, cohort, cohort_results, mock_rule_processor_for_handlers)
117140

tests/unit/services/processors/test_rule_processor.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44
from hamcrest import assert_that, empty, is_
55

6-
from eligibility_signposting_api.model.campaign_config import CohortLabel, RuleType
6+
from eligibility_signposting_api.model.campaign_config import CohortLabel, IterationCohort, RuleType
77
from eligibility_signposting_api.model.eligibility_status import CohortGroupResult, Reason, RuleName, Status
88
from eligibility_signposting_api.model.person import Person
99
from eligibility_signposting_api.services.processors.person_data_reader import PersonDataReader
@@ -48,6 +48,14 @@ def test_get_exclusion_rules_matching_cohort_label():
4848
assert_that(result, is_([matching_rule]))
4949

5050

51+
def test_get_exclusion_rules_matching_cohort_label_when_it_contains_multiple_cohort_labels():
52+
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
53+
matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A,COHORT_B")
54+
rules_to_filter = [matching_rule]
55+
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
56+
assert_that(result, is_([matching_rule]))
57+
58+
5159
def test_get_exclusion_rules_non_matching_cohort_label():
5260
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
5361
non_matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B")
@@ -56,6 +64,14 @@ def test_get_exclusion_rules_non_matching_cohort_label():
5664
assert_that(result, is_([]))
5765

5866

67+
def test_get_exclusion_rules_non_matching_cohort_label_when_it_contains_multiple_cohort_labels():
68+
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
69+
non_matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B,COHORT_C")
70+
rules_to_filter = [non_matching_rule]
71+
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
72+
assert_that(result, is_([]))
73+
74+
5975
def test_get_exclusion_rules_matching_from_list_cohort_label():
6076
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
6177
rule1 = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A")
@@ -65,6 +81,15 @@ def test_get_exclusion_rules_matching_from_list_cohort_label():
6581
assert_that(result, is_([rule1]))
6682

6783

84+
def test_get_exclusion_rules_matching_from_list_cohort_label_when_it_contains_multiple_cohort_labels():
85+
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
86+
rule1 = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A")
87+
rule2 = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B,COHORT_C")
88+
rules_to_filter = [rule1, rule2]
89+
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
90+
assert_that(result, is_([rule1]))
91+
92+
6893
def test_get_exclusion_rules_mixed_rules():
6994
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
7095
no_cohort_label_rule = rule_builder.IterationRuleFactory.build(cohort_label=None, name="General")
@@ -76,6 +101,17 @@ def test_get_exclusion_rules_mixed_rules():
76101
assert_that({r.name for r in result}, is_({"General", "Matching"}))
77102

78103

104+
def test_get_exclusion_rules_mixed_rules_when_it_contains_multiple_cohort_labels():
105+
cohort = rule_builder.IterationCohortFactory.build(cohort_label="COHORT_A")
106+
no_cohort_label_rule = rule_builder.IterationRuleFactory.build(cohort_label=None, name="General")
107+
matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_A,COHORT_C", name="Matching")
108+
non_matching_rule = rule_builder.IterationRuleFactory.build(cohort_label="COHORT_B,COHORT_C", name="NonMatching")
109+
110+
rules_to_filter = [no_cohort_label_rule, matching_rule, non_matching_rule]
111+
result = list(RuleProcessor.get_exclusion_rules(cohort, rules_to_filter))
112+
assert_that({r.name for r in result}, is_({"General", "Matching"}))
113+
114+
79115
@patch("eligibility_signposting_api.services.processors.rule_processor.RuleCalculator")
80116
def test_evaluate_rules_priority_group_all_actionable(mock_rule_calculator_class, rule_processor):
81117
mock_rule_calculator_class.return_value.evaluate_exclusion.return_value = (
@@ -570,17 +606,22 @@ def test_get_cohort_group_results(
570606
suppression_rules = (rule_builder.IterationRuleFactory.build(type=RuleType.suppression),)
571607
mock_get_rules_by_type.return_value = (filter_rules, suppression_rules)
572608

573-
def mock_handle_side_effect(person, cohort, cohort_results_dict, rule_processor_instance): # noqa: ARG001
609+
def mock_handle_side_effect(
610+
person: Person, # noqa: ARG001
611+
cohort: IterationCohort,
612+
cohort_results: dict[CohortLabel, CohortGroupResult],
613+
rule_processor_instance: RuleProcessor, # noqa: ARG001
614+
):
574615
if cohort.cohort_label == CohortLabel("COHORT_A"):
575-
cohort_results_dict[CohortLabel("COHORT_A")] = CohortGroupResult(
616+
cohort_results[CohortLabel("COHORT_A")] = CohortGroupResult(
576617
cohort_code=cohort.cohort_group,
577618
status=Status.actionable,
578619
reasons=[],
579620
description="Cohort A Description",
580621
audit_rules=[],
581622
)
582623
elif cohort.cohort_label == CohortLabel("COHORT_B"):
583-
cohort_results_dict[CohortLabel("COHORT_B")] = CohortGroupResult(
624+
cohort_results[CohortLabel("COHORT_B")] = CohortGroupResult(
584625
cohort_code=cohort.cohort_group,
585626
status=Status.not_eligible,
586627
reasons=[],

tests/unit/validation/test_iteration_rules_validator.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,12 +163,26 @@ def test_invalid_attribute_name(self, attr_name, valid_iteration_rule_with_only_
163163
IterationRuleValidation(**data)
164164

165165
# CohortLabel
166-
@pytest.mark.parametrize("label", ["Cohort_A", "Segment_2025", None, ""])
167-
def test_valid_cohort_label(self, label, valid_iteration_rule_with_only_mandatory_fields):
166+
@pytest.mark.parametrize(
167+
("label", "expected_parsed_cohort_label"),
168+
[
169+
("Cohort_A", ["Cohort_A"]),
170+
("Cohort_A,Cohort_B", ["Cohort_A", "Cohort_B"]),
171+
("Cohort_C,,,,", ["Cohort_C"]),
172+
("Cohort_D,,,,Cohort_E", ["Cohort_D", "Cohort_E"]),
173+
(",,,,Cohort_E,,,,", ["Cohort_E"]),
174+
(",,,,", []),
175+
("", []),
176+
(None, []),
177+
],
178+
)
179+
def test_cohort_label_parsing(
180+
self, label, expected_parsed_cohort_label, valid_iteration_rule_with_only_mandatory_fields
181+
):
168182
data = valid_iteration_rule_with_only_mandatory_fields.copy()
169183
data["CohortLabel"] = label
170184
result = IterationRuleValidation(**data)
171-
assert result.cohort_label == label
185+
assert result.parsed_cohort_labels == expected_parsed_cohort_label
172186

173187
@pytest.mark.parametrize("label", [123, [], {}])
174188
def test_invalid_cohort_label(self, label, valid_iteration_rule_with_only_mandatory_fields):

0 commit comments

Comments
 (0)