Skip to content

Commit 1cc6750

Browse files
authored
Merge pull request #81 from NHSDigital/feature/ELI-151-cohort-based-filtering
QA: looks good! ELI-151 - Cohort based filtering
2 parents 7bb25f3 + 8a5bcfc commit 1cc6750

File tree

13 files changed

+506
-205
lines changed

13 files changed

+506
-205
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ ignore = ["COM812", "D"]
7474

7575
[tool.ruff.lint.per-file-ignores]
7676
"src/eligibility_signposting_api/repos/*" = ["ANN401"]
77-
"tests/*" = ["ANN", "INP", "S101", "S106"]
77+
"tests/*" = ["ANN", "INP", "S101", "S106", "S311"]
7878

7979
[tool.pyright]
8080
include = ["src/"]

src/eligibility_signposting_api/model/eligibility.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from dataclasses import dataclass
22
from datetime import date
33
from enum import Enum, auto
4-
from typing import NewType
4+
from functools import total_ordering
5+
from typing import NewType, Self
56

67
NHSNumber = NewType("NHSNumber", str)
78
DateOfBirth = NewType("DateOfBirth", date)
@@ -18,11 +19,17 @@ class RuleType(str, Enum):
1819
redirect = "R"
1920

2021

22+
@total_ordering
2123
class Status(Enum):
2224
not_eligible = auto()
2325
not_actionable = auto()
2426
actionable = auto()
2527

28+
def __lt__(self, other: Self) -> bool:
29+
if self.__class__ is other.__class__:
30+
return self.value < other.value
31+
return NotImplemented
32+
2633

2734
@dataclass
2835
class Reason:

src/eligibility_signposting_api/model/rules.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import typing
4-
from datetime import date, datetime
4+
from datetime import UTC, date, datetime
55
from enum import Enum
6+
from functools import cached_property
7+
from operator import attrgetter
68
from typing import Literal, NewType
79

810
from pydantic import BaseModel, Field, field_serializer, field_validator
@@ -24,6 +26,7 @@
2426
RuleComparator = NewType("RuleComparator", str)
2527
StartDate = NewType("StartDate", date)
2628
EndDate = NewType("EndDate", date)
29+
CohortLabel = NewType("CohortLabel", str)
2730

2831

2932
class RuleType(str, Enum):
@@ -77,10 +80,10 @@ class RuleAttributeLevel(str, Enum):
7780

7881

7982
class IterationCohort(BaseModel):
80-
cohort_label: str | None = Field(None, alias="CohortLabel")
83+
cohort_label: CohortLabel | None = Field(None, alias="CohortLabel")
8184
priority: int | None = Field(None, alias="Priority")
8285

83-
model_config = {"populate_by_name": True}
86+
model_config = {"populate_by_name": True, "extra": "ignore"}
8487

8588

8689
class IterationRule(BaseModel):
@@ -90,11 +93,12 @@ class IterationRule(BaseModel):
9093
priority: RulePriority = Field(..., alias="Priority")
9194
attribute_level: RuleAttributeLevel = Field(..., alias="AttributeLevel")
9295
attribute_name: RuleAttributeName = Field(..., alias="AttributeName")
96+
cohort_label: CohortLabel | None = Field(None, alias="CohortLabel")
9397
operator: RuleOperator = Field(..., alias="Operator")
9498
comparator: RuleComparator = Field(..., alias="Comparator")
9599
attribute_target: str | None = Field(None, alias="AttributeTarget")
96100

97-
model_config = {"populate_by_name": True}
101+
model_config = {"populate_by_name": True, "extra": "ignore"}
98102

99103

100104
class Iteration(BaseModel):
@@ -109,10 +113,7 @@ class Iteration(BaseModel):
109113
iteration_cohorts: list[IterationCohort] = Field(..., alias="IterationCohorts")
110114
iteration_rules: list[IterationRule] = Field(..., alias="IterationRules")
111115

112-
model_config = {
113-
"populate_by_name": True,
114-
"arbitrary_types_allowed": True,
115-
}
116+
model_config = {"populate_by_name": True, "arbitrary_types_allowed": True, "extra": "ignore"}
116117

117118
@field_validator("iteration_date", mode="before")
118119
@classmethod
@@ -146,10 +147,7 @@ class CampaignConfig(BaseModel):
146147
approval_maximum: int | None = Field(None, alias="ApprovalMaximum")
147148
iterations: list[Iteration] = Field(..., alias="Iterations")
148149

149-
model_config = {
150-
"populate_by_name": True,
151-
"arbitrary_types_allowed": True,
152-
}
150+
model_config = {"populate_by_name": True, "arbitrary_types_allowed": True, "extra": "ignore"}
153151

