Skip to content

Commit 87b71b1

Browse files
Neutronlulamureki
andauthored
Add type hints to seq()'s increment_by argument (#560)
* Add type hints for `increment_by` parameter in `seq` function * Update CHANGELOG * Fix mypy errors in `seq()` type hints --------- Co-authored-by: Rust Saiargaliev <fly.amureki@gmail.com>
1 parent 6e197a6 commit 87b71b1

File tree

2 files changed

+46
-25
lines changed

2 files changed

+46
-25
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
1111

1212
### Changed
1313
- docs: Update seq import in basic usage
14+
- Add type hints to `seq()`'s `increment_by` argument
1415

1516
### Removed
1617
- Drop mentions of model_mommy from the project. The old migration script is available in [the GitHub gist](https://gist.github.com/amureki/168b545105cb3e71f824351ffff507dc).

model_bakery/utils.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import datetime
2+
import decimal
23
import importlib
34
import inspect
45
import itertools
@@ -50,7 +51,45 @@ def get_calling_module(levels_back: int) -> ModuleType | None:
5051
return inspect.getmodule(frame)
5152

5253

53-
def seq(value, increment_by=1, start=None, suffix=None):
54+
def _seq_datetime(
55+
value: datetime.datetime | datetime.date | datetime.time,
56+
increment_by: int | float | decimal.Decimal | datetime.timedelta,
57+
):
58+
if not isinstance(increment_by, datetime.timedelta):
59+
raise TypeError("datetime values require a timedelta increment_by")
60+
if type(value) is datetime.date:
61+
date: datetime.datetime = datetime.datetime.combine(
62+
value, datetime.datetime.now().time()
63+
)
64+
elif type(value) is datetime.time:
65+
date = datetime.datetime.combine(datetime.date.today(), value)
66+
elif isinstance(value, datetime.datetime):
67+
date = value
68+
else:
69+
raise TypeError("Unexpected value type")
70+
71+
epoch_datetime = datetime.datetime(1970, 1, 1, tzinfo=date.tzinfo)
72+
start_seconds = (date - epoch_datetime).total_seconds()
73+
increment_seconds = increment_by.total_seconds()
74+
for n in itertools.count(increment_seconds, increment_seconds):
75+
series_date = tz_aware(
76+
datetime.datetime.fromtimestamp(start_seconds + n, tz=datetime.timezone.utc)
77+
)
78+
79+
if type(value) is datetime.time:
80+
yield series_date.time()
81+
elif type(value) is datetime.date:
82+
yield series_date.date()
83+
else:
84+
yield series_date
85+
86+
87+
def seq(
88+
value,
89+
increment_by: int | float | decimal.Decimal | datetime.timedelta = 1,
90+
start: int | float | None = None,
91+
suffix=None,
92+
):
5493
"""Generate a sequence of values based on a running count.
5594
5695
This function can be used to generate sequences of `int`, `float`,
@@ -60,8 +99,8 @@ def seq(value, increment_by=1, start=None, suffix=None):
6099
Args:
61100
value (object): the value at which to begin generation (this will
62101
be ignored for types `datetime`, `date`, and `time`)
63-
increment_by (`int` or `float`, optional): the amount by which to
64-
increment for each generated value (defaults to `1`)
102+
increment_by (`int` or `float` or `decimal.Decimal` or `datetime.timedelta`, optional):
103+
the amount by which to increment for each generated value (defaults to `1`)
65104
start (`int` or `float`, optional): the value at which the sequence
66105
will begin to add to `value` (if `value` is a `str`, `start` will
67106
be appended to it)
@@ -75,29 +114,10 @@ def seq(value, increment_by=1, start=None, suffix=None):
75114
_validate_sequence_parameters(value, increment_by, start, suffix)
76115

77116
if isinstance(value, (datetime.datetime, datetime.date, datetime.time)):
78-
if type(value) is datetime.date:
79-
date = datetime.datetime.combine(value, datetime.datetime.now().time())
80-
elif type(value) is datetime.time:
81-
date = datetime.datetime.combine(datetime.date.today(), value)
82-
else:
83-
date = value
84-
85-
# convert to epoch time
86-
epoch_datetime = datetime.datetime(1970, 1, 1, tzinfo=date.tzinfo)
87-
start = (date - epoch_datetime).total_seconds()
88-
increment_by = increment_by.total_seconds()
89-
for n in itertools.count(increment_by, increment_by):
90-
series_date = tz_aware(
91-
datetime.datetime.fromtimestamp(start + n, tz=datetime.timezone.utc)
92-
)
93-
94-
if type(value) is datetime.time:
95-
yield series_date.time()
96-
elif type(value) is datetime.date:
97-
yield series_date.date()
98-
else:
99-
yield series_date
117+
yield from _seq_datetime(value, increment_by)
100118
else:
119+
if isinstance(increment_by, datetime.timedelta):
120+
raise TypeError("non-datetime values do not support timedelta increment_by")
101121
for n in itertools.count(
102122
increment_by if start is None else start, increment_by
103123
):

0 commit comments

Comments
 (0)