Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ check_untyped_defs = true
disallow_untyped_decorators = true
disallow_any_explicit = false
disallow_any_generics = true
disallow_untyped_calls = true
disallow_untyped_calls = false
disallow_untyped_defs = true
ignore_errors = false
ignore_missing_imports = true
Expand Down
16 changes: 11 additions & 5 deletions pytest_factoryboy/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,23 @@
try:
from factory.declarations import PostGenerationContext
except ImportError: # factory_boy < 3.2.0
from factory.builder import PostGenerationContext
from factory.builder import ( # type: ignore[attr-defined, no-redef]
PostGenerationContext,
)

if pytest_version.release >= (8, 1):

def getfixturedefs(fixturemanager: FixtureManager, fixturename: str, node: Node) -> Sequence[FixtureDef] | None:
def getfixturedefs(
fixturemanager: FixtureManager, fixturename: str, node: Node
) -> Sequence[FixtureDef[object]] | None:
return fixturemanager.getfixturedefs(fixturename, node)

else:

def getfixturedefs(fixturemanager: FixtureManager, fixturename: str, node: Node) -> Sequence[FixtureDef] | None:
return fixturemanager.getfixturedefs(fixturename, node.nodeid)
def getfixturedefs(
fixturemanager: FixtureManager, fixturename: str, node: Node
) -> Sequence[FixtureDef[object]] | None:
return fixturemanager.getfixturedefs(fixturename, node.nodeid) # type: ignore[arg-type]


if pytest_version.release >= (8, 4):
Expand All @@ -35,4 +41,4 @@ def getfixturedefs(fixturemanager: FixtureManager, fixturename: str, node: Node)
else:
from _pytest.fixtures import FixtureFunction

PytestFixtureT: TypeAlias = FixtureFunction
PytestFixtureT: TypeAlias = FixtureFunction # type: ignore[misc, no-redef]
111 changes: 59 additions & 52 deletions pytest_factoryboy/fixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, cast, overload

import factory
import factory.builder
import factory.declarations
import factory.enums
import inflection
from typing_extensions import ParamSpec, TypeAlias
from factory.base import Factory
from factory.builder import BuildStep, DeclarationSet, StepBuilder
from factory.declarations import (
NotProvided,
PostGeneration,
PostGenerationDeclaration,
PostGenerationMethodCall,
RelatedFactory,
SubFactory,
)
from typing_extensions import ParamSpec

from .compat import PostGenerationContext
from .fixturegen import create_fixture
Expand All @@ -27,9 +35,8 @@

from .plugin import Request as FactoryboyRequest

FactoryType: TypeAlias = type[factory.Factory]
F = TypeVar("F", bound=FactoryType)
T = TypeVar("T")
U = TypeVar("U")
T_co = TypeVar("T_co", covariant=True)
P = ParamSpec("P")

Expand All @@ -38,9 +45,9 @@


@dataclass(eq=False)
class DeferredFunction:
class DeferredFunction(Generic[T]):
name: str
factory: FactoryType
factory: type[Factory[T]]
is_related: bool
function: Callable[[SubRequest], Any]

Expand All @@ -67,24 +74,24 @@ def named_model(model_cls: type[T], name: str) -> type[T]:
# register(AuthorFactory, ...)
#
# @register
# class AuthorFactory(factory.Factory): ...
# class AuthorFactory(Factory): ...
@overload
def register(factory_class: F, _name: str | None = None, **kwargs: Any) -> F: ...
def register(factory_class: type[Factory[T]], _name: str | None = None, **kwargs: Any) -> type[Factory[T]]: ...


# @register(...)
# class AuthorFactory(factory.Factory): ...
# class AuthorFactory(Factory): ...
@overload
def register(*, _name: str | None = None, **kwargs: Any) -> Callable[[F], F]: ...
def register(*, _name: str | None = None, **kwargs: Any) -> Callable[[type[Factory[T]]], type[Factory[T]]]: ...


