Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ aiohttp = "^3.11.12"
awscli = "^1.37.24"
awscli-local = "^0.22.0"
pyhamcrest = "^2.1.0"
factory-boy = "^3.3.3"
polyfactory = "^2.20.0"
pyright = "^1.1.394"
brunns-matchers = "^2.9.0"
localstack = "^4.1.1"
Expand Down
1 change: 1 addition & 0 deletions src/eligibility_signposting_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def config() -> dict[str, Any]:
"dynamodb_endpoint": URL(os.getenv("DYNAMODB_ENDPOINT", "http://localhost:4566")),
"aws_secret_access_key": AwsSecretAccessKey(os.getenv("AWS_SECRET_ACCESS_KEY", "dummy_secret")),
"log_level": LOG_LEVEL,
"rules_bucket_name": AwsAccessKey(os.getenv("RULES_BUCKET_NAME", "test-rules-bucket")),
}


Expand Down
127 changes: 127 additions & 0 deletions src/eligibility_signposting_api/model/rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import typing
from datetime import date, datetime
from enum import Enum
from typing import Literal, NewType

from pydantic import BaseModel, Field, field_serializer, field_validator

if typing.TYPE_CHECKING:
from pydantic import SerializationInfo

BucketName = NewType("BucketName", str)
CampaignName = NewType("CampaignName", str)
CampaignVersion = NewType("CampaignVersion", str)
CampaignID = NewType("CampaignID", str)
IterationName = NewType("IterationName", str)
IterationVersion = NewType("IterationVersion", str)
IterationID = NewType("IterationID", str)
RuleName = NewType("RuleName", str)
RuleDescription = NewType("RuleDescription", str)
RulePriority = NewType("RulePriority", int)
RuleAttributeLevel = NewType("RuleAttributeLevel", str)
RuleAttributeName = NewType("RuleAttributeName", str)
RuleComparator = NewType("RuleComparator", str)
StartDate = NewType("StartDate", date)
EndDate = NewType("EndDate", date)


class RuleType(str, Enum):
filter = "F"
suppression = "S"
redirect = "R"


class RuleOperator(str, Enum):
lt = "<"
gt = ">"
year_gt = "Y>"
not_in = "not_in"
equals = "="
lte = "<="
ne = "!="
date_gte = "D>="
member_of = "MemberOf"


class IterationCohort(BaseModel):
cohort_label: str | None = Field(None, alias="CohortLabel")
priority: int | None = Field(None, alias="Priority")

model_config = {"populate_by_name": True}


class IterationRule(BaseModel):
type: RuleType = Field(..., alias="Type")
name: RuleName = Field(..., alias="Name")
description: RuleDescription = Field(..., alias="Description")
priority: RulePriority = Field(..., alias="Priority")
attribute_level: RuleAttributeLevel = Field(..., alias="AttributeLevel")
attribute_name: RuleAttributeName = Field(..., alias="AttributeName")
operator: RuleOperator = Field(..., alias="Operator")
comparator: RuleComparator = Field(..., alias="Comparator")
attribute_target: str | None = Field(None, alias="AttributeTarget")
comms_routing: str | None = Field(None, alias="CommsRouting")

model_config = {"populate_by_name": True}


class Iteration(BaseModel):
id: IterationID = Field(..., alias="ID")
default_comms_routing: str | None = Field(None, alias="DefaultCommsRouting")
version: IterationVersion = Field(..., alias="Version")
name: IterationName = Field(..., alias="Name")
iteration_date: str | None = Field(None, alias="IterationDate")
iteration_number: int | None = Field(None, alias="IterationNumber")
comms_type: Literal["I", "R"] = Field(..., alias="CommsType")
approval_minimum: int | None = Field(None, alias="ApprovalMinimum")
approval_maximum: int | None = Field(None, alias="ApprovalMaximum")
type: Literal["A", "M", "S"] = Field(..., alias="Type")
iteration_cohorts: list[IterationCohort] | None = Field(None, alias="IterationCohorts")
iteration_rules: list[IterationRule] | None = Field(None, alias="IterationRules")

model_config = {"populate_by_name": True}


