Skip to content

Commit 78367da

Browse files
Improves Rules model - use enums where I know what the values mean, Literals where I don't. Also deserialise dates properly.
1 parent c10c3df commit 78367da

File tree

5 files changed

+93
-41
lines changed

5 files changed

+93
-41
lines changed

src/eligibility_signposting_api/model/rules.py

Lines changed: 78 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,48 @@
11
from __future__ import annotations
22

3-
from typing import NewType
3+
import typing
4+
from datetime import date, datetime
5+
from enum import Enum
6+
from typing import Literal, NewType
47

5-
from pydantic import BaseModel, Field
8+
from pydantic import BaseModel, Field, field_serializer, field_validator
9+
10+
if typing.TYPE_CHECKING:
11+
from pydantic import SerializationInfo
612

7-
Campaign = NewType("Campaign", str)
813
BucketName = NewType("BucketName", str)
14+
CampaignName = NewType("CampaignName", str)
15+
CampaignVersion = NewType("CampaignVersion", str)
16+
CampaignID = NewType("CampaignID", str)
17+
IterationName = NewType("IterationName", str)
18+
IterationVersion = NewType("IterationVersion", str)
19+
IterationID = NewType("IterationID", str)
20+
RuleName = NewType("RuleName", str)
21+
RuleDescription = NewType("RuleDescription", str)
22+
RulePriority = NewType("RulePriority", int)
23+
RuleAttributeLevel = NewType("RuleAttributeLevel", str)
24+
RuleAttributeName = NewType("RuleAttributeName", str)
25+
RuleComparator = NewType("RuleComparator", str)
26+
StartDate = NewType("StartDate", date)
27+
EndDate = NewType("EndDate", date)
28+
29+
30+
class RuleType(str, Enum):
31+
filter = "F"
32+
suppression = "S"
33+
redirect = "R"
34+
35+
36+
class RuleOperator(str, Enum):
37+
lt = "<"
38+
gt = ">"
39+
year_gt = "Y>"
40+
not_in = "not_in"
41+
equals = "="
42+
lte = "<="
43+
ne = "!="
44+
date_gte = "D>="
45+
member_of = "MemberOf"
946

1047

1148
class IterationCohort(BaseModel):
@@ -16,57 +53,72 @@ class IterationCohort(BaseModel):
1653

1754

1855
class IterationRule(BaseModel):
19-
type: str | None = Field(None, alias="Type")
20-
name: str | None = Field(None, alias="Name")
21-
description: str | None = Field(None, alias="Description")
22-
priority: int | None = Field(None, alias="Priority")
23-
attribute_level: str | None = Field(None, alias="AttributeLevel")
24-
attribute_name: str | None = Field(None, alias="AttributeName")
25-
operator: str | None = Field(None, alias="Operator")
26-
comparator: str | None = Field(None, alias="Comparator")
56+
type: RuleType = Field(..., alias="Type")
57+
name: RuleName = Field(..., alias="Name")
58+
description: RuleDescription = Field(..., alias="Description")
59+
priority: RulePriority = Field(..., alias="Priority")
60+
attribute_level: RuleAttributeLevel = Field(..., alias="AttributeLevel")
61+
attribute_name: RuleAttributeName = Field(..., alias="AttributeName")
62+
operator: RuleOperator = Field(..., alias="Operator")
63+
comparator: RuleComparator = Field(..., alias="Comparator")
2764
attribute_target: str | None = Field(None, alias="AttributeTarget")
2865
comms_routing: str | None = Field(None, alias="CommsRouting")
2966

3067
model_config = {"populate_by_name": True}
3168

3269

