Skip to content

Commit c68e79f

Browse files
committed
Merge post-hooks feature branch and resolve conflicts
- Merged origin/post-hooks branch which adds post_build and post_generate hooks - Resolved merge conflicts in tests/test_pydantic_factory.py - Kept both sets of tests (PEP 695 tests and post_build test)
2 parents af24258 + 9742675 commit c68e79f

File tree

5 files changed

+92
-4
lines changed

5 files changed

+92
-4
lines changed

polyfactory/factories/base.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,25 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
11171117
for field_name, post_generator in generate_post.items():
11181118
result[field_name] = post_generator.to_value(field_name, result)
11191119

1120+
return cls.post_generate(result)
1121+
1122+
@classmethod
1123+
def post_build(cls, model: T) -> T:
1124+
"""Post-create hook. Helpful for building additional database associations or running logic which requires the
1125+
fully-created model.
1126+
1127+
:param model: The created model instance.
1128+
:returns: The (optionally) mutated model.
1129+
"""
1130+
return model
1131+
1132+
@classmethod
1133+
def post_generate(cls, result: dict[str, Any]) -> dict[str, Any]:
1134+
"""Post-generate hook. Helpful for mutating or adding additional fields right before model creation.
1135+
1136+
:param result: The kwargs that will be passed to the model.
1137+
:returns: The (optionally) mutated kwargs.
1138+
"""
11201139
return result
11211140

11221141
@classmethod
@@ -1177,7 +1196,11 @@ def build(cls, *_: Any, **kwargs: Any) -> T:
11771196
:returns: An instance of type T.
11781197
11791198
"""
1180-
return cast("T", cls.__model__(**cls.process_kwargs(**kwargs)))
1199+
created_model = cast("T", cls.__model__(**cls.process_kwargs(**kwargs)))
1200+
1201+
cls.post_build(created_model)
1202+
1203+
return created_model
11811204

11821205
@classmethod
11831206
def batch(cls, size: int, **kwargs: Any) -> list[T]:
@@ -1202,6 +1225,7 @@ def coverage(cls, **kwargs: Any) -> abc.Iterator[T]:
12021225
"""
12031226
for data in cls.process_kwargs_coverage(**kwargs):
12041227
instance = cls.__model__(**data)
1228+
cls.post_build(instance)
12051229
yield cast("T", instance)
12061230

12071231
@classmethod

polyfactory/factories/pydantic_factory.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def from_field_info(
148148
min_collection_length: int | None = None,
149149
max_collection_length: int | None = None,
150150
) -> PydanticFieldMeta:
151-
"""Create an instance from a pydantic field info.
151+
"""Create an instance from a pydantic field info. Used by `get_model_fields` to generate field list for a model.
152152
153153
:param field_name: The name of the field.
154154
:param field_info: A pydantic FieldInfo instance.
@@ -545,7 +545,11 @@ def build(
545545

546546
processed_kwargs = cls.process_kwargs(**kwargs)
547547

548-
return cls._create_model(kwargs["_build_context"], **processed_kwargs)
548+
created_model = cls._create_model(kwargs["_build_context"], **processed_kwargs)
549+
550+
cls.post_build(created_model)
551+
552+
return created_model
549553

550554
@classmethod
551555
def _get_build_context(cls, build_context: BaseBuildContext | PydanticBuildContext | None) -> PydanticBuildContext:
@@ -592,7 +596,11 @@ def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Ite
592596
)
593597

594598
for data in cls.process_kwargs_coverage(**kwargs):
595-
yield cls._create_model(_build_context=kwargs["_build_context"], **data)
599+
created_model = cls._create_model(_build_context=kwargs["_build_context"], **data)
600+
601+
cls.post_build(created_model)
602+
603+
yield created_model
596604

597605
@classmethod
598606
def is_custom_root_field(cls, field_meta: FieldMeta) -> bool:

tests/test_factory_fields.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,25 @@ def caption(cls, is_long: bool) -> str:
167167
assert result.caption == "just this"
168168

169169

170+
def test_post_build_classmethod() -> None:
171+
@dataclass
172+
class Model:
173+
i: int
174+
j: int
175+
176+
class Factory(DataclassFactory[Model]):
177+
__model__ = Model
178+
179+
@classmethod
180+
def post_build(cls, model: Model) -> Model:
181+
model.i = model.j + 10
182+
return model
183+
184+
result = Factory.build()
185+
186+
assert result.i == result.j + 10
187+
188+
170189
@pytest.mark.parametrize(
171190
"factory_field",
172191
[

tests/test_pydantic_factory.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,3 +1528,20 @@ class FooFactory(ModelFactory[Foo]):
15281528

15291529
instance = FooFactory.build()
15301530
assert instance.name == "John" # Should use the overridden alias
1531+
1532+
1533+
def test_post_build_classmethod() -> None:
1534+
class Model(BaseModel):
1535+
i: int
1536+
j: int
1537+
1538+
class Factory(ModelFactory[Model]):
1539+
__model__ = Model
1540+
1541+
@classmethod
1542+
def post_build(cls, model: Model) -> Model:
1543+
model.i = model.j + 10
1544+
return model
1545+
1546+
result = Factory.build()
1547+
assert result.i == result.j + 10

tests/test_type_coverage_generation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,26 @@ def i(cls, j: int) -> int:
202202
assert results[0].i == results[0].j + 10
203203

204204

205+
def test_coverage_post_build() -> None:
206+
@dataclass
207+
class Model:
208+
i: int
209+
j: int
210+
211+
class Factory(DataclassFactory[Model]):
212+
__model__ = Model
213+
214+
@classmethod
215+
def post_build(cls, model: Model) -> Model:
216+
model.i = model.j + 10
217+
return model
218+
219+
results = list(Factory.coverage())
220+
assert len(results) == 1
221+
222+
assert results[0].i == results[0].j + 10
223+
224+
205225
class CustomInt:
206226
def __init__(self, value: int) -> None:
207227
self.value = value

0 commit comments

Comments
 (0)