Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions src/eligibility_signposting_api/model/rules.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import typing
from collections import Counter
from datetime import UTC, date, datetime
from enum import StrEnum
from functools import cached_property
from operator import attrgetter
from typing import Literal, NewType

from pydantic import BaseModel, Field, field_serializer, field_validator
from pydantic import BaseModel, Field, field_serializer, field_validator, model_validator

if typing.TYPE_CHECKING: # pragma: no cover
from pydantic import SerializationInfo
Expand Down Expand Up @@ -146,7 +147,7 @@ class CampaignConfig(BaseModel):
end_date: EndDate = Field(..., alias="EndDate")
approval_minimum: int | None = Field(None, alias="ApprovalMinimum")
approval_maximum: int | None = Field(None, alias="ApprovalMaximum")
iterations: list[Iteration] = Field(..., alias="Iterations")
iterations: list[Iteration] = Field(..., min_length=1, alias="Iterations")

model_config = {"populate_by_name": True, "arbitrary_types_allowed": True, "extra": "ignore"}

Expand All @@ -162,16 +163,47 @@ def parse_dates(cls, v: str | date) -> date:
def serialize_dates(v: date, _info: SerializationInfo) -> str:
return v.strftime("%Y%m%d")

@model_validator(mode="after")
def check_start_and_end_dates_sensible(self) -> typing.Self:
if self.start_date > self.end_date:
message = f"start date {self.start_date} after end date {self.end_date}"
raise ValueError(message)
return self

@model_validator(mode="after")
def check_no_overlapping_iterations(self) -> typing.Self:
iterations_by_date = Counter([i.iteration_date for i in self.iterations])
if multiple_found := next(((d, c) for d, c in iterations_by_date.most_common() if c > 1), None):
iteration_date, count = multiple_found
message = f"{count} iterations with iteration date {iteration_date} in campaign {self.id}"
raise ValueError(message)
return self

@model_validator(mode="after")
def check_has_iteration_from_start(self) -> typing.Self:
iterations_by_date = sorted(self.iterations, key=attrgetter("iteration_date"))
if first_iteration := next(iter(iterations_by_date), None):
if first_iteration.iteration_date > self.start_date:
message = (
f"campaign {self.id} starts on {self.start_date}, "
f"1st iteration starts later - {first_iteration.iteration_date}"
)
raise ValueError(message)
return self
# Should never happen, since we are constraining self.iterations with a min_length of 1
message = f"campaign {self.id} has no iterations."
raise ValueError(message)

@cached_property
def campaign_live(self) -> bool:
today = datetime.now(tz=UTC).date()
return self.start_date <= today <= self.end_date

@cached_property
def current_iteration(self) -> Iteration | None:
def current_iteration(self) -> Iteration:
today = datetime.now(tz=UTC).date()
iterations_by_date_descending = sorted(self.iterations, key=attrgetter("iteration_date"), reverse=True)
return next((i for i in iterations_by_date_descending if i.iteration_date <= today), None)
return next(i for i in iterations_by_date_descending if i.iteration_date <= today)


