Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 79 additions & 4 deletions pydantic_extra_types/pendulum_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,65 @@ def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaH
Returns:
A Pydantic CoreSchema with the Duration validation.
"""
return core_schema.no_info_wrap_validator_function(cls._validate, core_schema.timedelta_schema())
return core_schema.no_info_wrap_validator_function(
cls._validate,
core_schema.timedelta_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.to_iso8601_string()
),
)

def to_iso8601_string(self) -> str:
"""
Convert a Duration object to an ISO 8601 string.

In addition to the standard ISO 8601 format, this method also supports the representation of fractions of a second and negative durations.

Args:
duration (Duration): The Duration object.

Returns:
str: The ISO 8601 string representation of the duration.
"""
# Extracting components from the Duration object
years = self.years
months = self.months
days = self._days
hours = self.hours
minutes = self.minutes
seconds = self.remaining_seconds
milliseconds = self.microseconds // 1000
microseconds = self.microseconds % 1000

# Constructing the ISO 8601 duration string
iso_duration = 'P'
if years or months or days:
if years:
iso_duration += f'{years}Y'
if months:
iso_duration += f'{months}M'
if days:
iso_duration += f'{days}D'

if hours or minutes or seconds or milliseconds or microseconds:
iso_duration += 'T'
if hours:
iso_duration += f'{hours}H'
if minutes:
iso_duration += f'{minutes}M'
if seconds or milliseconds or microseconds:
iso_duration += f'{seconds}'
if milliseconds or microseconds:
iso_duration += f'.{milliseconds:03d}'
if microseconds:
iso_duration += f'{microseconds:03d}'
iso_duration += 'S'

# Prefix with '-' if the duration is negative
if self.total_seconds() < 0:
iso_duration = '-' + iso_duration

return iso_duration

@classmethod
def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Duration:
Expand Down Expand Up @@ -219,10 +277,27 @@ def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler
microseconds=value.microseconds,
)

assert isinstance(value, str)
try:
parsed = parse(value, exact=True)
if not isinstance(parsed, timedelta):
# https://github.com/python-pendulum/pendulum/issues/532
if value.startswith('-'):
parsed = parse(value.lstrip('-'), exact=True)
else:
parsed = parse(value, exact=True)
if not isinstance(parsed, _Duration):
raise ValueError(f'value is not a valid duration it is a {type(parsed)}')
return Duration(seconds=parsed.total_seconds())
if value.startswith('-'):
parsed = -parsed

return Duration(
years=parsed.years,
months=parsed.months,
weeks=parsed.weeks,
days=parsed.remaining_days,
hours=parsed.hours,
minutes=parsed.minutes,
seconds=parsed.remaining_seconds,
microseconds=parsed.microseconds,
)
except Exception as exc:
raise PydanticCustomError('value_error', 'value is not a valid duration') from exc
25 changes: 25 additions & 0 deletions tests/test_pendulum_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,31 @@ def test_pendulum_duration_from_serialized(delta_t_str):
assert isinstance(model.delta_t, pendulum.Duration)


@pytest.mark.parametrize(
'duration',
[
(Duration(months=1)),
(Duration(weeks=1)),
(Duration(milliseconds=1)),
(Duration(microseconds=1)),
(Duration(days=1)),
(Duration(hours=1)),
(Duration(minutes=1)),
(Duration(seconds=1)),
(Duration(months=2, days=5)),
(Duration(weeks=3, hours=12)),
(Duration(days=10, minutes=30)),
(Duration(weeks=1, days=2, hours=3)),
(Duration(seconds=30, milliseconds=500)),
],
)
def test_pendulum_duration_serialization_roundtrip(duration):
adapter = TypeAdapter(Duration)
serialized = adapter.dump_python(duration)
deserialized = TypeAdapter.validate_python(adapter, serialized)
assert deserialized == duration


def get_invalid_dt_common():
return [
None,
Expand Down