Skip to content

Commit 3c2136c

Browse files
Introduced COHORT_LABEL field in iteration rules (#103)
* Considered cohort_label in rules * Test get value from all attribute_levels * change param name * test_status_on_target_based_on_last_successful_date fixed field value.
1 parent ae2d338 commit 3c2136c

File tree

4 files changed

+113
-9
lines changed

4 files changed

+113
-9
lines changed

src/eligibility_signposting_api/services/calculators/eligibility_calculator.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections import defaultdict
55
from collections.abc import Collection, Iterator, Mapping
66
from dataclasses import dataclass, field
7+
from functools import cached_property
78
from itertools import groupby
89
from typing import Any
910

@@ -45,6 +46,13 @@ def campaigns_grouped_by_condition_name(
4546
):
4647
yield condition_name, list(campaign_group)
4748

49+
@cached_property
50+
def person_cohorts(self) -> set[str]:
51+
cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
52+
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), {}
53+
)
54+
return set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())
55+
4856
def evaluate_eligibility(self) -> eligibility.EligibilityStatus:
4957
"""Iterates over campaign groups, evaluates eligibility, and returns a consolidated status."""
5058

@@ -78,12 +86,7 @@ def check_base_eligibility(self, iteration: rules.Iteration) -> bool:
7886
}
7987
if magic_cohort in iteration_cohorts:
8088
return True
81-
82-
cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
83-
(row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), {}
84-
)
85-
person_cohorts: set[str] = set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())
86-
return bool(iteration_cohorts & person_cohorts)
89+
return bool(iteration_cohorts & self.person_cohorts)
8790