class Rules(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class EligibilityCalculator:

@property
def active_campaigns(self) -> list[rules.CampaignConfig]:
return [cc for cc in self.campaign_configs if cc.campaign_live and cc.current_iteration]
return [cc for cc in self.campaign_configs if cc.campaign_live]

@property
def campaigns_grouped_by_condition_name(
Expand Down Expand Up @@ -71,11 +71,8 @@ def get_the_base_eligible_campaigns(self, campaign_group: list[rules.CampaignCon
return base_eligible_campaigns
return []

def check_base_eligibility(self, iteration: rules.Iteration | None) -> bool:
def check_base_eligibility(self, iteration: rules.Iteration) -> bool:
"""Return cohorts for which person is base eligible."""

if not iteration:
return False # pragma: no cover
iteration_cohorts: set[str] = {
cohort.cohort_label for cohort in iteration.iteration_cohorts if cohort.cohort_label
}
Expand All @@ -100,7 +97,7 @@ def evaluate_eligibility_by_iteration_rules(

status_with_reasons: dict[eligibility.Status, list[eligibility.Reason]] = defaultdict()

for iteration in [cc.current_iteration for cc in campaign_group if cc.current_iteration]:
for iteration in [cc.current_iteration for cc in campaign_group]:
# Until we see a worse status, we assume someone is actionable for this iteration.
worst_status = eligibility.Status.actionable
exclusion_reasons, actionable_reasons = [], []
Expand Down
34 changes: 33 additions & 1 deletion tests/fixtures/builders/model/rule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import UTC, date, datetime, timedelta
from operator import attrgetter
from random import randint

from polyfactory import Use
Expand Down Expand Up @@ -27,13 +28,44 @@ class IterationFactory(ModelFactory[rules.Iteration]):
iteration_date = Use(past_date)


class CampaignConfigFactory(ModelFactory[rules.CampaignConfig]):
class RawCampaignConfigFactory(ModelFactory[rules.CampaignConfig]):
iterations = Use(IterationFactory.batch, size=2)

start_date = Use(past_date)
end_date = Use(future_date)


class CampaignConfigFactory(RawCampaignConfigFactory):
@classmethod
def build(cls, **kwargs) -> rules.CampaignConfig:
"""Ensure invariants are met:
* no iterations with duplicate iteration dates
* must have iteration active from campaign start date"""
processed_kwargs = cls.process_kwargs(**kwargs)
start_date: date = processed_kwargs["start_date"]
iterations: list[rules.Iteration] = processed_kwargs["iterations"]

CampaignConfigFactory.fix_iteration_date_invariants(iterations, start_date)

data = super().build(**processed_kwargs).dict()
return cls.__model__(**data)

@staticmethod
def fix_iteration_date_invariants(iterations: list[rules.Iteration], start_date: date) -> None:
iterations.sort(key=attrgetter("iteration_date"))
iterations[0].iteration_date = start_date

seen: set[date] = set()
previous: date = iterations[0].iteration_date
for iteration in iterations:
current = iteration.iteration_date if iteration.iteration_date >= previous else previous + timedelta(days=1)
while current in seen:
current += timedelta(days=1)
seen.add(current)
iteration.iteration_date = current
previous = current


class PersonAgeSuppressionRuleFactory(IterationRuleFactory):
type = rules.RuleType.suppression
name = rules.RuleName("Exclude too young less than 75")
Expand Down
Empty file added tests/unit/model/__init__.py
Empty file.
61 changes: 61 additions & 0 deletions tests/unit/model/test_rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest
from dateutil.relativedelta import relativedelta
from faker import Faker

from tests.fixtures.builders.model.rule import IterationFactory, RawCampaignConfigFactory


def test_campaign_must_have_at_least_one_iteration():
# Given

# When, Then
with pytest.raises(
ValueError,
match=r"1 validation error for CampaignConfig\n"
r"iterations\n"
r".*List should have at least 1 item",
):
RawCampaignConfigFactory.build(iterations=[])


def test_campaign_start_date_must_be_before_end_date(faker: Faker):
# Given
start_date = faker.date_object()
end_date = start_date - relativedelta(days=1)

# When, Then
with pytest.raises(
ValueError,
match=r"1 validation error for CampaignConfig\n"
r".*start date .* after end date",
):
RawCampaignConfigFactory.build(start_date=start_date, end_date=end_date)


def test_iteration_with_overlapping_start_dates_not_allowed(faker: Faker):
# Given
start_date = faker.date_object()
iteration1 = IterationFactory.build(iteration_date=start_date)
iteration2 = IterationFactory.build(iteration_date=start_date)

# When, Then
with pytest.raises(
ValueError,
match=r"1 validation error for CampaignConfig\n"
r".*2 iterations with iteration date",
):
RawCampaignConfigFactory.build(start_date=start_date, iterations=[iteration1, iteration2])


def test_iteration_must_have_active_iteration_from_its_start(faker: Faker):
# Given
start_date = faker.date_object()
iteration = IterationFactory.build(iteration_date=start_date + relativedelta(days=1))

# When, Then
with pytest.raises(
ValueError,
match=r"1 validation error for CampaignConfig\n"
r".*1st iteration starts later",
):
RawCampaignConfigFactory.build(start_date=start_date, iterations=[iteration])
28 changes: 1 addition & 27 deletions tests/unit/services/calculators/test_eligibility_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from faker import Faker
from freezegun import freeze_time
from hamcrest import assert_that, contains_exactly, empty, has_item, has_items
from hamcrest import assert_that, contains_exactly, has_item, has_items

from eligibility_signposting_api.model import rules
from eligibility_signposting_api.model import rules as rules_model
Expand Down Expand Up @@ -247,32 +247,6 @@ def test_simple_rule_only_excludes_from_live_iteration(faker: Faker):
)


@freeze_time("2025-04-25")
def test_campaign_with_no_active_iteration_not_considered(faker: Faker):
# Given
nhs_number = NHSNumber(faker.nhs_number())

person_rows = person_rows_builder(nhs_number)
campaign_configs = [
rule_builder.CampaignConfigFactory.build(
target="RSV",
iterations=[
rule_builder.IterationFactory.build(
iteration_date=rules_model.IterationDate(datetime.date(2025, 4, 26)),
)
],
)
]

calculator = EligibilityCalculator(person_rows, campaign_configs)

# When
actual = calculator.evaluate_eligibility()

# Then
assert_that(actual, is_eligibility_status().with_conditions(empty()))


@pytest.mark.parametrize(
("rule_type", "expected_status"),
[
Expand Down
Loading