Skip to content

Commit 4d911cb

Browse files
authored
fix: update resolution order to favour provider types (litestar-org#832)
1 parent ef2e655 commit 4d911cb

File tree

2 files changed

+53
-10
lines changed

2 files changed

+53
-10
lines changed

polyfactory/factories/base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,10 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
784784

785785
return cls.get_field_value(cls.__random__.choice(children), field_build_parameters, build_context)
786786

787+
provider_map = cls.get_provider_map()
788+
if provider := (provider_map.get(field_meta.annotation) or provider_map.get(unwrapped_annotation)):
789+
return provider()
790+
787791
if BaseFactory.is_factory_type(annotation=unwrapped_annotation):
788792
if not field_build_parameters and unwrapped_annotation in build_context["seen_models"]:
789793
return None if is_optional(field_meta.annotation) else Null
@@ -844,10 +848,6 @@ def get_field_value( # noqa: C901, PLR0911, PLR0912
844848
build_context=build_context,
845849
)
846850

847-
provider_map = cls.get_provider_map()
848-
if provider := (provider_map.get(field_meta.annotation) or provider_map.get(unwrapped_annotation)):
849-
return provider()
850-
851851
if is_type_var(unwrapped_annotation):
852852
return create_random_string(cls.__random__, min_length=1, max_length=10)
853853

@@ -913,6 +913,12 @@ def get_field_value_coverage( # noqa: C901,PLR0912
913913
field_meta=unwrapped_annotation_meta,
914914
)
915915

916+
elif provider := (
917+
(provider_map := cls.get_provider_map()).get(field_meta.annotation)
918+
or provider_map.get(unwrapped_annotation)
919+
):
920+
yield CoverageContainerCallable(provider)
921+
916922
elif BaseFactory.is_factory_type(annotation=unwrapped_annotation):
917923
yield CoverageContainer(
918924
cls._get_or_create_factory(model=unwrapped_annotation).coverage(
@@ -935,12 +941,6 @@ def get_field_value_coverage( # noqa: C901,PLR0912
935941

936942
yield handle_collection_type_coverage(child_meta, origin, cls, build_context=build_context)
937943

938-
elif provider := (
939-
(provider_map := cls.get_provider_map()).get(field_meta.annotation)
940-
or provider_map.get(unwrapped_annotation)
941-
):
942-
yield CoverageContainerCallable(provider)
943-
944944
elif is_type_var(unwrapped_annotation):
945945
yield create_random_string(cls.__random__, min_length=1, max_length=10)
946946

tests/test_provider_map.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
import pytest
55

6+
from pydantic import BaseModel
7+
68
from polyfactory.exceptions import ParameterException
79
from polyfactory.factories.base import BaseFactory
810
from polyfactory.factories.dataclass_factory import DataclassFactory
11+
from polyfactory.factories.pydantic_factory import ModelFactory
912

1013

1114
def test_provider_map() -> None:
@@ -60,6 +63,46 @@ def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
6063
assert all(result.foo == "any" for result in coverage_result)
6164

6265

66+
def test_provider_map_takes_priority_over_factory_type() -> None:
67+
"""Custom providers should take precedence over built-in factory type resolution."""
68+
69+
@dataclass
70+
class Inner:
71+
value: str
72+
73+
@dataclass
74+
class Outer:
75+
inner: Inner
76+
77+
sentinel = Inner(value="from_provider")
78+
79+
class OuterFactory(DataclassFactory[Outer]):
80+
@classmethod
81+
def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
82+
return {Inner: lambda: sentinel, **super().get_provider_map()}
83+
84+
assert OuterFactory.build().inner is sentinel
85+
assert all(result.inner is sentinel for result in OuterFactory.coverage())
86+
87+
88+
def test_provider_map_takes_priority_over_pydantic_factory_type() -> None:
89+
"""Custom providers should take precedence for Pydantic model fields."""
90+
91+
class InnerModel(BaseModel):
92+
value: str
93+
94+
class OuterModel(BaseModel):
95+
inner: InnerModel
96+
97+
class OuterFactory(ModelFactory[OuterModel]):
98+
@classmethod
99+
def get_provider_map(cls) -> dict[Any, Callable[[], Any]]:
100+
return {InnerModel: lambda: InnerModel(value="from_provider"), **super().get_provider_map()}
101+
102+
assert OuterFactory.build().inner.value == "from_provider"
103+
assert all(result.inner.value == "from_provider" for result in OuterFactory.coverage())
104+
105+
63106
def test_add_custom_provider() -> None:
64107
class CustomType:
65108
def __init__(self, _: Any) -> None:

0 commit comments

Comments
 (0)