Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 62 additions & 51 deletions polyfactory/value_generators/constrained_numbers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from decimal import Decimal
from math import ceil, floor, ulp
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like I'm failing the Python 3.8 tests because I'm importing ulp, which was only added in Python 3.9. I'll probably have to implement an alternative way of computing the correct increment for a given floating-point number.

from sys import float_info
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from typing import TYPE_CHECKING, Protocol, TypeVar, cast

from polyfactory.exceptions import ParameterException
from polyfactory.value_generators.primitives import create_random_decimal, create_random_float, create_random_integer
Expand Down Expand Up @@ -99,8 +100,8 @@ def is_multiply_of_multiple_of_in_range(
return False


def passes_pydantic_multiple_validator(value: T, multiple_of: T) -> bool:
"""Determine whether a given value passes the pydantic multiple_of validation.
def is_almost_multiple_of(value: T, multiple_of: T) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep the passes_pydantic_multiple_)validator as it is and deprecate it and then create a new is_almost_multiple_of function as a replacement.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we could just do something like this:

def passes_pydantic_multiple_validator(value: T, multiple_of: T)  -> bool:
    return is_almost_mulitple_of(value, multiple_of)

That is, changing the logic in it is fine since it's doing the same, but we would need to keep the old name around.

"""Determine whether a given ``value`` is a close enough to a multiple of ``multiple_of``.

:param value: A numeric value.
:param multiple_of: Another numeric value.
Expand All @@ -110,23 +111,33 @@ def passes_pydantic_multiple_validator(value: T, multiple_of: T) -> bool:
"""
if multiple_of == 0:
return True
mod = float(value) / float(multiple_of) % 1
return almost_equal_floats(mod, 0.0) or almost_equal_floats(mod, 1.0)
mod = value % multiple_of
return almost_equal_floats(float(mod), 0.0) or almost_equal_floats(float(abs(mod)), float(abs(multiple_of)))


def get_increment(t_type: type[T]) -> T:
def get_increment(value: T, t_type: type[T]) -> T:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a breaking change. Instead, we should have another get_increment_v2 function and deprecate this one.

"""Get a small increment base to add to constrained values, i.e. lt/gt entries.

:param t_type: A value of type T.
:param value: A value of type T.
:param t_type: The type of ``value``.

:returns: An increment T.
"""
values: dict[Any, Any] = {
int: 1,
float: float_info.epsilon,
Decimal: Decimal("0.001"),
}
return cast("T", values[t_type])
# See https://github.com/python/mypy/issues/17045 for why the redundant casts are ignored.
if t_type == int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these checks should maybe use is_safe_subclass so that we handle custom int as well. Same for float and Decimal.

return cast("T", 1)
if t_type == float:
# When ``value`` is large in magnitude, we need to choose an increment that is large enough
# to not be rounded away, but when ``value`` small in magnitude, we need to prevent the
# incerement from vanishing. ``float_info.epsilon`` is defined as the smallest delta that
# can be represented between 1.0 and the next largest number, but it's not sufficient for
# larger values. ``ulp(x)`` will return smallest delta that can be added to ``x``.
return cast("T", max(ulp(value), float_info.epsilon)) # type: ignore[redundant-cast]
if t_type == Decimal:
return cast("T", Decimal("0.001")) # type: ignore[redundant-cast]

msg = f"invalid t_type: {t_type}"
raise AssertionError(msg)


def get_value_or_none(
Expand All @@ -147,14 +158,14 @@ def get_value_or_none(
if ge is not None:
minimum_value = ge
elif gt is not None:
minimum_value = gt + get_increment(t_type)
minimum_value = gt + get_increment(gt, t_type)
else:
minimum_value = None

if le is not None:
maximum_value = le
elif lt is not None:
maximum_value = lt - get_increment(t_type)
maximum_value = lt - get_increment(lt, t_type)
else:
maximum_value = None
return minimum_value, maximum_value
Expand Down Expand Up @@ -210,33 +221,36 @@ def get_constrained_number_range(
return minimum, maximum


def generate_constrained_number(
def generate_constrained_multiple_of(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above regarding it being breaking changes.

random: Random,
minimum: T | None,
maximum: T | None,
multiple_of: T | None,
method: "NumberGeneratorProtocol[T]",
multiple_of: T,
) -> T:
"""Generate a constrained number, output depends on the passed in callbacks.
"""Generate a constrained multiple of ``multiple_of``.

:param random: An instance of random.
:param minimum: A minimum value.
:param maximum: A maximum value.
:param multiple_of: A multiple of value.
:param method: A function that generates numbers of type T.

:returns: A value of type T.
"""
if minimum is None or maximum is None:
return multiple_of if multiple_of is not None else method(random=random)
if multiple_of is None:
return method(random=random, minimum=minimum, maximum=maximum)
if multiple_of >= minimum:
return multiple_of
result = minimum
while not passes_pydantic_multiple_validator(result, multiple_of):
result = round(method(random=random, minimum=minimum, maximum=maximum) / multiple_of) * multiple_of
return result

# Regardless of the type of ``multiple_of``, we can generate a valid multiple of it by
# multiplying it with any integer, which we call a multiplier. We will randomly generate the
# multiplier as a random integer, but we need to translate the original bounds, if any, to the
# correct bounds on the multiplier so that the resulting product will meet the original
# constraints.

if multiple_of < 0:
minimum, maximum = maximum, minimum

multiplier_min = ceil(minimum / multiple_of) if minimum is not None else None
multiplier_max = floor(maximum / multiple_of) if maximum is not None else None
multiplier = create_random_integer(random=random, minimum=multiplier_min, maximum=multiplier_max)

return multiplier * multiple_of


def handle_constrained_int(
Expand Down Expand Up @@ -269,13 +283,11 @@ def handle_constrained_int(
multiple_of=multiple_of,
random=random,
)
return generate_constrained_number(
random=random,
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
method=create_random_integer,
)

if multiple_of is None:
return create_random_integer(random=random, minimum=minimum, maximum=maximum)

return generate_constrained_multiple_of(random=random, minimum=minimum, maximum=maximum, multiple_of=multiple_of)


def handle_constrained_float(
Expand Down Expand Up @@ -308,13 +320,10 @@ def handle_constrained_float(
random=random,
)

return generate_constrained_number(
random=random,
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
method=create_random_float,
)
if multiple_of is None:
return create_random_float(random=random, minimum=minimum, maximum=maximum)

return generate_constrained_multiple_of(random=random, minimum=minimum, maximum=maximum, multiple_of=multiple_of)


def validate_max_digits(
Expand Down Expand Up @@ -422,13 +431,15 @@ def handle_constrained_decimal(
if max_digits is not None:
validate_max_digits(max_digits=max_digits, minimum=minimum, decimal_places=decimal_places)

generated_decimal = generate_constrained_number(
random=random,
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
method=create_random_decimal,
)
if multiple_of is None:
generated_decimal = create_random_decimal(random=random, minimum=minimum, maximum=maximum)
else:
generated_decimal = generate_constrained_multiple_of(
random=random,
minimum=minimum,
maximum=maximum,
multiple_of=multiple_of,
)

if max_digits is not None or decimal_places is not None:
return handle_decimal_length(
Expand Down
61 changes: 43 additions & 18 deletions tests/constraints/test_decimal_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, cast

import pytest
from hypothesis import given
from hypothesis import assume, given
from hypothesis.strategies import decimals, integers

from pydantic import BaseModel, condecimal
Expand All @@ -13,11 +13,24 @@
from polyfactory.value_generators.constrained_numbers import (
handle_constrained_decimal,
handle_decimal_length,
is_almost_multiple_of,
is_multiply_of_multiple_of_in_range,
passes_pydantic_multiple_validator,
)


def assume_max_digits(x: Decimal, max_digits: int) -> None:
"""
Signal to Hypothesis that ``x`` should have at most ``max_digits`` significant digits.

This is different than the ``decimals()`` strategy function's ``places`` keyword argument, which
only counts the digits after the decimal point when the number is written without an exponent.

E.g. 12.51 has 4 significant digits but 2 decimal places.
"""

assume(len(x.as_tuple().digits) <= max_digits)


def test_handle_constrained_decimal_without_constraints() -> None:
result = handle_constrained_decimal(
random=Random(),
Expand Down Expand Up @@ -162,7 +175,7 @@ def test_handle_constrained_decimal_handles_multiple_of(multiple_of: Decimal) ->
random=Random(),
multiple_of=multiple_of,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand All @@ -185,15 +198,17 @@ def test_handle_constrained_decimal_handles_multiple_of(multiple_of: Decimal) ->
max_value=1000000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_lt(val1: Decimal, val2: Decimal) -> None:
multiple_of, max_value = sorted([val1, val2])
def test_handle_constrained_decimal_handles_multiple_of_with_lt(max_value: Decimal, multiple_of: Decimal) -> None:
if multiple_of != Decimal("0"):
assume_max_digits(max_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
lt=max_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert result < max_value
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand All @@ -217,15 +232,17 @@ def test_handle_constrained_decimal_handles_multiple_of_with_lt(val1: Decimal, v
max_value=1000000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_le(val1: Decimal, val2: Decimal) -> None:
multiple_of, max_value = sorted([val1, val2])
def test_handle_constrained_decimal_handles_multiple_of_with_le(max_value: Decimal, multiple_of: Decimal) -> None:
if multiple_of != Decimal("0"):
assume_max_digits(max_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
le=max_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert result <= max_value
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand All @@ -249,15 +266,17 @@ def test_handle_constrained_decimal_handles_multiple_of_with_le(val1: Decimal, v
max_value=1000000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, val2: Decimal) -> None:
min_value, multiple_of = sorted([val1, val2])
def test_handle_constrained_decimal_handles_multiple_of_with_ge(min_value: Decimal, multiple_of: Decimal) -> None:
if multiple_of != Decimal("0"):
assume_max_digits(min_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
ge=min_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert min_value <= result
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand All @@ -281,15 +300,17 @@ def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, v
max_value=1000000000,
),
)
def test_handle_constrained_decimal_handles_multiple_of_with_gt(val1: Decimal, val2: Decimal) -> None:
min_value, multiple_of = sorted([val1, val2])
def test_handle_constrained_decimal_handles_multiple_of_with_gt(min_value: Decimal, multiple_of: Decimal) -> None:
if multiple_of != Decimal("0"):
assume_max_digits(min_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
gt=min_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert min_value < result
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand Down Expand Up @@ -322,21 +343,25 @@ def test_handle_constrained_decimal_handles_multiple_of_with_gt(val1: Decimal, v
def test_handle_constrained_decimal_handles_multiple_of_with_ge_and_le(
val1: Decimal,
val2: Decimal,
val3: Decimal,
multiple_of: Decimal,
) -> None:
min_value, multiple_of, max_value = sorted([val1, val2, val3])
min_value, max_value = sorted([val1, val2])
if multiple_of != Decimal("0") and is_multiply_of_multiple_of_in_range(
minimum=min_value,
maximum=max_value,
multiple_of=multiple_of,
):
assume_max_digits(min_value, 10)
assume_max_digits(max_value, 10)
assume_max_digits(multiple_of, 10)
result = handle_constrained_decimal(
random=Random(),
multiple_of=multiple_of,
ge=min_value,
le=max_value,
)
assert passes_pydantic_multiple_validator(result, multiple_of)
assert min_value <= result <= max_value
assert is_almost_multiple_of(result, multiple_of)
else:
with pytest.raises(ParameterException):
handle_constrained_decimal(
Expand Down
Loading