Skip to content

Commit 7f199f3

Browse files
Considered cohort_label in rules
1 parent ae2d338 commit 7f199f3

File tree

3 files changed

+62
-8
lines changed

3 files changed

+62
-8
lines changed

src/eligibility_signposting_api/services/calculators/eligibility_calculator.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections import defaultdict
55
from collections.abc import Collection, Iterator, Mapping
66
from dataclasses import dataclass, field
7+
from functools import cached_property
78
from itertools import groupby
89
from typing import Any
910

@@ -45,6 +46,13 @@ def campaigns_grouped_by_condition_name(
4546
):
4647
yield condition_name, list(campaign_group)
4748

49+
@cached_property
50+
def person_cohorts(self) -> set[str]:
51+
cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
52+
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), {}
53+
)
54+
return set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())
55+
4856
def evaluate_eligibility(self) -> eligibility.EligibilityStatus:
4957
"""Iterates over campaign groups, evaluates eligibility, and returns a consolidated status."""
5058

@@ -78,12 +86,7 @@ def check_base_eligibility(self, iteration: rules.Iteration) -> bool:
7886
}
7987
if magic_cohort in iteration_cohorts:
8088
return True
81-
82-
cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
83-
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), {}
84-
)
85-
person_cohorts: set[str] = set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())
86-
return bool(iteration_cohorts & person_cohorts)
89+
return bool(iteration_cohorts & self.person_cohorts)
8790

8891
def evaluate_eligibility_by_iteration_rules(
8992
self, campaign_group: list[rules.CampaignConfig]
@@ -123,8 +126,12 @@ def evaluate_priority_group(
123126
worst_status_so_far_for_condition: eligibility.Status,
124127
) -> tuple[eligibility.Status, list[eligibility.Reason], list[eligibility.Reason]]:
125128
exclusion_reasons, actionable_reasons = [], []
129+
126130
exclude_capable_rules = [
127-
ir for ir in iteration_rule_group if ir.type in (rules.RuleType.filter, rules.RuleType.suppression)
131+
ir
132+
for ir in iteration_rule_group
133+
if ir.type in (rules.RuleType.filter, rules.RuleType.suppression)
134+
and (ir.cohort_label is None or (ir.cohort_label in self.person_cohorts))
128135
]
129136

130137
best_status = eligibility.Status.not_eligible if exclude_capable_rules else eligibility.Status.actionable

tests/fixtures/builders/model/rule.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def future_date(days_ahead: int = 365) -> date:
1919
class IterationCohortFactory(ModelFactory[rules.IterationCohort]): ...
2020

2121

22-
class IterationRuleFactory(ModelFactory[rules.IterationRule]): ...
22+
class IterationRuleFactory(ModelFactory[rules.IterationRule]):
23+
attribute_target = None
24+
cohort_label = None
2325

2426

2527
class IterationFactory(ModelFactory[rules.Iteration]):

tests/unit/services/calculators/test_eligibility_calculator.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,48 @@ def test_status_on_cohort_attribute_level(faker: Faker):
763763
has_item(is_condition().with_condition_name(ConditionName("RSV")).and_status(Status.not_eligible))
764764
),
765765
)
766+
767+
768+
@pytest.mark.parametrize(
769+
("person_cohorts", "cohort", "expected_status", "test_comment"),
770+
[
771+
(["cohort1", "cohort2"], "cohort1", Status.not_actionable, "matches the cohort label"),
772+
(["cohort2", "cohort3"], "cohort1", Status.actionable, "doesn't match the cohort label"),
773+
],
774+
)
775+
def test_status_if_iteration_rules_contains_cohort_label_field(
776+
person_cohorts, cohort: str, expected_status: Status, test_comment: str, faker: Faker
777+
):
778+
# Given
779+
nhs_number = NHSNumber(faker.nhs_number())
780+
date_of_birth = DateOfBirth(faker.date_of_birth(minimum_age=66, maximum_age=74))
781+
782+
person_rows = person_rows_builder(nhs_number, date_of_birth=date_of_birth, cohorts=person_cohorts)
783+
campaign_configs = [
784+
rule_builder.CampaignConfigFactory.build(
785+
target="RSV",
786+
iterations=[
787+
rule_builder.IterationFactory.build(
788+
iteration_cohorts=[
789+
rule_builder.IterationCohortFactory.build(cohort_label="cohort1"),
790+
rule_builder.IterationCohortFactory.build(cohort_label="cohort2"),
791+
],
792+
iteration_rules=[rule_builder.PersonAgeSuppressionRuleFactory.build(cohort_label=cohort)],
793+
)
794+
],
795+
)
796+
]
797+
798+
calculator = EligibilityCalculator(person_rows, campaign_configs)
799+
800+
# When
801+
actual = calculator.evaluate_eligibility()
802+
803+
# Then
804+
assert_that(
805+
actual,
806+
is_eligibility_status().with_conditions(
807+
has_items(is_condition().with_condition_name(ConditionName("RSV")).and_status(expected_status))
808+
),
809+
test_comment,
810+
)

0 commit comments

Comments
 (0)