8891
def evaluate_eligibility_by_iteration_rules(
8992
self, campaign_group: list[rules.CampaignConfig]
@@ -123,8 +126,12 @@ def evaluate_priority_group(
123126
worst_status_so_far_for_condition: eligibility.Status,
124127
) -> tuple[eligibility.Status, list[eligibility.Reason], list[eligibility.Reason]]:
125128
exclusion_reasons, actionable_reasons = [], []
129+
126130
exclude_capable_rules = [
127-
ir for ir in iteration_rule_group if ir.type in (rules.RuleType.filter, rules.RuleType.suppression)
131+
ir
132+
for ir in iteration_rule_group
133+
if ir.type in (rules.RuleType.filter, rules.RuleType.suppression)
134+
and (ir.cohort_label is None or (ir.cohort_label in self.person_cohorts))
128135
]
129136

130137
best_status = eligibility.Status.not_eligible if exclude_capable_rules else eligibility.Status.actionable

tests/fixtures/builders/model/rule.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ def future_date(days_ahead: int = 365) -> date:
1919
class IterationCohortFactory(ModelFactory[rules.IterationCohort]): ...
2020

2121

22-
class IterationRuleFactory(ModelFactory[rules.IterationRule]): ...
22+
class IterationRuleFactory(ModelFactory[rules.IterationRule]):
23+
attribute_target = None
24+
cohort_label = None
2325

2426

2527
class IterationFactory(ModelFactory[rules.Iteration]):

tests/unit/services/calculators/test_eligibility_calculator.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def test_status_on_target_based_on_last_successful_date(
689689
),
690690
rule_builder.IterationRuleFactory.build(
691691
type=rules.RuleType.suppression,
692-
name=rules.RuleName("You have a future booking to be vaccinated against RSV"),
692+
name=rules.RuleName("You have a vaccination date in the future for RSV"),
693693
description=rules.RuleDescription("Exclude anyone with future Completed RSV Vaccination"),
694694
priority=10,
695695
operator=rules.RuleOperator.day_lte,
@@ -763,3 +763,48 @@ def test_status_on_cohort_attribute_level(faker: Faker):
763763
has_item(is_condition().with_condition_name(ConditionName("RSV")).and_status(Status.not_eligible))
764764
),
765765
)
766+
767+
768+
@pytest.mark.parametrize(
769+
("person_cohorts", "cohort_label", "expected_status", "test_comment"),
770+
[
771+
(["cohort1", "cohort2"], "cohort1", Status.not_actionable, "matches the cohort label"),
772+
(["cohort2", "cohort3"], "cohort1", Status.actionable, "doesn't match the cohort label"),
773+
],
774+
)
775+
def test_status_if_iteration_rules_contains_cohort_label_field(
776+
person_cohorts, cohort_label: str, expected_status: Status, test_comment: str, faker: Faker
777+
):
778+
# Given
779+
nhs_number = NHSNumber(faker.nhs_number())
780+
date_of_birth = DateOfBirth(faker.date_of_birth(minimum_age=66, maximum_age=74))
781+
782+
person_rows = person_rows_builder(nhs_number, date_of_birth=date_of_birth, cohorts=person_cohorts)
783+
campaign_configs = [
784+
rule_builder.CampaignConfigFactory.build(
785+
target="RSV",
786+
iterations=[
787+
rule_builder.IterationFactory.build(
788+
iteration_cohorts=[
789+
rule_builder.IterationCohortFactory.build(cohort_label="cohort1"),
790+
rule_builder.IterationCohortFactory.build(cohort_label="cohort2"),
791+
],
792+
iteration_rules=[rule_builder.PersonAgeSuppressionRuleFactory.build(cohort_label=cohort_label)],
793+
)
794+
],
795+
)
796+
]
797+
798+
calculator = EligibilityCalculator(person_rows, campaign_configs)
799+
800+
# When
801+
actual = calculator.evaluate_eligibility()
802+
803+
# Then
804+
assert_that(
805+
actual,
806+
is_eligibility_status().with_conditions(
807+
has_items(is_condition().with_condition_name(ConditionName("RSV")).and_status(expected_status))
808+
),
809+
test_comment,
810+
)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from collections.abc import Collection, Mapping
2+
from typing import Any
3+
4+
import pytest
5+
6+
from eligibility_signposting_api.model import rules
7+
from eligibility_signposting_api.services.calculators.rule_calculator import RuleCalculator
8+
from tests.fixtures.builders.model import rule as rule_builder
9+
10+
Row = Collection[Mapping[str, Any]]
11+
12+
13+
@pytest.mark.parametrize(
14+
("person_data", "rule", "expected"),
15+
[
16+
# PERSON attribute level
17+
(
18+
[{"ATTRIBUTE_TYPE": "PERSON", "POSTCODE": "SW19"}],
19+
rule_builder.IterationRuleFactory.build(
20+
attribute_level=rules.RuleAttributeLevel.PERSON, attribute_name="POSTCODE"
21+
),
22+
"SW19",
23+
),
24+
# TARGET attribute level
25+
(
26+
[{"ATTRIBUTE_TYPE": "RSV", "LAST_SUCCESSFUL_DATE": "20240101"}],
27+
rule_builder.IterationRuleFactory.build(
28+
attribute_level=rules.RuleAttributeLevel.TARGET,
29+
attribute_name="LAST_SUCCESSFUL_DATE",
30+
attribute_target="RSV",
31+
),
32+
"20240101",
33+
),
34+
# COHORT attribute level
35+
(
36+
[{"ATTRIBUTE_TYPE": "COHORTS", "COHORT_LABEL": ""}],
37+
rule_builder.IterationRuleFactory.build(
38+
attribute_level=rules.RuleAttributeLevel.COHORT, attribute_name="COHORT_LABEL"
39+
),
40+
"",
41+
),
42+
],
43+
)
44+
def test_get_attribute_value_for_all_attribute_levels(person_data: Row, rule: rules.IterationRule, expected: str):
45+
# Given
46+
calc = RuleCalculator(person_data=person_data, rule=rule)
47+
# When
48+
actual = calc.get_attribute_value()
49+
# Then
50+
assert actual == expected

0 commit comments

Comments
 (0)