diff --git a/src/eligibility_signposting_api/model/rules.py b/src/eligibility_signposting_api/model/rules.py index 6048dabb..c416b42e 100644 --- a/src/eligibility_signposting_api/model/rules.py +++ b/src/eligibility_signposting_api/model/rules.py @@ -1,13 +1,14 @@ from __future__ import annotations import typing +from collections import Counter from datetime import UTC, date, datetime from enum import StrEnum from functools import cached_property from operator import attrgetter from typing import Literal, NewType -from pydantic import BaseModel, Field, field_serializer, field_validator +from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator if typing.TYPE_CHECKING: # pragma: no cover from pydantic import SerializationInfo @@ -146,7 +147,7 @@ class CampaignConfig(BaseModel): end_date: EndDate = Field(..., alias="EndDate") approval_minimum: int | None = Field(None, alias="ApprovalMinimum") approval_maximum: int | None = Field(None, alias="ApprovalMaximum") - iterations: list[Iteration] = Field(..., alias="Iterations") + iterations: list[Iteration] = Field(..., min_length=1, alias="Iterations") model_config = {"populate_by_name": True, "arbitrary_types_allowed": True, "extra": "ignore"} @@ -162,16 +163,47 @@ def parse_dates(cls, v: str | date) -> date: def serialize_dates(v: date, _info: SerializationInfo) -> str: return v.strftime("%Y%m%d") + @model_validator(mode="after") + def check_start_and_end_dates_sensible(self) -> typing.Self: + if self.start_date > self.end_date: + message = f"start date {self.start_date} after end date {self.end_date}" + raise ValueError(message) + return self + + @model_validator(mode="after") + def check_no_overlapping_iterations(self) -> typing.Self: + iterations_by_date = Counter([i.iteration_date for i in self.iterations]) + if multiple_found := next(((d, c) for d, c in iterations_by_date.most_common() if c > 1), None): + iteration_date, count = multiple_found + message = f"{count} iterations with iteration date {iteration_date} in campaign {self.id}" + raise ValueError(message) + return self + + @model_validator(mode="after") + def check_has_iteration_from_start(self) -> typing.Self: + iterations_by_date = sorted(self.iterations, key=attrgetter("iteration_date")) + if first_iteration := next(iter(iterations_by_date), None): + if first_iteration.iteration_date > self.start_date: + message = ( + f"campaign {self.id} starts on {self.start_date}, " + f"1st iteration starts later - {first_iteration.iteration_date}" + ) + raise ValueError(message) + return self + # Should never happen, since we are constraining self.iterations with a min_length of 1 + message = f"campaign {self.id} has no iterations." + raise ValueError(message) + @cached_property def campaign_live(self) -> bool: today = datetime.now(tz=UTC).date() return self.start_date <= today <= self.end_date @cached_property - def current_iteration(self) -> Iteration | None: + def current_iteration(self) -> Iteration: today = datetime.now(tz=UTC).date() iterations_by_date_descending = sorted(self.iterations, key=attrgetter("iteration_date"), reverse=True) - return next((i for i in iterations_by_date_descending if i.iteration_date <= today), None) + return next(i for i in iterations_by_date_descending if i.iteration_date <= today) class Rules(BaseModel): diff --git a/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py b/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py index 31419b83..0042c7d7 100644 --- a/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py +++ b/src/eligibility_signposting_api/services/calculators/eligibility_calculator.py @@ -32,7 +32,7 @@ class EligibilityCalculator: @property def active_campaigns(self) -> list[rules.CampaignConfig]: - return [cc for cc in self.campaign_configs if cc.campaign_live and cc.current_iteration] + return [cc for cc in self.campaign_configs if cc.campaign_live] @property def campaigns_grouped_by_condition_name( @@ -71,11 +71,8 @@ def get_the_base_eligible_campaigns(self, campaign_group: list[rules.CampaignCon return base_eligible_campaigns return [] - def check_base_eligibility(self, iteration: rules.Iteration | None) -> bool: + def check_base_eligibility(self, iteration: rules.Iteration) -> bool: """Return cohorts for which person is base eligible.""" - - if not iteration: - return False # pragma: no cover iteration_cohorts: set[str] = { cohort.cohort_label for cohort in iteration.iteration_cohorts if cohort.cohort_label } @@ -100,7 +97,7 @@ def evaluate_eligibility_by_iteration_rules( status_with_reasons: dict[eligibility.Status, list[eligibility.Reason]] = defaultdict() - for iteration in [cc.current_iteration for cc in campaign_group if cc.current_iteration]: + for iteration in [cc.current_iteration for cc in campaign_group]: # Until we see a worse status, we assume someone is actionable for this iteration. worst_status = eligibility.Status.actionable exclusion_reasons, actionable_reasons = [], [] diff --git a/tests/fixtures/builders/model/rule.py b/tests/fixtures/builders/model/rule.py index 38b88276..f1a04fbe 100644 --- a/tests/fixtures/builders/model/rule.py +++ b/tests/fixtures/builders/model/rule.py @@ -1,4 +1,5 @@ from datetime import UTC, date, datetime, timedelta +from operator import attrgetter from random import randint from polyfactory import Use @@ -27,13 +28,44 @@ class IterationFactory(ModelFactory[rules.Iteration]): iteration_date = Use(past_date) -class CampaignConfigFactory(ModelFactory[rules.CampaignConfig]): +class RawCampaignConfigFactory(ModelFactory[rules.CampaignConfig]): iterations = Use(IterationFactory.batch, size=2) start_date = Use(past_date) end_date = Use(future_date) +class CampaignConfigFactory(RawCampaignConfigFactory): + @classmethod + def build(cls, **kwargs) -> rules.CampaignConfig: + """Ensure invariants are met: + * no iterations with duplicate iteration dates + * must have iteration active from campaign start date""" + processed_kwargs = cls.process_kwargs(**kwargs) + start_date: date = processed_kwargs["start_date"] + iterations: list[rules.Iteration] = processed_kwargs["iterations"] + + CampaignConfigFactory.fix_iteration_date_invariants(iterations, start_date) + + data = super().build(**processed_kwargs).dict() + return cls.__model__(**data) + + @staticmethod + def fix_iteration_date_invariants(iterations: list[rules.Iteration], start_date: date) -> None: + iterations.sort(key=attrgetter("iteration_date")) + iterations[0].iteration_date = start_date + + seen: set[date] = set() + previous: date = iterations[0].iteration_date + for iteration in iterations: + current = iteration.iteration_date if iteration.iteration_date >= previous else previous + timedelta(days=1) + while current in seen: + current += timedelta(days=1) + seen.add(current) + iteration.iteration_date = current + previous = current + + class PersonAgeSuppressionRuleFactory(IterationRuleFactory): type = rules.RuleType.suppression name = rules.RuleName("Exclude too young less than 75") diff --git a/tests/unit/model/__init__.py b/tests/unit/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/model/test_rules.py b/tests/unit/model/test_rules.py new file mode 100644 index 00000000..cde9c7b4 --- /dev/null +++ b/tests/unit/model/test_rules.py @@ -0,0 +1,61 @@ +import pytest +from dateutil.relativedelta import relativedelta +from faker import Faker + +from tests.fixtures.builders.model.rule import IterationFactory, RawCampaignConfigFactory + + +def test_campaign_must_have_at_least_one_iteration(): + # Given + + # When, Then + with pytest.raises( + ValueError, + match=r"1 validation error for CampaignConfig\n" + r"iterations\n" + r".*List should have at least 1 item", + ): + RawCampaignConfigFactory.build(iterations=[]) + + +def test_campaign_start_date_must_be_before_end_date(faker: Faker): + # Given + start_date = faker.date_object() + end_date = start_date - relativedelta(days=1) + + # When, Then + with pytest.raises( + ValueError, + match=r"1 validation error for CampaignConfig\n" + r".*start date .* after end date", + ): + RawCampaignConfigFactory.build(start_date=start_date, end_date=end_date) + + +def test_iteration_with_overlapping_start_dates_not_allowed(faker: Faker): + # Given + start_date = faker.date_object() + iteration1 = IterationFactory.build(iteration_date=start_date) + iteration2 = IterationFactory.build(iteration_date=start_date) + + # When, Then + with pytest.raises( + ValueError, + match=r"1 validation error for CampaignConfig\n" + r".*2 iterations with iteration date", + ): + RawCampaignConfigFactory.build(start_date=start_date, iterations=[iteration1, iteration2]) + + +def test_iteration_must_have_active_iteration_from_its_start(faker: Faker): + # Given + start_date = faker.date_object() + iteration = IterationFactory.build(iteration_date=start_date + relativedelta(days=1)) + + # When, Then + with pytest.raises( + ValueError, + match=r"1 validation error for CampaignConfig\n" + r".*1st iteration starts later", + ): + RawCampaignConfigFactory.build(start_date=start_date, iterations=[iteration]) diff --git a/tests/unit/services/calculators/test_eligibility_calculator.py b/tests/unit/services/calculators/test_eligibility_calculator.py index 2e6ed793..8e8e564c 100644 --- a/tests/unit/services/calculators/test_eligibility_calculator.py +++ b/tests/unit/services/calculators/test_eligibility_calculator.py @@ -3,7 +3,7 @@ import pytest from faker import Faker from freezegun import freeze_time -from hamcrest import assert_that, contains_exactly, empty, has_item, has_items +from hamcrest import assert_that, contains_exactly, has_item, has_items from eligibility_signposting_api.model import rules from eligibility_signposting_api.model import rules as rules_model @@ -247,32 +247,6 @@ def test_simple_rule_only_excludes_from_live_iteration(faker: Faker): ) -@freeze_time("2025-04-25") -def test_campaign_with_no_active_iteration_not_considered(faker: Faker): - # Given - nhs_number = NHSNumber(faker.nhs_number()) - - person_rows = person_rows_builder(nhs_number) - campaign_configs = [ - rule_builder.CampaignConfigFactory.build( - target="RSV", - iterations=[ - rule_builder.IterationFactory.build( - iteration_date=rules_model.IterationDate(datetime.date(2025, 4, 26)), - ) - ], - ) - ] - - calculator = EligibilityCalculator(person_rows, campaign_configs) - - # When - actual = calculator.evaluate_eligibility() - - # Then - assert_that(actual, is_eligibility_status().with_conditions(empty())) - - @pytest.mark.parametrize( ("rule_type", "expected_status"), [