154152
@field_validator("start_date", "end_date", mode="before")
155153
@classmethod
@@ -163,6 +161,17 @@ def parse_dates(cls, v: str | date) -> date:
163161
def serialize_dates(v: date, _info: SerializationInfo) -> str:
164162
return v.strftime("%Y%m%d")
165163

164+
@cached_property
165+
def campaign_live(self) -> bool:
166+
today = datetime.now(tz=UTC).date()
167+
return self.start_date <= today <= self.end_date
168+
169+
@cached_property
170+
def current_iteration(self) -> Iteration | None:
171+
today = datetime.now(tz=UTC).date()
172+
iterations_by_date_descending = sorted(self.iterations, key=attrgetter("iteration_date"), reverse=True)
173+
return next((i for i in iterations_by_date_descending if i.iteration_date <= today), None)
174+
166175

167176
class Rules(BaseModel):
168177
"""Eligibility rules.
@@ -171,4 +180,4 @@ class Rules(BaseModel):
171180

172181
campaign_config: CampaignConfig = Field(..., alias="CampaignConfig")
173182

174-
model_config = {"populate_by_name": True}
183+
model_config = {"populate_by_name": True, "extra": "ignore"}

src/eligibility_signposting_api/services/eligibility_services.py

Lines changed: 152 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
2+
from collections import defaultdict
3+
from collections.abc import Collection, Mapping
4+
from typing import Any
25

36
from hamcrest.core.string_description import StringDescription
47
from wireup import service
58

6-
from eligibility_signposting_api.model import eligibility
7-
from eligibility_signposting_api.model.rules import CampaignConfig, IterationRule, RuleAttributeLevel
9+
from eligibility_signposting_api.model import eligibility, rules
810
from eligibility_signposting_api.repos import EligibilityRepo, NotFoundError, RulesRepo
911
from eligibility_signposting_api.services.rules.operators import OperatorRegistry
1012

@@ -40,59 +42,176 @@ def get_eligibility_status(self, nhs_number: eligibility.NHSNumber | None = None
4042
except NotFoundError as e:
4143
raise UnknownPersonError from e
4244
else:
43-
# TODO: Apply rules here # noqa: TD002, TD003, FIX002
4445
return self.evaluate_eligibility(campaign_configs, person_data)
4546

4647
raise UnknownPersonError # pragma: no cover
4748

4849
@staticmethod
4950
def evaluate_eligibility(
50-
campaign_configs: list[CampaignConfig], person_data: list[dict[str, str | None]]
51+
campaign_configs: Collection[rules.CampaignConfig], person_data: Collection[Mapping[str, Any]]
5152
) -> eligibility.EligibilityStatus:
5253
"""Calculate a person's eligibility for vaccination."""
54+
55+
# Get all iterations for which the person is base eligible, i.e. those which *might* provide eligibility
56+
# due to cohort membership.
57+
base_eligible_campaigns, condition_names = EligibilityService.get_base_eligible_campaigns(
58+
campaign_configs, person_data
59+
)
60+
# Evaluate iteration rules to see if the person is actionable, not actionable (due to "F" rules),
61+
# or not eligible (due to "S" rules")
62+
evaluations = EligibilityService.evaluate_for_base_eligible_campaigns(base_eligible_campaigns, person_data)
63+
5364
conditions: dict[eligibility.ConditionName, eligibility.Condition] = {}
54-
for campaign_config in campaign_configs:
65+
# Add all not base eligible conditions to result set.
66+
conditions |= EligibilityService.get_not_base_eligible_conditions(base_eligible_campaigns, condition_names)
67+
# Add all base eligible conditions to result set.
68+
conditions |= EligibilityService.get_base_eligible_conditions(evaluations)
69+
70+
return eligibility.EligibilityStatus(conditions=list(conditions.values()))
71+
72+
@staticmethod
73+
def get_base_eligible_campaigns(
74+
campaign_configs: Collection[rules.CampaignConfig], person_data: Collection[Mapping[str, Any]]
75+
) -> tuple[list[rules.CampaignConfig], set[eligibility.ConditionName]]:
76+
"""Get all campaigns for which the person is base eligible, i.e. those which *might* provide eligibility.
77+
78+
Build and return a collection of campaigns for which the person is base eligible (using cohorts).
79+
Also build and return a set of conditions in the campaigns while we are here.
80+
"""
81+
condition_names: set[eligibility.ConditionName] = set()
82+
base_eligible_campaigns: list[rules.CampaignConfig] = []
83+
84+
for campaign_config in (cc for cc in campaign_configs if cc.campaign_live and cc.current_iteration):
5585
condition_name = eligibility.ConditionName(campaign_config.target)
56-
condition = conditions.setdefault(
57-
condition_name,
58-
eligibility.Condition(condition_name=condition_name, status=eligibility.Status.actionable, reasons=[]),
59-
)
60-
for iteration_rule in [
61-
iteration_rule
62-
for iteration in campaign_config.iterations
63-
for iteration_rule in iteration.iteration_rules
64-
]:
65-
exclusion, reason = EligibilityService.evaluate_exclusion(iteration_rule, person_data)
66-
condition.reasons.append(
67-
eligibility.Reason(
68-
rule_type=eligibility.RuleType(iteration_rule.type),
69-
rule_name=eligibility.RuleName(iteration_rule.name),
70-
rule_result=eligibility.RuleResult(reason),
71-
)
86+
condition_names.add(condition_name)
87+
base_eligible = EligibilityService.evaluate_base_eligibility(campaign_config.current_iteration, person_data)
88+
if base_eligible:
89+
base_eligible_campaigns.append(campaign_config)
90+
91+
return base_eligible_campaigns, condition_names
92+
93+
@staticmethod
94+
def evaluate_base_eligibility(
95+
iteration: rules.Iteration | None, person_data: Collection[Mapping[str, Any]]
96+
) -> set[str]:
97+
"""Return cohorts for which person is base eligible."""
98+
if not iteration:
99+
return set()
100+
iteration_cohorts: set[str] = {
101+
cohort.cohort_label for cohort in iteration.iteration_cohorts if cohort.cohort_label
102+
}
103+
104+
cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next(
105+
(r for r in person_data if r.get("ATTRIBUTE_TYPE", "") == "COHORTS"), {}
106+
)
107+
person_cohorts = set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys())
108+
109+
return iteration_cohorts.intersection(person_cohorts)
110+
111+
@staticmethod
112+
def get_not_base_eligible_conditions(
113+
base_eligible_campaigns: Collection[rules.CampaignConfig],
114+
condition_names: Collection[eligibility.ConditionName],
115+
) -> dict[eligibility.ConditionName, eligibility.Condition]:
116+
"""Get conditions where the person is not base eligible,
117+
i.e. is not is the cohort for any campaign iteration."""
118+
119+
# for each condition:
120+
# if the person isn't base eligible for any iteration,
121+
# the person is not (base) eligible for the condition
122+
not_eligible_conditions: dict[eligibility.ConditionName, eligibility.Condition] = {}
123+
for condition_name in condition_names:
124+
if condition_name not in {eligibility.ConditionName(cc.target) for cc in base_eligible_campaigns}:
125+
not_eligible_conditions[condition_name] = eligibility.Condition(
126+
condition_name=condition_name, status=eligibility.Status.not_eligible, reasons=[]
72127
)
128+
return not_eligible_conditions
129+
130+
@staticmethod
131+
def evaluate_for_base_eligible_campaigns(
132+
base_eligible_campaigns: Collection[rules.CampaignConfig],
133+
person_data: Collection[Mapping[str, Any]],
134+
) -> dict[eligibility.ConditionName, dict[eligibility.Status, list[eligibility.Reason]]]:
135+
"""Evaluate iteration rules to see if the person is actionable, not actionable (due to "F" rules),
136+
or not eligible (due to "S" rules").
137+
138+
For each condition, evaluate all iterations for inclusion or exclusion."""
139+
base_eligible_evaluations: dict[
140+
eligibility.ConditionName, dict[eligibility.Status, list[eligibility.Reason]]
141+
] = defaultdict(dict)
142+
for condition_name, iteration in [
143+
(eligibility.ConditionName(cc.target), cc.current_iteration)
144+
for cc in base_eligible_campaigns
145+
if cc.current_iteration
146+
]:
147+
status = eligibility.Status.actionable
148+
exclusion_reasons, actionable_reasons = [], []
149+
for iteration_rule in iteration.iteration_rules:
150+
if iteration_rule.type not in (rules.RuleType.filter, rules.RuleType.suppression):
151+
continue
152+
exclusion, reason = EligibilityService.evaluate_exclusion(iteration_rule, person_data)
73153
if exclusion:
74-
condition.status = eligibility.Status.not_actionable
154+
status = min(
155+
status,
156+
eligibility.Status.not_eligible
157+
if iteration_rule.type == rules.RuleType.filter
158+
else eligibility.Status.not_actionable,
159+
)
160+
exclusion_reasons.append(reason)
161+
else:
162+
actionable_reasons.append(reason)
163+
condition_entry = base_eligible_evaluations.setdefault(condition_name, {})
164+
condition_status_entry = condition_entry.setdefault(status, [])
165+
condition_status_entry.extend(
166+
actionable_reasons if status is eligibility.Status.actionable else exclusion_reasons
167+
)
168+
return base_eligible_evaluations
75169

