Skip to content

Commit e26daa5

Browse files
Use builder for rules test.
1 parent 757256b commit e26daa5

File tree

5 files changed

+72
-40
lines changed

5 files changed

+72
-40
lines changed

poetry.lock

Lines changed: 26 additions & 20 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ aiohttp = "^3.11.12"
4848
awscli = "^1.37.24"
4949
awscli-local = "^0.22.0"
5050
pyhamcrest = "^2.1.0"
51-
factory-boy = "^3.3.3"
51+
polyfactory = "^2.20.0"
5252
pyright = "^1.1.394"
5353
brunns-matchers = "^2.9.0"
5454
localstack = "^4.1.1"

src/eligibility_signposting_api/model/rules.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ class IterationCohort(BaseModel):
1212
cohort_label: str | None = Field(None, alias="CohortLabel")
1313
priority: int | None = Field(None, alias="Priority")
1414

15+
model_config = {"populate_by_name": True}
16+
1517

1618
class IterationRule(BaseModel):
1719
type: str | None = Field(None, alias="Type")
@@ -25,6 +27,8 @@ class IterationRule(BaseModel):
2527
attribute_target: str | None = Field(None, alias="AttributeTarget")
2628
comms_routing: str | None = Field(None, alias="CommsRouting")
2729

30+
model_config = {"populate_by_name": True}
31+
2832

2933
class Iteration(BaseModel):
3034
id: str = Field(..., alias="ID")
@@ -40,6 +44,8 @@ class Iteration(BaseModel):
4044
iteration_cohorts: list[IterationCohort] | None = Field(None, alias="IterationCohorts")
4145
iteration_rules: list[IterationRule] | None = Field(None, alias="IterationRules")
4246

47+
model_config = {"populate_by_name": True}
48+
4349

4450
class CampaignConfig(BaseModel):
4551
id: str = Field(..., alias="ID")
@@ -60,6 +66,10 @@ class CampaignConfig(BaseModel):
6066
approval_maximum: int | None = Field(None, alias="ApprovalMaximum")
6167
iterations: list[Iteration] | None = Field(None, alias="Iterations")
6268

69+
model_config = {"populate_by_name": True}
70+
6371

6472
class Rules(BaseModel):
6573
campaign_config: CampaignConfig = Field(..., alias="CampaignConfig")
74+
75+
model_config = {"populate_by_name": True}
Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import json
2-
import uuid
32
from collections.abc import Generator
43

54
import pytest
65
from botocore.client import BaseClient
76
from hamcrest import assert_that
87

9-
from eligibility_signposting_api.model.rules import BucketName, Campaign
8+
from eligibility_signposting_api.model.rules import BucketName, CampaignConfig
109
from eligibility_signposting_api.repos.rules_repo import RulesRepo
1110
from tests.integration.conftest import AWS_REGION
12-
from tests.utils.builders import random_int, random_str
11+
from tests.utils.builders import CampaignConfigFactory, random_str
1312
from tests.utils.rules.campaign import is_campaign_config
1413

1514

@@ -22,25 +21,22 @@ def bucket(s3_client: BaseClient) -> Generator[BucketName]:
2221

2322

2423
@pytest.fixture
25-
def campaign(s3_client: BaseClient, bucket: BucketName) -> Generator[tuple[Campaign, str, str]]:
26-
campaign_name = Campaign(random_str(10))
27-
id_ = f"{uuid.uuid4()}"
28-
version = random_int(maximum=10)
29-
campaign_data = {"CampaignConfig": {"ID": id_, "Version": version, "Name": campaign_name}}
24+
def campaign(s3_client: BaseClient, bucket: BucketName) -> Generator[CampaignConfig]:
25+
campaign: CampaignConfig = CampaignConfigFactory.build()
26+
campaign_data = {"CampaignConfig": campaign.model_dump(by_alias=True)}
3027
s3_client.put_object(
31-
Bucket=bucket, Key=f"{campaign_name}.json", Body=json.dumps(campaign_data), ContentType="application/json"
28+
Bucket=bucket, Key=f"{campaign.name}.json", Body=json.dumps(campaign_data), ContentType="application/json"
3229
)
33-
yield campaign_name, id_, version
34-
s3_client.delete_object(Bucket=bucket, Key=f"{campaign_name}.json")
30+
yield campaign
31+
s3_client.delete_object(Bucket=bucket, Key=f"{campaign.name}.json")
3532

3633

37-
def test_get_campaign_config(s3_client: BaseClient, bucket: BucketName, campaign: tuple[Campaign, str, str]):
34+
def test_get_campaign_config(s3_client: BaseClient, bucket: BucketName, campaign: CampaignConfig):
3835
# Given
39-
campaign_name, id_, version = campaign
4036
repo = RulesRepo(s3_client, bucket)
4137

4238
# When
43-
actual = repo.get_campaign_config(campaign_name)
39+
actual = repo.get_campaign_config(campaign.name)
4440

4541
# Then
46-
assert_that(actual, is_campaign_config().with_id(id_).and_name(campaign_name).and_version(version))
42+
assert_that(actual, is_campaign_config().with_id(campaign.id).and_name(campaign.name).and_version(campaign.version))

tests/utils/builders.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,30 @@
11
import random
22
import string
33

4+
from polyfactory import Use
5+
from polyfactory.factories.pydantic_factory import ModelFactory
46

5-
def random_str(length: int) -> str:
6-
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length)) # noqa: S311
7+
from eligibility_signposting_api.model.rules import CampaignConfig, Iteration, IterationCohort, IterationRule
8+
9+
10+
class IterationCohortFactory(ModelFactory[IterationCohort]):
11+
__model__ = IterationCohort
12+
13+
14+
class IterationRuleFactory(ModelFactory[IterationRule]):
15+
__model__ = IterationRule
716

817

9-
def random_int(minimum: int = 1, maximum: int = 10) -> int:
10-
return random.randint(minimum, maximum) # noqa: S311
18+
class IterationFactory(ModelFactory[Iteration]):
19+
__model__ = Iteration
20+
iteration_cohorts: Use(IterationCohortFactory.batch, size=2)
21+
iteration_rules: Use(IterationRuleFactory.batch, size=2)
22+
23+
24+
class CampaignConfigFactory(ModelFactory[CampaignConfig]):
25+
__model__ = CampaignConfig
26+
iterations = Use(IterationFactory.batch, size=2)
27+
28+
29+
def random_str(length: int) -> str:
30+
return "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(length)) # noqa: S311

0 commit comments

Comments
 (0)