diff --git a/poetry.lock b/poetry.lock index ac424155..88936cb2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -888,25 +888,6 @@ files = [ [package.extras] tests = ["pytest"] -[[package]] -name = "factory-boy" -version = "3.3.3" -description = "A versatile test fixtures replacement based on thoughtbot's factory_bot for Ruby." -optional = false -python-versions = ">=3.8" -groups = ["dev"] -files = [ - {file = "factory_boy-3.3.3-py2.py3-none-any.whl", hash = "sha256:1c39e3289f7e667c4285433f305f8d506efc2fe9c73aaea4151ebd5cdea394fc"}, - {file = "factory_boy-3.3.3.tar.gz", hash = "sha256:866862d226128dfac7f2b4160287e899daf54f2612778327dd03d0e2cb1e3d03"}, -] - -[package.dependencies] -Faker = ">=0.7.0" - -[package.extras] -dev = ["Django", "Pillow", "SQLAlchemy", "coverage", "flake8", "isort", "mongoengine", "mongomock", "mypy", "tox", "wheel (>=0.32.0)", "zest.releaser[recommended]"] -doc = ["Sphinx", "sphinx-rtd-theme", "sphinxcontrib-spelling"] - [[package]] name = "faker" version = "37.1.0" @@ -1848,6 +1829,31 @@ files = [ {file = "ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3"}, ] +[[package]] +name = "polyfactory" +version = "2.20.0" +description = "Mock data generation factories" +optional = false +python-versions = "<4.0,>=3.8" +groups = ["dev"] +files = [ + {file = "polyfactory-2.20.0-py3-none-any.whl", hash = "sha256:6a808454bb03afacf54abeeb50d79b86c9e5b8476efc2bc3788e5ece26dd561a"}, + {file = "polyfactory-2.20.0.tar.gz", hash = "sha256:86017160f05332baadb5eaf89885e1ba7bb447a3140e46ba4546848c76cbdec5"}, +] + +[package.dependencies] +faker = ">=5.0.0" +typing-extensions = ">=4.6.0" + +[package.extras] +attrs = ["attrs (>=22.2.0)"] +beanie = ["beanie", "pydantic[email]", "pymongo (<4.9)"] +full = ["attrs", "beanie", "msgspec", "odmantic", "pydantic", "sqlalchemy"] +msgspec = ["msgspec"] +odmantic = ["odmantic (<1.0.0)", "pydantic[email]"] +pydantic = ["pydantic[email] (>=1.10)"] +sqlalchemy = ["sqlalchemy (>=1.4.29)"] + [[package]] name = "propcache" version = "0.3.1" @@ -3099,4 +3105,4 @@ propcache = ">=0.2.1" [metadata] lock-version = "2.1" python-versions = "^3.13" -content-hash = "9ddc81f25040399c771d440e69689b07da3f37f32686a6f41d3ef1843228b029" +content-hash = "8ce1a4f2f462b9d778ce3ffddbc85bc30d70b17c0877342c76136c947ebcf333" diff --git a/pyproject.toml b/pyproject.toml index fe08435e..651b3321 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ aiohttp = "^3.11.12" awscli = "^1.37.24" awscli-local = "^0.22.0" pyhamcrest = "^2.1.0" -factory-boy = "^3.3.3" +polyfactory = "^2.20.0" pyright = "^1.1.394" brunns-matchers = "^2.9.0" localstack = "^4.1.1" diff --git a/src/eligibility_signposting_api/config.py b/src/eligibility_signposting_api/config.py index 726f5b34..9f5682ab 100644 --- a/src/eligibility_signposting_api/config.py +++ b/src/eligibility_signposting_api/config.py @@ -22,6 +22,7 @@ def config() -> dict[str, Any]: "dynamodb_endpoint": URL(os.getenv("DYNAMODB_ENDPOINT", "http://localhost:4566")), "aws_secret_access_key": AwsSecretAccessKey(os.getenv("AWS_SECRET_ACCESS_KEY", "dummy_secret")), "log_level": LOG_LEVEL, + "rules_bucket_name": AwsAccessKey(os.getenv("RULES_BUCKET_NAME", "test-rules-bucket")), } diff --git a/src/eligibility_signposting_api/model/rules.py b/src/eligibility_signposting_api/model/rules.py new file mode 100644 index 00000000..2abdc9e0 --- /dev/null +++ b/src/eligibility_signposting_api/model/rules.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import typing +from datetime import date, datetime +from enum import Enum +from typing import Literal, NewType + +from pydantic import BaseModel, Field, field_serializer, field_validator + +if typing.TYPE_CHECKING: + from pydantic import SerializationInfo + +BucketName = NewType("BucketName", str) +CampaignName = NewType("CampaignName", str) +CampaignVersion = NewType("CampaignVersion", str) +CampaignID = NewType("CampaignID", str) +IterationName = NewType("IterationName", str) +IterationVersion = NewType("IterationVersion", str) +IterationID = NewType("IterationID", str) +RuleName = NewType("RuleName", str) +RuleDescription = NewType("RuleDescription", str) +RulePriority = NewType("RulePriority", int) +RuleAttributeLevel = NewType("RuleAttributeLevel", str) +RuleAttributeName = NewType("RuleAttributeName", str) +RuleComparator = NewType("RuleComparator", str) +StartDate = NewType("StartDate", date) +EndDate = NewType("EndDate", date) + + +class RuleType(str, Enum): + filter = "F" + suppression = "S" + redirect = "R" + + +class RuleOperator(str, Enum): + lt = "<" + gt = ">" + year_gt = "Y>" + not_in = "not_in" + equals = "=" + lte = "<=" + ne = "!=" + date_gte = "D>=" + member_of = "MemberOf" + + +class IterationCohort(BaseModel): + cohort_label: str | None = Field(None, alias="CohortLabel") + priority: int | None = Field(None, alias="Priority") + + model_config = {"populate_by_name": True} + + +class IterationRule(BaseModel): + type: RuleType = Field(..., alias="Type") + name: RuleName = Field(..., alias="Name") + description: RuleDescription = Field(..., alias="Description") + priority: RulePriority = Field(..., alias="Priority") + attribute_level: RuleAttributeLevel = Field(..., alias="AttributeLevel") + attribute_name: RuleAttributeName = Field(..., alias="AttributeName") + operator: RuleOperator = Field(..., alias="Operator") + comparator: RuleComparator = Field(..., alias="Comparator") + attribute_target: str | None = Field(None, alias="AttributeTarget") + comms_routing: str | None = Field(None, alias="CommsRouting") + + model_config = {"populate_by_name": True} + + +class Iteration(BaseModel): + id: IterationID = Field(..., alias="ID") + default_comms_routing: str | None = Field(None, alias="DefaultCommsRouting") + version: IterationVersion = Field(..., alias="Version") + name: IterationName = Field(..., alias="Name") + iteration_date: str | None = Field(None, alias="IterationDate") + iteration_number: int | None = Field(None, alias="IterationNumber") + comms_type: Literal["I", "R"] = Field(..., alias="CommsType") + approval_minimum: int | None = Field(None, alias="ApprovalMinimum") + approval_maximum: int | None = Field(None, alias="ApprovalMaximum") + type: Literal["A", "M", "S"] = Field(..., alias="Type") + iteration_cohorts: list[IterationCohort] | None = Field(None, alias="IterationCohorts") + iteration_rules: list[IterationRule] | None = Field(None, alias="IterationRules") + + model_config = {"populate_by_name": True} + + +class CampaignConfig(BaseModel): + id: CampaignID = Field(..., alias="ID") + version: CampaignVersion = Field(..., alias="Version") + name: CampaignName = Field(..., alias="Name") + type: Literal["V", "S"] = Field(..., alias="Type") + target: Literal["COVID", "FLU", "MMR", "RSV"] = Field(..., alias="Target") + manager: str | None = Field(None, alias="Manager") + approver: str | None = Field(None, alias="Approver") + reviewer: str | None = Field(None, alias="Reviewer") + iteration_frequency: Literal["X", "D", "W", "M", "Q", "A"] = Field(..., alias="IterationFrequency") + iteration_type: Literal["A", "M", "S"] = Field(..., alias="IterationType") + iteration_time: str | None = Field(None, alias="IterationTime") + default_comms_routing: str | None = Field(None, alias="DefaultCommsRouting") + start_date: StartDate = Field(..., alias="StartDate") + end_date: EndDate = Field(..., alias="EndDate") + approval_minimum: int | None = Field(None, alias="ApprovalMinimum") + approval_maximum: int | None = Field(None, alias="ApprovalMaximum") + iterations: list[Iteration] | None = Field(None, alias="Iterations") + + model_config = { + "populate_by_name": True, + "arbitrary_types_allowed": True, + } + + @field_validator("start_date", "end_date", mode="before") + @classmethod + def parse_dates(cls, v: str | date) -> date: + if isinstance(v, date): + return v + return datetime.strptime(v, "%Y%m%d").date() # noqa: DTZ007 + + @field_serializer("start_date", "end_date", when_used="always") + @staticmethod + def serialize_dates(v: date, _info: SerializationInfo) -> str: + return v.strftime("%Y%m%d") + + +class Rules(BaseModel): + campaign_config: CampaignConfig = Field(..., alias="CampaignConfig") + + model_config = {"populate_by_name": True} diff --git a/src/eligibility_signposting_api/repos/factory.py b/src/eligibility_signposting_api/repos/factory.py index fe909323..c65d8539 100644 --- a/src/eligibility_signposting_api/repos/factory.py +++ b/src/eligibility_signposting_api/repos/factory.py @@ -3,6 +3,7 @@ from boto3 import Session from boto3.resources.base import ServiceResource +from botocore.client import BaseClient from wireup import Inject, service from yarl import URL @@ -27,3 +28,10 @@ def dynamodb_resource_factory( session: Session, dynamodb_endpoint: Annotated[URL, Inject(param="dynamodb_endpoint")] ) -> ServiceResource: return session.resource("dynamodb", endpoint_url=str(dynamodb_endpoint)) + + +@service(qualifier="s3") +def s3_service_factory( + session: Session, dynamodb_endpoint: Annotated[URL, Inject(param="dynamodb_endpoint")] +) -> BaseClient: + return session.client("s3", endpoint_url=str(dynamodb_endpoint)) diff --git a/src/eligibility_signposting_api/repos/rules_repo.py b/src/eligibility_signposting_api/repos/rules_repo.py new file mode 100644 index 00000000..762f230c --- /dev/null +++ b/src/eligibility_signposting_api/repos/rules_repo.py @@ -0,0 +1,24 @@ +import json +from typing import Annotated + +from botocore.client import BaseClient +from wireup import Inject, service + +from eligibility_signposting_api.model.rules import BucketName, CampaignConfig, CampaignName, Rules + + +@service +class RulesRepo: + def __init__( + self, + s3_client: Annotated[BaseClient, Inject(qualifier="s3")], + bucket_name: Annotated[BucketName, Inject(param="rules_bucket_name")], + ) -> None: + super().__init__() + self.s3_client = s3_client + self.bucket_name = bucket_name + + def get_campaign_config(self, campaign: CampaignName) -> CampaignConfig: + response = self.s3_client.get_object(Bucket=self.bucket_name, Key=f"{campaign}.json") + body = response["Body"].read() + return Rules.model_validate(json.loads(body)).campaign_config diff --git a/tests/integration/repo/test_rules_repo.py b/tests/integration/repo/test_rules_repo.py new file mode 100644 index 00000000..50b2b167 --- /dev/null +++ b/tests/integration/repo/test_rules_repo.py @@ -0,0 +1,42 @@ +import json +from collections.abc import Generator + +import pytest +from botocore.client import BaseClient +from hamcrest import assert_that + +from eligibility_signposting_api.model.rules import BucketName, CampaignConfig +from eligibility_signposting_api.repos.rules_repo import RulesRepo +from tests.integration.conftest import AWS_REGION +from tests.utils.builders import CampaignConfigFactory, random_str +from tests.utils.rules.rules import is_campaign_config + + +@pytest.fixture +def bucket(s3_client: BaseClient) -> Generator[BucketName]: + bucket_name = BucketName(random_str(63)) + s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": AWS_REGION}) + yield bucket_name + s3_client.delete_bucket(Bucket=bucket_name) + + +@pytest.fixture +def campaign(s3_client: BaseClient, bucket: BucketName) -> Generator[CampaignConfig]: + campaign: CampaignConfig = CampaignConfigFactory.build() + campaign_data = {"CampaignConfig": campaign.model_dump(by_alias=True)} + s3_client.put_object( + Bucket=bucket, Key=f"{campaign.name}.json", Body=json.dumps(campaign_data), ContentType="application/json" + ) + yield campaign + s3_client.delete_object(Bucket=bucket, Key=f"{campaign.name}.json") + + +def test_get_campaign_config(s3_client: BaseClient, bucket: BucketName, campaign: CampaignConfig): + # Given + repo = RulesRepo(s3_client, bucket) + + # When + actual = repo.get_campaign_config(campaign.name) + + # Then + assert_that(actual, is_campaign_config().with_id(campaign.id).and_name(campaign.name).and_version(campaign.version)) diff --git a/tests/utils/builders.py b/tests/utils/builders.py index e69de29b..6f0ab4fe 100644 --- a/tests/utils/builders.py +++ b/tests/utils/builders.py @@ -0,0 +1,30 @@ +import random +import string + +from polyfactory import Use +from polyfactory.factories.pydantic_factory import ModelFactory + +from eligibility_signposting_api.model.rules import CampaignConfig, Iteration, IterationCohort, IterationRule + + +class IterationCohortFactory(ModelFactory[IterationCohort]): + __model__ = IterationCohort + + +class IterationRuleFactory(ModelFactory[IterationRule]): + __model__ = IterationRule + + +class IterationFactory(ModelFactory[Iteration]): + __model__ = Iteration + iteration_cohorts = Use(IterationCohortFactory.batch, size=2) + iteration_rules = Use(IterationRuleFactory.batch, size=2) + + +class CampaignConfigFactory(ModelFactory[CampaignConfig]): + __model__ = CampaignConfig + iterations = Use(IterationFactory.batch, size=2) + + +def random_str(length: int) -> str: + return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length)) # noqa: S311 diff --git a/tests/utils/rules/rules.py b/tests/utils/rules/rules.py new file mode 100644 index 00000000..3b363fa5 --- /dev/null +++ b/tests/utils/rules/rules.py @@ -0,0 +1,64 @@ +from brunns.matchers.utils import append_matcher_description, describe_field_match, describe_field_mismatch +from hamcrest import anything +from hamcrest.core.base_matcher import BaseMatcher +from hamcrest.core.description import Description +from hamcrest.core.helpers.wrap_matcher import wrap_matcher +from hamcrest.core.matcher import Matcher + +from eligibility_signposting_api.model.rules import CampaignConfig, CampaignID, CampaignName, CampaignVersion + +ANYTHING = anything() + + +class CampaignConfigMatcher(BaseMatcher[CampaignConfig]): + def __init__(self): + super().__init__() + self.id_: Matcher[CampaignID] = ANYTHING + self.name: Matcher[CampaignName] = ANYTHING + self.version: Matcher[CampaignVersion] = ANYTHING + + def describe_to(self, description: Description) -> None: + description.append_text("CampaignConfig with") + append_matcher_description(self.id_, "id", description) + append_matcher_description(self.name, "name", description) + append_matcher_description(self.version, "version", description) + + def _matches(self, item: CampaignConfig) -> bool: + return self.id_.matches(item.id) and self.name.matches(item.name) and self.version.matches(item.version) + + def describe_mismatch(self, item: CampaignConfig, mismatch_description: Description) -> None: + mismatch_description.append_text("was CampaignConfig with") + describe_field_mismatch(self.id_, "id", item.id, mismatch_description) + describe_field_mismatch(self.name, "name", item.name, mismatch_description) + describe_field_mismatch(self.version, "version", item.version, mismatch_description) + + def describe_match(self, item: CampaignConfig, match_description: Description) -> None: + match_description.append_text("was CampaignConfig with") + describe_field_match(self.id_, "id", item.id, match_description) + describe_field_match(self.name, "name", item.name, match_description) + describe_field_match(self.version, "version", item.version, match_description) + + def with_id(self, id_: CampaignID | Matcher[CampaignID]): + self.id_ = wrap_matcher(id_) + return self + + def and_id(self, id_: CampaignID | Matcher[CampaignID]): + return self.with_id(id_) + + def with_name(self, name: CampaignName | Matcher[CampaignName]): + self.name = wrap_matcher(name) + return self + + def and_name(self, name: CampaignName | Matcher[CampaignName]): + return self.with_name(name) + + def with_version(self, version: CampaignVersion | Matcher[CampaignVersion]): + self.version = wrap_matcher(version) + return self + + def and_version(self, version: CampaignVersion | Matcher[CampaignVersion]): + return self.with_version(version) + + +def is_campaign_config() -> Matcher[CampaignConfig]: + return CampaignConfigMatcher()