Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions docs/examples/fields/test_example_10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dataclasses import dataclass

from polyfactory.decorators import post_generated
from polyfactory.factories import DataclassFactory
from polyfactory.fields import CallableParam


@dataclass
class Person:
name: str
age_next_year: int


class PersonFactoryWithParamValueSpecifiedInFactory(DataclassFactory[Person]):
"""In this factory, the next_years_age_from_calculator must be passed at build time."""

next_years_age_from_calculator = CallableParam[int](lambda age: age + 1, age=20)

@post_generated
@classmethod
def age_next_year(cls, next_years_age_from_calculator: int) -> int:
return next_years_age_from_calculator


def test_factory__in_factory() -> None:
person = PersonFactoryWithParamValueSpecifiedInFactory.build()

assert isinstance(person, Person)
assert not hasattr(person, "next_years_age_from_calculator")
assert person.age_next_year == 21


class PersonFactoryWithParamValueSetAtBuild(DataclassFactory[Person]):
"""In this factory, the next_years_age_from_calculator must be passed at build time."""

next_years_age_from_calculator = CallableParam[int](age=20)

@post_generated
@classmethod
def age_next_year(cls, next_years_age_from_calculator: int) -> int:
return next_years_age_from_calculator


def test_factory__build_time() -> None:
person = PersonFactoryWithParamValueSpecifiedInFactory.build(next_years_age_from_calculator=lambda age: age + 1)

assert isinstance(person, Person)
assert not hasattr(person, "next_years_age_from_calculator")
assert person.age_next_year == 21
52 changes: 52 additions & 0 deletions docs/examples/fields/test_example_9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from dataclasses import dataclass
from typing import List

from polyfactory.decorators import post_generated
from polyfactory.factories import DataclassFactory
from polyfactory.fields import Param


@dataclass
class Pet:
name: str
sound: str


class PetFactoryWithParamValueSetAtBuild(DataclassFactory[Pet]):
"""In this factory, the name_choices must be passed at build time."""

name_choices = Param[List[str]]()

@post_generated
@classmethod
def name(cls, name_choices: List[str]) -> str:
return cls.__random__.choice(name_choices)


def test_factory__build_time() -> None:
names = ["Ralph", "Roxy"]
pet = PetFactoryWithParamValueSetAtBuild.build(name_choices=names)

assert isinstance(pet, Pet)
assert not hasattr(pet, "name_choices")
assert pet.name in names


class PetFactoryWithParamSpecififiedInFactory(DataclassFactory[Pet]):
"""In this factory, the name_choices are specified in the
factory and do not need to be passed at build time."""

name_choices = Param[List[str]](["Ralph", "Roxy"])

@post_generated
@classmethod
def name(cls, name_choices: List[str]) -> str:
return cls.__random__.choice(name_choices)


def test_factory__in_factory() -> None:
pet = PetFactoryWithParamSpecififiedInFactory.build()

assert isinstance(pet, Pet)
assert not hasattr(pet, "name_choices")
assert pet.name in ["Ralph", "Roxy"]
30 changes: 30 additions & 0 deletions docs/usage/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,36 @@ The signature for use is: ``cb: Callable, *args, **defaults`` it can receive an
callable should be: ``name: str, values: dict[str, Any], *args, **defaults``. The already generated values are mapped by
name in the values dictionary.


The ``Param`` Field
-------------------

The :class:`Param <polyfactory.fields.Param>` class denotes a constant parameter that can be referenced by other fields at build but is not mapped to the final object. This is useful for passing values needed by other factory fields but that are not part of object being built.

The Param type can either accept a constant value at the definition of the factory, or its value can be set at build time.

If neither a value is provided at the definition of the factory nor at build time, an exception will be raised.

.. literalinclude:: /examples/fields/test_example_9.py
:caption: Using the ``Param`` field
:language: python


The ``CallableParam`` Field
---------------------------