76-
return eligibility.EligibilityStatus(conditions=list(conditions.values()))
170+
@staticmethod
171+
def get_base_eligible_conditions(
172+
base_eligible_evaluations: Mapping[
173+
eligibility.ConditionName, Mapping[eligibility.Status, list[eligibility.Reason]]
174+
],
175+
) -> dict[eligibility.ConditionName, eligibility.Condition]:
176+
"""Get conditions where the person is base eligible, but may be either actionable, not actionable,
177+
or not eligible."""
178+
179+
# for each condition for which the person is base eligible:
180+
# what is the "best" status, i.e. closest to actionable? Add the condition to the result with that status.
181+
eligible_conditions: dict[eligibility.ConditionName, eligibility.Condition] = {}
182+
for condition_name, reasons_by_status in base_eligible_evaluations.items():
183+
best_status = max(reasons_by_status.keys())
184+
eligible_conditions[condition_name] = eligibility.Condition(
185+
condition_name=condition_name, status=best_status, reasons=reasons_by_status[best_status]
186+
)
187+
return eligible_conditions
77188

78189
@staticmethod
79-
def evaluate_exclusion(iteration_rule: IterationRule, person_data: list[dict[str, str | None]]) -> tuple[bool, str]:
190+
def evaluate_exclusion(
191+
iteration_rule: rules.IterationRule, person_data: Collection[Mapping[str, str | None]]
192+
) -> tuple[bool, eligibility.Reason]:
80193
"""Evaluate if a particular rule excludes this person. Return the result, and the reason for the result."""
81194
attribute_value = EligibilityService.get_attribute_value(iteration_rule, person_data)
82195
exclusion, reason = EligibilityService.evaluate_rule(iteration_rule, attribute_value)
83-
reason = (
84-
f"Rule {iteration_rule.name!r} ({iteration_rule.description!r}) "
85-
f"{'' if exclusion else 'not '}excluding - "
86-
f"{iteration_rule.attribute_name!r} {iteration_rule.comparator!r} {reason}"
196+
reason = eligibility.Reason(
197+
rule_name=eligibility.RuleName(iteration_rule.name),
198+
rule_type=eligibility.RuleType(iteration_rule.type),
199+
rule_result=eligibility.RuleResult(
200+
f"Rule {iteration_rule.name!r} ({iteration_rule.description!r}) "
201+
f"{'' if exclusion else 'not '}excluding - "
202+
f"{iteration_rule.attribute_name!r} {iteration_rule.comparator!r} {reason}"
203+
),
87204
)
88205
return exclusion, reason
89206

90207
@staticmethod
91-
def get_attribute_value(iteration_rule: IterationRule, person_data: list[dict[str, str | None]]) -> str | None:
208+
def get_attribute_value(
209+
iteration_rule: rules.IterationRule, person_data: Collection[Mapping[str, str | None]]
210+
) -> str | None:
92211
"""Pull out the correct attribute for a rule from the person's data."""
93212
match iteration_rule.attribute_level:
94-
case RuleAttributeLevel.PERSON:
95-
person: dict[str, str | None] | None = next(
213+
case rules.RuleAttributeLevel.PERSON:
214+
person: Mapping[str, str | None] | None = next(
96215
(r for r in person_data if r.get("ATTRIBUTE_TYPE", "") == "PERSON"), None
97216
)
98217
attribute_value = person.get(iteration_rule.attribute_name) if person else None
@@ -102,10 +221,10 @@ def get_attribute_value(iteration_rule: IterationRule, person_data: list[dict[st
102221
return attribute_value
103222

104223
@staticmethod
105-
def evaluate_rule(iteration_rule: IterationRule, attribute_value: str | None) -> tuple[bool, str]:
224+
def evaluate_rule(iteration_rule: rules.IterationRule, attribute_value: str | None) -> tuple[bool, str]:
106225
"""Evaluate a rule against a person data attribute. Return the result, and the reason for the result."""
107226
matcher_class = OperatorRegistry.get(iteration_rule.operator)
108-
matcher = matcher_class(iteration_rule.comparator)
227+
matcher = matcher_class(rule_value=iteration_rule.comparator)
109228

110229
reason = StringDescription()
111230
if matcher.matches(attribute_value):

tests/fixtures/builders/model/eligibility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@ class EligibilityStatusFactory(DataclassFactory[EligibilityStatus]):
1515

1616

1717
def random_str(length: int) -> str:
18-
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length)) # noqa: S311
18+
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length))

0 commit comments

Comments
 (0)