class CampaignConfig(BaseModel):
id: CampaignID = Field(..., alias="ID")
version: CampaignVersion = Field(..., alias="Version")
name: CampaignName = Field(..., alias="Name")
type: Literal["V", "S"] = Field(..., alias="Type")
target: Literal["COVID", "FLU", "MMR", "RSV"] = Field(..., alias="Target")
manager: str | None = Field(None, alias="Manager")
approver: str | None = Field(None, alias="Approver")
reviewer: str | None = Field(None, alias="Reviewer")
iteration_frequency: Literal["X", "D", "W", "M", "Q", "A"] = Field(..., alias="IterationFrequency")
iteration_type: Literal["A", "M", "S"] = Field(..., alias="IterationType")
iteration_time: str | None = Field(None, alias="IterationTime")
default_comms_routing: str | None = Field(None, alias="DefaultCommsRouting")
start_date: StartDate = Field(..., alias="StartDate")
end_date: EndDate = Field(..., alias="EndDate")
approval_minimum: int | None = Field(None, alias="ApprovalMinimum")
approval_maximum: int | None = Field(None, alias="ApprovalMaximum")
iterations: list[Iteration] | None = Field(None, alias="Iterations")

model_config = {
"populate_by_name": True,
"arbitrary_types_allowed": True,
}

@field_validator("start_date", "end_date", mode="before")
@classmethod
def parse_dates(cls, v: str | date) -> date:
if isinstance(v, date):
return v
return datetime.strptime(v, "%Y%m%d").date() # noqa: DTZ007

@field_serializer("start_date", "end_date", when_used="always")
@staticmethod
def serialize_dates(v: date, _info: SerializationInfo) -> str:
return v.strftime("%Y%m%d")


class Rules(BaseModel):
campaign_config: CampaignConfig = Field(..., alias="CampaignConfig")

model_config = {"populate_by_name": True}
8 changes: 8 additions & 0 deletions src/eligibility_signposting_api/repos/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from boto3 import Session
from boto3.resources.base import ServiceResource
from botocore.client import BaseClient
from wireup import Inject, service
from yarl import URL

Expand All @@ -27,3 +28,10 @@ def dynamodb_resource_factory(
session: Session, dynamodb_endpoint: Annotated[URL, Inject(param="dynamodb_endpoint")]
) -> ServiceResource:
return session.resource("dynamodb", endpoint_url=str(dynamodb_endpoint))


@service(qualifier="s3")
def s3_service_factory(
session: Session, dynamodb_endpoint: Annotated[URL, Inject(param="dynamodb_endpoint")]
) -> BaseClient:
return session.client("s3", endpoint_url=str(dynamodb_endpoint))
24 changes: 24 additions & 0 deletions src/eligibility_signposting_api/repos/rules_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import json
from typing import Annotated

from botocore.client import BaseClient
from wireup import Inject, service

from eligibility_signposting_api.model.rules import BucketName, CampaignConfig, CampaignName, Rules


@service
class RulesRepo:
def __init__(
self,
s3_client: Annotated[BaseClient, Inject(qualifier="s3")],
bucket_name: Annotated[BucketName, Inject(param="rules_bucket_name")],
) -> None:
super().__init__()
self.s3_client = s3_client
self.bucket_name = bucket_name

def get_campaign_config(self, campaign: CampaignName) -> CampaignConfig:
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=f"{campaign}.json")
body = response["Body"].read()
return Rules.model_validate(json.loads(body)).campaign_config
42 changes: 42 additions & 0 deletions tests/integration/repo/test_rules_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json
from collections.abc import Generator

import pytest
from botocore.client import BaseClient
from hamcrest import assert_that

from eligibility_signposting_api.model.rules import BucketName, CampaignConfig
from eligibility_signposting_api.repos.rules_repo import RulesRepo
from tests.integration.conftest import AWS_REGION
from tests.utils.builders import CampaignConfigFactory, random_str
from tests.utils.rules.rules import is_campaign_config


@pytest.fixture
def bucket(s3_client: BaseClient) -> Generator[BucketName]:
bucket_name = BucketName(random_str(63))
s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": AWS_REGION})
yield bucket_name
s3_client.delete_bucket(Bucket=bucket_name)


