1313from typing import TYPE_CHECKING , Any , Callable , Generic , TypeVar , cast , overload
1414
1515import factory
16- import factory .builder
17- import factory .declarations
1816import factory .enums
1917import inflection
20- from typing_extensions import ParamSpec , TypeAlias
18+ from factory .base import Factory
19+ from factory .builder import BuildStep , DeclarationSet , StepBuilder
20+ from factory .declarations import (
21+ NotProvided ,
22+ PostGeneration ,
23+ PostGenerationDeclaration ,
24+ PostGenerationMethodCall ,
25+ RelatedFactory ,
26+ SubFactory ,
27+ )
28+ from typing_extensions import ParamSpec
2129
2230from .compat import PostGenerationContext
2331from .fixturegen import create_fixture
2735
2836 from .plugin import Request as FactoryboyRequest
2937
30- FactoryType : TypeAlias = type [factory .Factory ]
31- F = TypeVar ("F" , bound = FactoryType )
3238T = TypeVar ("T" )
39+ U = TypeVar ("U" )
3340T_co = TypeVar ("T_co" , covariant = True )
3441P = ParamSpec ("P" )
3542
3845
3946
4047@dataclass (eq = False )
41- class DeferredFunction :
48+ class DeferredFunction ( Generic [ T ]) :
4249 name : str
43- factory : FactoryType
50+ factory : type [ Factory [ T ]]
4451 is_related : bool
4552 function : Callable [[SubRequest ], Any ]
4653
@@ -67,24 +74,24 @@ def named_model(model_cls: type[T], name: str) -> type[T]:
6774# register(AuthorFactory, ...)
6875#
6976# @register
70- # class AuthorFactory(factory. Factory): ...
77+ # class AuthorFactory(Factory): ...
7178@overload
72- def register (factory_class : F , _name : str | None = None , ** kwargs : Any ) -> F : ...
79+ def register (factory_class : type [ Factory [ T ]] , _name : str | None = None , ** kwargs : Any ) -> type [ Factory [ T ]] : ...
7380
7481
7582# @register(...)
76- # class AuthorFactory(factory. Factory): ...
83+ # class AuthorFactory(Factory): ...
7784@overload
78- def register (* , _name : str | None = None , ** kwargs : Any ) -> Callable [[F ], F ]: ...
85+ def register (* , _name : str | None = None , ** kwargs : Any ) -> Callable [[type [ Factory [ T ]]], type [ Factory [ T ]] ]: ...
7986
8087
8188def register (
82- factory_class : F | None = None ,
89+ factory_class : type [ Factory [ T ]] | None = None ,
8390 _name : str | None = None ,
8491 * ,
8592 _caller_locals : Box [dict [str , Any ]] | None = None ,
8693 ** kwargs : Any ,
87- ) -> F | Callable [[F ], F ]:
94+ ) -> type [ Factory [ T ]] | Callable [[type [ Factory [ T ]]], type [ Factory [ T ]] ]:
8895 r"""Register fixtures for the factory class.
8996
9097 :param factory_class: Factory class to register.
@@ -97,7 +104,7 @@ def register(
97104
98105 if factory_class is None :
99106
100- def register_ (factory_class : F ) -> F :
107+ def register_ (factory_class : type [ Factory [ T ]] ) -> type [ Factory [ T ]] :
101108 return register (factory_class , _name = _name , _caller_locals = _caller_locals , ** kwargs )
102109
103110 return register_
@@ -131,7 +138,7 @@ def register_(factory_class: F) -> F:
131138
132139
133140def generate_fixtures (
134- factory_class : FactoryType ,
141+ factory_class : type [ Factory [ T ]] ,
135142 model_name : str ,
136143 factory_name : str ,
137144 overrides : Mapping [str , Any ],
@@ -193,23 +200,23 @@ def create_fixture_with_related(
193200def make_declaration_fixturedef (
194201 attr_name : str ,
195202 value : Any ,
196- factory_class : FactoryType ,
203+ factory_class : type [ Factory [ T ]] ,
197204 related : list [str ],
198205) -> Callable [..., Any ]:
199206 """Create the FixtureDef for a factory declaration."""
200- if isinstance (value , (factory . SubFactory , factory . RelatedFactory )):
201- subfactory_class = value .get_factory ()
207+ if isinstance (value , (SubFactory , RelatedFactory )):
208+ subfactory_class : type [ Factory [ object ]] = value .get_factory ()
202209 subfactory_deps = get_deps (subfactory_class , factory_class )
203210
204211 args = list (subfactory_deps )
205- if isinstance (value , factory . RelatedFactory ):
212+ if isinstance (value , RelatedFactory ):
206213 related_model = get_model_name (subfactory_class )
207214 args .append (related_model )
208215 related .append (related_model )
209216 related .append (attr_name )
210217 related .extend (subfactory_deps )
211218
212- if isinstance (value , factory . SubFactory ):
219+ if isinstance (value , SubFactory ):
213220 args .append (inflection .underscore (subfactory_class ._meta .model .__name__ ))
214221
215222 return create_fixture_with_related (
@@ -219,10 +226,10 @@ def make_declaration_fixturedef(
219226 )
220227
221228 deps : list [str ] # makes mypy happy
222- if isinstance (value , factory . PostGeneration ):
229+ if isinstance (value , PostGeneration ):
223230 value = None
224231 deps = []
225- elif isinstance (value , factory . PostGenerationMethodCall ):
232+ elif isinstance (value , PostGenerationMethodCall ):
226233 value = value .method_arg
227234 deps = []
228235 elif isinstance (value , LazyFixture ):
@@ -258,7 +265,7 @@ def inject_into_caller(name: str, function: Callable[..., Any], locals_: Box[dic
258265 locals_ .value [name ] = function
259266
260267
261- def get_model_name (factory_class : FactoryType ) -> str :
268+ def get_model_name (factory_class : type [ Factory [ T ]] ) -> str :
262269 """Get model fixture name by factory."""
263270 model_cls = factory_class ._meta .model
264271
@@ -278,14 +285,14 @@ def get_model_name(factory_class: FactoryType) -> str:
278285 return model_name
279286
280287
281- def get_factory_name (factory_class : FactoryType ) -> str :
288+ def get_factory_name (factory_class : type [ Factory [ T ]] ) -> str :
282289 """Get factory fixture name by factory."""
283290 return inflection .underscore (factory_class .__name__ )
284291
285292
286293def get_deps (
287- factory_class : FactoryType ,
288- parent_factory_class : FactoryType | None = None ,
294+ factory_class : type [ Factory [ T ]] ,
295+ parent_factory_class : type [ Factory [ U ]] | None = None ,
289296 model_name : str | None = None ,
290297) -> list [str ]:
291298 """Get factory dependencies.
@@ -296,11 +303,13 @@ def get_deps(
296303 parent_model_name = get_model_name (parent_factory_class ) if parent_factory_class is not None else None
297304
298305 def is_dep (value : Any ) -> bool :
299- if isinstance (value , factory . RelatedFactory ):
306+ if isinstance (value , RelatedFactory ):
300307 return False
301- if isinstance (value , factory .SubFactory ) and get_model_name (value .get_factory ()) == parent_model_name :
302- return False
303- if isinstance (value , factory .declarations .PostGenerationDeclaration ):
308+ if isinstance (value , SubFactory ):
309+ subfactory_class : type [Factory [object ]] = value .get_factory ()
310+ if get_model_name (subfactory_class ) == parent_model_name :
311+ return False
312+ if isinstance (value , PostGenerationDeclaration ):
304313 # Dependency on extracted value
305314 return True
306315
@@ -334,7 +343,7 @@ def disable_method(method: MethodType) -> Iterator[None]:
334343 setattr (klass , method .__name__ , old_method )
335344
336345
337- def model_fixture (request : SubRequest , factory_name : str ) -> Any :
346+ def model_fixture (request : SubRequest , factory_name : str ) -> object :
338347 """Model fixture implementation."""
339348 factoryboy_request : FactoryboyRequest = request .getfixturevalue ("factoryboy_request" )
340349
@@ -345,21 +354,19 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
345354 fixture_name = request .fixturename
346355 prefix = "" .join ((fixture_name , SEPARATOR ))
347356
348- factory_class : FactoryType = request .getfixturevalue (factory_name )
357+ factory_class : type [ Factory [ object ]] = request .getfixturevalue (factory_name )
349358
350359 # Create model fixture instance
351- Factory : FactoryType = cast (FactoryType , type ("Factory" , (factory_class ,), {}))
360+ NewFactory : type [ Factory [ object ]] = cast (type [ Factory [ object ]] , type ("Factory" , (factory_class ,), {}))
352361 # equivalent to:
353362 # class Factory(factory_class):
354363 # pass
355364 # it just makes mypy understand it.
356365
357- Factory ._meta .base_declarations = {
358- k : v
359- for k , v in Factory ._meta .base_declarations .items ()
360- if not isinstance (v , factory .declarations .PostGenerationDeclaration )
366+ NewFactory ._meta .base_declarations = {
367+ k : v for k , v in NewFactory ._meta .base_declarations .items () if not isinstance (v , PostGenerationDeclaration )
361368 }
362- Factory ._meta .post_declarations = factory . builder . DeclarationSet ()
369+ NewFactory ._meta .post_declarations = DeclarationSet ()
363370
364371 kwargs = {}
365372 for key in factory_class ._meta .pre_declarations :
@@ -368,25 +375,25 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
368375 kwargs [key ] = evaluate (request , request .getfixturevalue (argname ))
369376
370377 strategy = factory .enums .CREATE_STRATEGY
371- builder = factory . builder . StepBuilder (Factory ._meta , kwargs , strategy )
372- step = factory . builder . BuildStep (builder = builder , sequence = Factory ._meta .next_sequence ())
378+ builder = StepBuilder (NewFactory ._meta , kwargs , strategy )
379+ step = BuildStep (builder = builder , sequence = NewFactory ._meta .next_sequence ())
373380
374381 # FactoryBoy invokes the `_after_postgeneration` method, but we will instead call it manually later,
375382 # once we are able to evaluate all the related fixtures.
376- with disable_method (Factory ._after_postgeneration ):
377- instance = Factory (** kwargs )
383+ with disable_method (NewFactory ._after_postgeneration ): # type: ignore[arg-type] # https://github.com/python/mypy/issues/14235
384+ instance = NewFactory (** kwargs )
378385
379386 # Cache the instance value on pytest level so that the fixture can be resolved before the return
380387 request ._fixturedef .cached_result = (instance , 0 , None )
381388 request ._fixture_defs [fixture_name ] = request ._fixturedef
382389
383390 # Defer post-generation declarations
384- deferred : list [DeferredFunction ] = []
391+ deferred : list [DeferredFunction [ object ] ] = []
385392
386393 for attr in factory_class ._meta .post_declarations .sorted ():
387394 decl = factory_class ._meta .post_declarations .declarations [attr ]
388395
389- if isinstance (decl , factory . RelatedFactory ):
396+ if isinstance (decl , RelatedFactory ):
390397 deferred .append (make_deferred_related (factory_class , fixture_name , attr ))
391398 else :
392399 argname = "" .join ((prefix , attr ))
@@ -405,7 +412,7 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
405412 # that `value_provided` should be falsy
406413 postgen_value = evaluate (request , request .getfixturevalue (argname ))
407414 postgen_context = PostGenerationContext (
408- value_provided = (postgen_value is not factory . declarations . NotProvided ),
415+ value_provided = (postgen_value is not NotProvided ),
409416 value = postgen_value ,
410417 extra = extra ,
411418 )
@@ -420,7 +427,7 @@ def model_fixture(request: SubRequest, factory_name: str) -> Any:
420427 return instance
421428
422429
423- def make_deferred_related (factory : FactoryType , fixture : str , attr : str ) -> DeferredFunction :
430+ def make_deferred_related (factory : type [ Factory [ T ]] , fixture : str , attr : str ) -> DeferredFunction [ T ] :
424431 """Make deferred function for the related factory declaration.
425432
426433 :param factory: Factory class.
@@ -443,14 +450,14 @@ def deferred_impl(request: SubRequest) -> Any:
443450
444451
445452def make_deferred_postgen (
446- step : factory . builder . BuildStep ,
447- factory_class : FactoryType ,
453+ step : BuildStep ,
454+ factory_class : type [ Factory [ T ]] ,
448455 fixture : str ,
449456 instance : Any ,
450457 attr : str ,
451- declaration : factory . declarations . PostGenerationDeclaration ,
458+ declaration : PostGenerationDeclaration ,
452459 context : PostGenerationContext ,
453- ) -> DeferredFunction :
460+ ) -> DeferredFunction [ T ] :
454461 """Make deferred function for the post-generation declaration.
455462
456463 :param step: factory_boy builder step.
@@ -476,7 +483,7 @@ def deferred_impl(request: SubRequest) -> Any:
476483 )
477484
478485
479- def factory_fixture (request : SubRequest , factory_class : F ) -> F :
486+ def factory_fixture (request : SubRequest , factory_class : type [ Factory [ T ]] ) -> type [ Factory [ T ]] :
480487 """Factory fixture implementation."""
481488 return factory_class
482489
@@ -486,7 +493,7 @@ def attr_fixture(request: SubRequest, value: T) -> T:
486493 return value
487494
488495
489- def subfactory_fixture (request : SubRequest , factory_class : FactoryType ) -> Any :
496+ def subfactory_fixture (request : SubRequest , factory_class : type [ Factory [ object ]] ) -> Any :
490497 """SubFactory/RelatedFactory fixture implementation."""
491498 fixture = inflection .underscore (factory_class ._meta .model .__name__ )
492499 return request .getfixturevalue (fixture )
0 commit comments