3370
class Iteration(BaseModel):
34-
id: str = Field(..., alias="ID")
71+
id: IterationID = Field(..., alias="ID")
3572
default_comms_routing: str | None = Field(None, alias="DefaultCommsRouting")
36-
version: int | None = Field(None, alias="Version")
37-
name: str | None = Field(None, alias="Name")
73+
version: IterationVersion = Field(..., alias="Version")
74+
name: IterationName = Field(..., alias="Name")
3875
iteration_date: str | None = Field(None, alias="IterationDate")
3976
iteration_number: int | None = Field(None, alias="IterationNumber")
40-
comms_type: str | None = Field(None, alias="CommsType")
77+
comms_type: Literal["I", "R"] = Field(..., alias="CommsType")
4178
approval_minimum: int | None = Field(None, alias="ApprovalMinimum")
4279
approval_maximum: int | None = Field(None, alias="ApprovalMaximum")
43-
type: str | None = Field(None, alias="Type")
80+
type: Literal["A", "M", "S"] = Field(..., alias="Type")
4481
iteration_cohorts: list[IterationCohort] | None = Field(None, alias="IterationCohorts")
4582
iteration_rules: list[IterationRule] | None = Field(None, alias="IterationRules")
4683

4784
model_config = {"populate_by_name": True}
4885

4986

5087
class CampaignConfig(BaseModel):
51-
id: str = Field(..., alias="ID")
52-
version: int = Field(..., alias="Version")
53-
name: str = Field(..., alias="Name")
54-
type: str | None = Field(None, alias="Type")
55-
target: str | None = Field(None, alias="Target")
88+
id: CampaignID = Field(..., alias="ID")
89+
version: CampaignVersion = Field(..., alias="Version")
90+
name: CampaignName = Field(..., alias="Name")
91+
type: Literal["V", "S"] = Field(..., alias="Type")
92+
target: Literal["COVID", "FLU", "MMR", "RSV"] = Field(..., alias="Target")
5693
manager: str | None = Field(None, alias="Manager")
5794
approver: str | None = Field(None, alias="Approver")
5895
reviewer: str | None = Field(None, alias="Reviewer")
59-
iteration_frequency: str | None = Field(None, alias="IterationFrequency")
60-
iteration_type: str | None = Field(None, alias="IterationType")
96+
iteration_frequency: Literal["X", "D", "W", "M", "Q", "A"] = Field(..., alias="IterationFrequency")
97+
iteration_type: Literal["A", "M", "S"] = Field(..., alias="IterationType")
6198
iteration_time: str | None = Field(None, alias="IterationTime")
6299
default_comms_routing: str | None = Field(None, alias="DefaultCommsRouting")
63-
start_date: str | None = Field(None, alias="StartDate")
64-
end_date: str | None = Field(None, alias="EndDate")
100+
start_date: StartDate = Field(..., alias="StartDate")
101+
end_date: EndDate = Field(..., alias="EndDate")
65102
approval_minimum: int | None = Field(None, alias="ApprovalMinimum")
66103
approval_maximum: int | None = Field(None, alias="ApprovalMaximum")
67104
iterations: list[Iteration] | None = Field(None, alias="Iterations")
68105

69-
model_config = {"populate_by_name": True}
106+
model_config = {
107+
"populate_by_name": True,
108+
"arbitrary_types_allowed": True,
109+
}
110+
111+
@field_validator("start_date", "end_date", mode="before")
112+
@classmethod
113+
def parse_dates(cls, v: str | date) -> date:
114+
if isinstance(v, date):
115+
return v
116+
return datetime.strptime(v, "%Y%m%d").date() # noqa: DTZ007
117+
118+
@field_serializer("start_date", "end_date", when_used="always")
119+
@staticmethod
120+
def serialize_dates(v: date, _info: SerializationInfo) -> str:
121+
return v.strftime("%Y%m%d")
70122

71123

72124
class Rules(BaseModel):

src/eligibility_signposting_api/repos/rules_repo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from botocore.client import BaseClient
55
from wireup import Inject, service
66

7-
from eligibility_signposting_api.model.rules import BucketName, Campaign, CampaignConfig, Rules
7+
from eligibility_signposting_api.model.rules import BucketName, CampaignConfig, CampaignName, Rules
88

99