@pytest.fixture
def campaign(s3_client: BaseClient, bucket: BucketName) -> Generator[CampaignConfig]:
campaign: CampaignConfig = CampaignConfigFactory.build()
campaign_data = {"CampaignConfig": campaign.model_dump(by_alias=True)}
s3_client.put_object(
Bucket=bucket, Key=f"{campaign.name}.json", Body=json.dumps(campaign_data), ContentType="application/json"
)
yield campaign
s3_client.delete_object(Bucket=bucket, Key=f"{campaign.name}.json")


def test_get_campaign_config(s3_client: BaseClient, bucket: BucketName, campaign: CampaignConfig):
# Given
repo = RulesRepo(s3_client, bucket)

# When
actual = repo.get_campaign_config(campaign.name)

# Then
assert_that(actual, is_campaign_config().with_id(campaign.id).and_name(campaign.name).and_version(campaign.version))
30 changes: 30 additions & 0 deletions tests/utils/builders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import random
import string

from polyfactory import Use
from polyfactory.factories.pydantic_factory import ModelFactory

from eligibility_signposting_api.model.rules import CampaignConfig, Iteration, IterationCohort, IterationRule


class IterationCohortFactory(ModelFactory[IterationCohort]):
__model__ = IterationCohort


class IterationRuleFactory(ModelFactory[IterationRule]):
__model__ = IterationRule


class IterationFactory(ModelFactory[Iteration]):
__model__ = Iteration
iteration_cohorts = Use(IterationCohortFactory.batch, size=2)
iteration_rules = Use(IterationRuleFactory.batch, size=2)


class CampaignConfigFactory(ModelFactory[CampaignConfig]):
__model__ = CampaignConfig
iterations = Use(IterationFactory.batch, size=2)


def random_str(length: int) -> str:
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length)) # noqa: S311
64 changes: 64 additions & 0 deletions tests/utils/rules/rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from brunns.matchers.utils import append_matcher_description, describe_field_match, describe_field_mismatch
from hamcrest import anything
from hamcrest.core.base_matcher import BaseMatcher
from hamcrest.core.description import Description
from hamcrest.core.helpers.wrap_matcher import wrap_matcher
from hamcrest.core.matcher import Matcher

from eligibility_signposting_api.model.rules import CampaignConfig, CampaignID, CampaignName, CampaignVersion

ANYTHING = anything()


class CampaignConfigMatcher(BaseMatcher[CampaignConfig]):
def __init__(self):
super().__init__()
self.id_: Matcher[CampaignID] = ANYTHING
self.name: Matcher[CampaignName] = ANYTHING
self.version: Matcher[CampaignVersion] = ANYTHING

def describe_to(self, description: Description) -> None:
description.append_text("CampaignConfig with")
append_matcher_description(self.id_, "id", description)
append_matcher_description(self.name, "name", description)
append_matcher_description(self.version, "version", description)

def _matches(self, item: CampaignConfig) -> bool:
return self.id_.matches(item.id) and self.name.matches(item.name) and self.version.matches(item.version)

def describe_mismatch(self, item: CampaignConfig, mismatch_description: Description) -> None:
mismatch_description.append_text("was CampaignConfig with")
describe_field_mismatch(self.id_, "id", item.id, mismatch_description)
describe_field_mismatch(self.name, "name", item.name, mismatch_description)
describe_field_mismatch(self.version, "version", item.version, mismatch_description)

def describe_match(self, item: CampaignConfig, match_description: Description) -> None:
match_description.append_text("was CampaignConfig with")
describe_field_match(self.id_, "id", item.id, match_description)
describe_field_match(self.name, "name", item.name, match_description)
describe_field_match(self.version, "version", item.version, match_description)

def with_id(self, id_: CampaignID | Matcher[CampaignID]):
self.id_ = wrap_matcher(id_)
return self

def and_id(self, id_: CampaignID | Matcher[CampaignID]):
return self.with_id(id_)

def with_name(self, name: CampaignName | Matcher[CampaignName]):
self.name = wrap_matcher(name)
return self

def and_name(self, name: CampaignName | Matcher[CampaignName]):
return self.with_name(name)

def with_version(self, version: CampaignVersion | Matcher[CampaignVersion]):
self.version = wrap_matcher(version)
return self

def and_version(self, version: CampaignVersion | Matcher[CampaignVersion]):
return self.with_version(version)


def is_campaign_config() -> Matcher[CampaignConfig]:
return CampaignConfigMatcher()
Loading