diff --git a/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py b/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py index 0042c7d7..8dbe951b 100644 --- a/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py +++ b/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py @@ -4,6 +4,7 @@ from collections import defaultdict from collections.abc import Collection, Iterator, Mapping from dataclasses import dataclass, field +from functools import cached_property from itertools import groupby from typing import Any @@ -45,6 +46,13 @@ def campaigns_grouped_by_condition_name( ): yield condition_name, list(campaign_group) + @cached_property + def person_cohorts(self) -> set[str]: + cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next( + (row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), {} + ) + return set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys()) + def evaluate_eligibility(self) -> eligibility.EligibilityStatus: """Iterates over campaign groups, evaluates eligibility, and returns a consolidated status.""" @@ -78,12 +86,7 @@ def check_base_eligibility(self, iteration: rules.Iteration) -> bool: } if magic_cohort in iteration_cohorts: return True - - cohorts_row: Mapping[str, dict[str, dict[str, dict[str, Any]]]] = next( - (row for row in self.person_data if row.get("ATTRIBUTE_TYPE") == "COHORTS"), {} - ) - person_cohorts: set[str] = set(cohorts_row.get("COHORT_MAP", {}).get("cohorts", {}).get("M", {}).keys()) - return bool(iteration_cohorts & person_cohorts) + return bool(iteration_cohorts & self.person_cohorts) def evaluate_eligibility_by_iteration_rules( self, campaign_group: list[rules.CampaignConfig] @@ -123,8 +126,12 @@ def evaluate_priority_group( worst_status_so_far_for_condition: eligibility.Status, ) -> tuple[eligibility.Status, list[eligibility.Reason], list[eligibility.Reason]]: exclusion_reasons, actionable_reasons = [], [] + exclude_capable_rules = [ - ir for ir in iteration_rule_group if ir.type in (rules.RuleType.filter, rules.RuleType.suppression) + ir + for ir in iteration_rule_group + if ir.type in (rules.RuleType.filter, rules.RuleType.suppression) + and (ir.cohort_label is None or (ir.cohort_label in self.person_cohorts)) ] best_status = eligibility.Status.not_eligible if exclude_capable_rules else eligibility.Status.actionable diff --git a/tests/fixtures/builders/model/rule.py b/tests/fixtures/builders/model/rule.py index f1a04fbe..90fa1839 100644 --- a/tests/fixtures/builders/model/rule.py +++ b/tests/fixtures/builders/model/rule.py @@ -19,7 +19,9 @@ def future_date(days_ahead: int = 365) -> date: class IterationCohortFactory(ModelFactory[rules.IterationCohort]): ... -class IterationRuleFactory(ModelFactory[rules.IterationRule]): ... +class IterationRuleFactory(ModelFactory[rules.IterationRule]): + attribute_target = None + cohort_label = None class IterationFactory(ModelFactory[rules.Iteration]): diff --git a/tests/unit/services/calculators/test_eligibility_calculator.py b/tests/unit/services/calculators/test_eligibility_calculator.py index 893cefb7..8b4d18d2 100644 --- a/tests/unit/services/calculators/test_eligibility_calculator.py +++ b/tests/unit/services/calculators/test_eligibility_calculator.py @@ -689,7 +689,7 @@ def test_status_on_target_based_on_last_successful_date( ), rule_builder.IterationRuleFactory.build( type=rules.RuleType.suppression, - name=rules.RuleName("You have a future booking to be vaccinated against RSV"), + name=rules.RuleName("You have a vaccination date in the future for RSV"), description=rules.RuleDescription("Exclude anyone with future Completed RSV Vaccination"), priority=10, operator=rules.RuleOperator.day_lte, @@ -763,3 +763,48 @@ def test_status_on_cohort_attribute_level(faker: Faker): has_item(is_condition().with_condition_name(ConditionName("RSV")).and_status(Status.not_eligible)) ), ) + + +@pytest.mark.parametrize( + ("person_cohorts", "cohort_label", "expected_status", "test_comment"), + [ + (["cohort1", "cohort2"], "cohort1", Status.not_actionable, "matches the cohort label"), + (["cohort2", "cohort3"], "cohort1", Status.actionable, "doesn't match the cohort label"), + ], +) +def test_status_if_iteration_rules_contains_cohort_label_field( + person_cohorts, cohort_label: str, expected_status: Status, test_comment: str, faker: Faker +): + # Given + nhs_number = NHSNumber(faker.nhs_number()) + date_of_birth = DateOfBirth(faker.date_of_birth(minimum_age=66, maximum_age=74)) + + person_rows = person_rows_builder(nhs_number, date_of_birth=date_of_birth, cohorts=person_cohorts) + campaign_configs = [ + rule_builder.CampaignConfigFactory.build( + target="RSV", + iterations=[ + rule_builder.IterationFactory.build( + iteration_cohorts=[ + rule_builder.IterationCohortFactory.build(cohort_label="cohort1"), + rule_builder.IterationCohortFactory.build(cohort_label="cohort2"), + ], + iteration_rules=[rule_builder.PersonAgeSuppressionRuleFactory.build(cohort_label=cohort_label)], + ) + ], + ) + ] + + calculator = EligibilityCalculator(person_rows, campaign_configs) + + # When + actual = calculator.evaluate_eligibility() + + # Then + assert_that( + actual, + is_eligibility_status().with_conditions( + has_items(is_condition().with_condition_name(ConditionName("RSV")).and_status(expected_status)) + ), + test_comment, + ) diff --git a/tests/unit/services/calculators/test_rule_calculator.py b/tests/unit/services/calculators/test_rule_calculator.py new file mode 100644 index 00000000..b8069a16 --- /dev/null +++ b/tests/unit/services/calculators/test_rule_calculator.py @@ -0,0 +1,50 @@ +from collections.abc import Collection, Mapping +from typing import Any + +import pytest + +from eligibility_signposting_api.model import rules +from eligibility_signposting_api.services.calculators.rule_calculator import RuleCalculator +from tests.fixtures.builders.model import rule as rule_builder + +Row = Collection[Mapping[str, Any]] + + +@pytest.mark.parametrize( + ("person_data", "rule", "expected"), + [ + # PERSON attribute level + ( + [{"ATTRIBUTE_TYPE": "PERSON", "POSTCODE": "SW19"}], + rule_builder.IterationRuleFactory.build( + attribute_level=rules.RuleAttributeLevel.PERSON, attribute_name="POSTCODE" + ), + "SW19", + ), + # TARGET attribute level + ( + [{"ATTRIBUTE_TYPE": "RSV", "LAST_SUCCESSFUL_DATE": "20240101"}], + rule_builder.IterationRuleFactory.build( + attribute_level=rules.RuleAttributeLevel.TARGET, + attribute_name="LAST_SUCCESSFUL_DATE", + attribute_target="RSV", + ), + "20240101", + ), + # COHORT attribute level + ( + [{"ATTRIBUTE_TYPE": "COHORTS", "COHORT_LABEL": ""}], + rule_builder.IterationRuleFactory.build( + attribute_level=rules.RuleAttributeLevel.COHORT, attribute_name="COHORT_LABEL" + ), + "", + ), + ], +) +def test_get_attribute_value_for_all_attribute_levels(person_data: Row, rule: rules.IterationRule, expected: str): + # Given + calc = RuleCalculator(person_data=person_data, rule=rule) + # When + actual = calc.get_attribute_value() + # Then + assert actual == expected