1010
@service
@@ -18,7 +18,7 @@ def __init__(
1818
self.s3_client = s3_client
1919
self.bucket_name = bucket_name
2020

21-
def get_campaign_config(self, campaign: Campaign) -> CampaignConfig:
21+
def get_campaign_config(self, campaign: CampaignName) -> CampaignConfig:
2222
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=f"{campaign}.json")
2323
body = response["Body"].read()
2424
return Rules.model_validate(json.loads(body)).campaign_config

tests/integration/repo/test_rules_repo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from eligibility_signposting_api.repos.rules_repo import RulesRepo
1010
from tests.integration.conftest import AWS_REGION
1111
from tests.utils.builders import CampaignConfigFactory, random_str
12-
from tests.utils.rules.campaign import is_campaign_config
12+
from tests.utils.rules.rules import is_campaign_config
1313

1414

1515
@pytest.fixture

tests/utils/builders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ class IterationRuleFactory(ModelFactory[IterationRule]):
1717

1818
class IterationFactory(ModelFactory[Iteration]):
1919
__model__ = Iteration
20-
iteration_cohorts: Use(IterationCohortFactory.batch, size=2)
21-
iteration_rules: Use(IterationRuleFactory.batch, size=2)
20+
iteration_cohorts = Use(IterationCohortFactory.batch, size=2)
21+
iteration_rules = Use(IterationRuleFactory.batch, size=2)
2222

2323

2424
class CampaignConfigFactory(ModelFactory[CampaignConfig]):
Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
from hamcrest.core.helpers.wrap_matcher import wrap_matcher
66
from hamcrest.core.matcher import Matcher
77

8-
from eligibility_signposting_api.model.rules import CampaignConfig
8+
from eligibility_signposting_api.model.rules import CampaignConfig, CampaignID, CampaignName, CampaignVersion
99

1010
ANYTHING = anything()
1111

1212

1313
class CampaignConfigMatcher(BaseMatcher[CampaignConfig]):
1414
def __init__(self):
1515
super().__init__()
16-
self.id_: Matcher[str] = ANYTHING
17-
self.name: Matcher[str] = ANYTHING
18-
self.version: Matcher[str] = ANYTHING
16+
self.id_: Matcher[CampaignID] = ANYTHING
17+
self.name: Matcher[CampaignName] = ANYTHING
18+
self.version: Matcher[CampaignVersion] = ANYTHING
1919

2020
def describe_to(self, description: Description) -> None:
2121
description.append_text("CampaignConfig with")
@@ -38,25 +38,25 @@ def describe_match(self, item: CampaignConfig, match_description: Description) -
3838
describe_field_match(self.name, "name", item.name, match_description)
3939
describe_field_match(self.version, "version", item.version, match_description)
4040

41-
def with_id(self, id_: str | Matcher[str]):
41+
def with_id(self, id_: CampaignID | Matcher[CampaignID]):
4242
self.id_ = wrap_matcher(id_)
4343
return self
4444

45-
def and_id(self, id_: str | Matcher[str]):
45+
def and_id(self, id_: CampaignID | Matcher[CampaignID]):
4646
return self.with_id(id_)
4747

48-
def with_name(self, name: str | Matcher[str]):
48+
def with_name(self, name: CampaignName | Matcher[CampaignName]):
4949
self.name = wrap_matcher(name)
5050
return self
5151

52-
def and_name(self, name: str | Matcher[str]):
52+
def and_name(self, name: CampaignName | Matcher[CampaignName]):
5353
return self.with_name(name)
5454

55-
def with_version(self, version: str | Matcher[str]):
55+
def with_version(self, version: CampaignVersion | Matcher[CampaignVersion]):
5656
self.version = wrap_matcher(version)
5757
return self
5858

59-
def and_version(self, version: str | Matcher[str]):
59+
def and_version(self, version: CampaignVersion | Matcher[CampaignVersion]):
6060
return self.with_version(version)
6161

6262

0 commit comments

Comments
 (0)