The :class:`CallableParam <polyfactory.fields.CallableParam>` class denotes a callable parameter with a return value that may be referenced by other fields during build but is not mapped to the final object. Optional keyword arguments may be passed to the callable as part of the field definition on the factory. Any additional keyword arguments passed to the build method will also not be passed to the final object.

The CallableParam type can either accept a callable provided at the definition of the factory, or its value can be passed at build time. The callable is executed at the beginning of build.

If neither a value is provided at the definition of the factory nor at build time, an exception will be raised.

The difference between a Param and a CallableParam is that the CallableParam is always executed at build time. If you need to pass an unmapped callable to the factory that should not automatically be executed at build time, use a Param.

.. literalinclude:: /examples/fields/test_example_10.py
:caption: Using the ``CallableParam`` field
:language: python

Factories as Fields
---------------------------

Expand Down
4 changes: 4 additions & 0 deletions polyfactory/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ class MissingBuildKwargException(FactoryException):

class MissingDependencyException(FactoryException, ImportError):
"""Missing dependency exception - used when a dependency is not installed"""


class MissingParamException(FactoryException):
"""Missing parameter exception - used when a required Param is not provided"""
44 changes: 40 additions & 4 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,14 @@
MIN_COLLECTION_LENGTH,
RANDOMIZE_COLLECTION_LENGTH,
)
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException
from polyfactory.exceptions import (
ConfigurationException,
MissingBuildKwargException,
MissingParamException,
ParameterException,
)
from polyfactory.field_meta import Null
from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
from polyfactory.fields import BaseParam, Fixture, Ignore, IsNotPassed, PostGenerated, Require, Use
from polyfactory.utils.helpers import (
flatten_annotation,
get_collection_type,
Expand Down Expand Up @@ -965,6 +970,30 @@ def _check_declared_fields_exist_in_model(cls) -> None:
if isinstance(field_value, (Use, PostGenerated, Ignore, Require)):
raise ConfigurationException(error_message)

@classmethod
def _handle_factory_params(cls, params: dict[str, BaseParam], **kwargs: Any) -> dict[str, Any]:
"""Get the factory parameters.

:param params: A dict of field name to Param instances.
:param kwargs: Any build kwargs.

:returns: A dict of fieldname mapped to realized Param values.
"""

try:
return {name: param.to_value(kwargs.get(name, IsNotPassed)) for name, param in params.items()}
except MissingParamException as e:
msg = "Missing required kwargs"
raise MissingBuildKwargException(msg) from e

@classmethod
def get_factory_params(cls) -> dict[str, BaseParam]:
"""Get the factory parameters.

:returns: A dict of field name to Param instances.
"""
return {name: item for name, item in cls.__dict__.items() if isinstance(item, BaseParam)}

@classmethod
def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
"""Process the given kwargs and generate values for the factory's model.
Expand All @@ -980,6 +1009,9 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

params = cls.get_factory_params()
result.update(cls._handle_factory_params(params, **kwargs))

for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)
if cls.should_set_field_value(field_meta, **kwargs) and not cls.should_use_default_value(field_meta):
Expand Down Expand Up @@ -1016,7 +1048,7 @@ def process_kwargs(cls, **kwargs: Any) -> dict[str, Any]:
for field_name, post_generator in generate_post.items():
result[field_name] = post_generator.to_value(field_name, result)

return result
return {key: value for key, value in result.items() if key not in params}

@classmethod
def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
Expand All @@ -1034,6 +1066,9 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
result: dict[str, Any] = {**kwargs}
generate_post: dict[str, PostGenerated] = {}

params = cls.get_factory_params()
result.update(cls._handle_factory_params(params, **kwargs))

for field_meta in cls.get_model_fields():
field_build_parameters = cls.extract_field_build_parameters(field_meta=field_meta, build_args=kwargs)

Expand Down Expand Up @@ -1069,7 +1104,8 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
for resolved in resolve_kwargs_coverage(result):
for field_name, post_generator in generate_post.items():
resolved[field_name] = post_generator.to_value(field_name, resolved)
yield resolved

yield {key: value for key, value in resolved.items() if key not in params}

@classmethod
def build(cls, **kwargs: Any) -> T:
Expand Down
121 changes: 120 additions & 1 deletion polyfactory/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from typing_extensions import ParamSpec

from polyfactory.exceptions import ParameterException
from polyfactory.exceptions import MissingParamException, ParameterException

T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")


Expand Down Expand Up @@ -114,3 +115,121 @@ def to_value(self) -> Any:

msg = "fixture has not been registered using the register_factory decorator"
raise ParameterException(msg)


class NotPassed:
"""Indicates a parameter was not passed to a factory field and must be
passed at build time.
"""


IsNotPassed = NotPassed()


class BaseParam(Generic[T, U]):
"""Base class for parameters.

This class is used to pass a parameters that can be referenced by other
fields but will not be passed to the final object.

It is generic over the type of the parameter that will be used during build
and also the method used to generate that value (e.g. as a constant or a
callable).
"""

def to_value(self, from_build: U | NotPassed = IsNotPassed) -> T:
"""Determines the value of the parameter.

This method must be implemented in subclasses.

:param from_build: The value passed at build time.
:returns: The value
:raises: NotImplementedError
"""
msg = "to_value must be implemented in subclasses"
raise NotImplementedError(msg) # pragma: no cover


class Param(Generic[T], BaseParam[T, T]):
"""A constant parameter that can be used by other fields but will not be
passed to the final object.

If a value for the parameter is not passed in the field's definition, it must
be passed at build time. Otherwise, a MissingParamException will be raised.
"""

__slots__ = ("param",)

def __init__(self, param: T | NotPassed = IsNotPassed) -> None:
"""Designate a parameter.

:param param: A constant or an unpassed value that can be referenced later
"""
self.param = param

def to_value(self, from_build: T | NotPassed = IsNotPassed) -> T:
"""Determines the value to use at build time

If a value was passed to the constructor, it will be used. Otherwise, the value
passed at build time will be used. If no value was passed at build time, a
MissingParamException will be raised.

:param args: from_build: The value passed at build time (if any).
:returns: The value
:raises: MissingParamException
"""
if self.param is IsNotPassed:
if from_build is not IsNotPassed:
return cast(T, from_build)
msg = "Param value was not passed at build time"
raise MissingParamException(msg)
return cast(T, self.param)


class CallableParam(Generic[T], BaseParam[T, Callable[..., T]]):
"""A callable parameter that can be used by other fields but will not be
passed to the final object.

The callable may be passed optional keyword arguments via the constructor
of this class. The callable will be invoked with the passed keyword
arguments and any positional arguments passed at build time.

If a callable for the parameter is not passed in the field's definition, it must
be passed at build time. Otherwise, a MissingParamException will be raised.
"""

__slots__ = (
"kwargs",
"param",
)

def __init__(
self,
param: Callable[..., T] | NotPassed = IsNotPassed,
**kwargs: Any,
) -> None:
"""Designate field as a callable parameter.

:param param: A callable that will be evaluated at build time.
:param kwargs: Any kwargs to pass to the callable.
"""
self.param = param
self.kwargs = kwargs

def to_value(self, from_build: Callable[..., T] | NotPassed = IsNotPassed) -> T:
"""Determine the value to use at build time.

If a value was passed to the constructor, it will be used. Otherwise, the value
passed at build time will be used. If no value was passed at build time, a
MissingParamException will be raised.

:param args: from_build: The callable passed at build time (if any).
:returns: The value
:raises: MissingParamException
"""
if self.param is IsNotPassed:
if from_build is not IsNotPassed:
return cast(Callable[..., T], from_build)(**self.kwargs)
msg = "Param value was not passed at build time"
raise MissingParamException(msg)
return cast(Callable[..., T], self.param)(**self.kwargs)
Loading
Loading