def register(
factory_class: F | None = None,
factory_class: type[Factory[T]] | None = None,
_name: str | None = None,
*,
_caller_locals: Box[dict[str, Any]] | None = None,
**kwargs: Any,
) -> F | Callable[[F], F]:
) -> type[Factory[T]] | Callable[[type[Factory[T]]], type[Factory[T]]]:
r"""Register fixtures for the factory class.

:param factory_class: Factory class to register.
Expand All @@ -97,7 +104,7 @@ def register(

if factory_class is None:

def register_(factory_class: F) -> F:
def register_(factory_class: type[Factory[T]]) -> type[Factory[T]]:
return register(factory_class, _name=_name, _caller_locals=_caller_locals, **kwargs)

return register_
Expand Down Expand Up @@ -131,7 +138,7 @@ def register_(factory_class: F) -> F:


def generate_fixtures(
factory_class: FactoryType,
factory_class: type[Factory[T]],
model_name: str,
factory_name: str,
overrides: Mapping[str, Any],
Expand Down Expand Up @@ -193,23 +200,23 @@ def create_fixture_with_related(
def make_declaration_fixturedef(
attr_name: str,
value: Any,
factory_class: FactoryType,
factory_class: type[Factory[T]],
related: list[str],
) -> Callable[..., Any]:
"""Create the FixtureDef for a factory declaration."""
if isinstance(value, (factory.SubFactory, factory.RelatedFactory)):
subfactory_class = value.get_factory()
if isinstance(value, (SubFactory, RelatedFactory)):
subfactory_class: type[Factory[object]] = value.get_factory()
subfactory_deps = get_deps(subfactory_class, factory_class)

args = list(subfactory_deps)
if isinstance(value, factory.RelatedFactory):
if isinstance(value, RelatedFactory):
related_model = get_model_name(subfactory_class)
args.append(related_model)
related.append(related_model)
related.append(attr_name)
related.extend(subfactory_deps)

if isinstance(value, factory.SubFactory):
if isinstance(value, SubFactory):
args.append(inflection.underscore(subfactory_class._meta.model.__name__))

return create_fixture_with_related(
Expand All @@ -219,10 +226,10 @@ def make_declaration_fixturedef(
)

deps: list[str] # makes mypy happy
if isinstance(value, factory.PostGeneration):
if isinstance(value, PostGeneration):
value = None
deps = []
elif isinstance(value, factory.PostGenerationMethodCall):
elif isinstance(value, PostGenerationMethodCall):
value = value.method_arg
deps = []
elif isinstance(value, LazyFixture):
Expand Down Expand Up @@ -258,7 +265,7 @@ def inject_into_caller(name: str, function: Callable[..., Any], locals_: Box[dic
locals_.value[name] = function


def get_model_name(factory_class: FactoryType) -> str:
def get_model_name(factory_class: type[Factory[T]]) -> str:
"""Get model fixture name by factory."""
model_cls = factory_class._meta.model

Expand All @@ -278,14 +285,14 @@ def get_model_name(factory_class: FactoryType) -> str:
return model_name


def get_factory_name(factory_class: FactoryType) -> str:
def get_factory_name(factory_class: type[Factory[T]]) -> str:
"""Get factory fixture name by factory."""
return inflection.underscore(factory_class.__name__)


def get_deps(
factory_class: FactoryType,
parent_factory_class: FactoryType | None = None,
factory_class: type[Factory[T]],
parent_factory_class: type[Factory[U]] | None = None,
model_name: str | None = None,
) -> list[str]:
"""Get factory dependencies.
Expand All @@ -296,11 +303,13 @@ def get_deps(
parent_model_name = get_model_name(parent_factory_class) if parent_factory_class is not None else None

def is_dep(value: Any) -> bool:
if isinstance(value, factory.RelatedFactory):
if isinstance(value, RelatedFactory):
return False
if isinstance(value, factory.SubFactory) and get_model_name(value.get_factory()) == parent_model_name:
return False
if isinstance(value, factory.declarations.PostGenerationDeclaration):
if isinstance(value, SubFactory):
subfactory_class: type[Factory[object]] = value.get_factory()
if get_model_name(subfactory_class) == parent_model_name:
return False
if isinstance(value, PostGenerationDeclaration):
# Dependency on extracted value
return True

Expand Down Expand Up @@ -334,7 +343,7 @@ def disable_method(method: MethodType) -> Iterator[None]:
setattr(klass, method.__name__, old_method)


def model_fixture(request: SubRequest, factory_name: str) -> Any:
def model_fixture(request: SubRequest, factory_name: str) -> object:
"""Model fixture implementation."""
factoryboy_request: FactoryboyRequest = request.getfixturevalue("factoryboy_request")

Expand All @@ -345,21 +354,19 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
fixture_name = request.fixturename
prefix = "".join((fixture_name, SEPARATOR))

factory_class: FactoryType = request.getfixturevalue(factory_name)
factory_class: type[Factory[object]] = request.getfixturevalue(factory_name)

# Create model fixture instance
Factory: FactoryType = cast(FactoryType, type("Factory", (factory_class,), {}))
NewFactory: type[Factory[object]] = cast(type[Factory[object]], type("Factory", (factory_class,), {}))
# equivalent to:
# class Factory(factory_class):
# pass
# it just makes mypy understand it.

Factory._meta.base_declarations = {
k: v
for k, v in Factory._meta.base_declarations.items()
if not isinstance(v, factory.declarations.PostGenerationDeclaration)
NewFactory._meta.base_declarations = {
k: v for k, v in NewFactory._meta.base_declarations.items() if not isinstance(v, PostGenerationDeclaration)
}
Factory._meta.post_declarations = factory.builder.DeclarationSet()
NewFactory._meta.post_declarations = DeclarationSet()

kwargs = {}
for key in factory_class._meta.pre_declarations:
Expand All @@ -368,25 +375,25 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
kwargs[key] = evaluate(request, request.getfixturevalue(argname))

strategy = factory.enums.CREATE_STRATEGY
builder = factory.builder.StepBuilder(Factory._meta, kwargs, strategy)
step = factory.builder.BuildStep(builder=builder, sequence=Factory._meta.next_sequence())
builder = StepBuilder(NewFactory._meta, kwargs, strategy)
step = BuildStep(builder=builder, sequence=NewFactory._meta.next_sequence())

# FactoryBoy invokes the `_after_postgeneration` method, but we will instead call it manually later,
# once we are able to evaluate all the related fixtures.
with disable_method(Factory._after_postgeneration):
instance = Factory(**kwargs)
with disable_method(NewFactory._after_postgeneration): # type: ignore[arg-type] # https://github.com/python/mypy/issues/14235
Copy link

Copilot AI Jun 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider revisiting the use of the type ignore comment here; if possible, update the type annotations to resolve the underlying type issue instead of suppressing it.

Suggested change
with disable_method(NewFactory._after_postgeneration): # type: ignore[arg-type] # https://github.com/python/mypy/issues/14235
with disable_method(cast(Callable[..., Any], NewFactory._after_postgeneration)):

Copilot uses AI. Check for mistakes.
instance = NewFactory(**kwargs)

# Cache the instance value on pytest level so that the fixture can be resolved before the return
request._fixturedef.cached_result = (instance, 0, None)
request._fixture_defs[fixture_name] = request._fixturedef

# Defer post-generation declarations
deferred: list[DeferredFunction] = []
deferred: list[DeferredFunction[object]] = []

for attr in factory_class._meta.post_declarations.sorted():
decl = factory_class._meta.post_declarations.declarations[attr]

if isinstance(decl, factory.RelatedFactory):
if isinstance(decl, RelatedFactory):
deferred.append(make_deferred_related(factory_class, fixture_name, attr))
else:
argname = "".join((prefix, attr))
Expand All @@ -405,7 +412,7 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
# that `value_provided` should be falsy
postgen_value = evaluate(request, request.getfixturevalue(argname))
postgen_context = PostGenerationContext(
value_provided=(postgen_value is not factory.declarations.NotProvided),
value_provided=(postgen_value is not NotProvided),
value=postgen_value,
extra=extra,
)
Expand All @@ -420,7 +427,7 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
return instance


def make_deferred_related(factory: FactoryType, fixture: str, attr: str) -> DeferredFunction:
def make_deferred_related(factory: type[Factory[T]], fixture: str, attr: str) -> DeferredFunction[T]:
"""Make deferred function for the related factory declaration.

:param factory: Factory class.
Expand All @@ -443,14 +450,14 @@ def deferred_impl(request: SubRequest) -> Any:


def make_deferred_postgen(
step: factory.builder.BuildStep,
factory_class: FactoryType,
step: BuildStep,
factory_class: type[Factory[T]],
fixture: str,
instance: Any,
attr: str,
declaration: factory.declarations.PostGenerationDeclaration,
declaration: PostGenerationDeclaration,
context: PostGenerationContext,
) -> DeferredFunction:
) -> DeferredFunction[T]:
"""Make deferred function for the post-generation declaration.

:param step: factory_boy builder step.
Expand All @@ -476,7 +483,7 @@ def deferred_impl(request: SubRequest) -> Any:
)


def factory_fixture(request: SubRequest, factory_class: F) -> F:
def factory_fixture(request: SubRequest, factory_class: type[Factory[T]]) -> type[Factory[T]]:
"""Factory fixture implementation."""
return factory_class

Expand All @@ -486,7 +493,7 @@ def attr_fixture(request: SubRequest, value: T) -> T:
return value


def subfactory_fixture(request: SubRequest, factory_class: FactoryType) -> Any:
def subfactory_fixture(request: SubRequest, factory_class: type[Factory[object]]) -> Any:
"""SubFactory/RelatedFactory fixture implementation."""
fixture = inflection.underscore(factory_class._meta.model.__name__)
return request.getfixturevalue(fixture)
Expand Down
Loading
Loading