diff --git a/pydantic_extra_types/pendulum_dt.py b/pydantic_extra_types/pendulum_dt.py index f306529..fde1729 100644 --- a/pydantic_extra_types/pendulum_dt.py +++ b/pydantic_extra_types/pendulum_dt.py @@ -186,7 +186,65 @@ def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaH Returns: A Pydantic CoreSchema with the Duration validation. """ - return core_schema.no_info_wrap_validator_function(cls._validate, core_schema.timedelta_schema()) + return core_schema.no_info_wrap_validator_function( + cls._validate, + core_schema.timedelta_schema(), + serialization=core_schema.plain_serializer_function_ser_schema( + lambda instance: instance.to_iso8601_string() + ), + ) + + def to_iso8601_string(self) -> str: + """ + Convert a Duration object to an ISO 8601 string. + + In addition to the standard ISO 8601 format, this method also supports the representation of fractions of a second and negative durations. + + Args: + duration (Duration): The Duration object. + + Returns: + str: The ISO 8601 string representation of the duration. + """ + # Extracting components from the Duration object + years = self.years + months = self.months + days = self._days + hours = self.hours + minutes = self.minutes + seconds = self.remaining_seconds + milliseconds = self.microseconds // 1000 + microseconds = self.microseconds % 1000 + + # Constructing the ISO 8601 duration string + iso_duration = 'P' + if years or months or days: + if years: + iso_duration += f'{years}Y' + if months: + iso_duration += f'{months}M' + if days: + iso_duration += f'{days}D' + + if hours or minutes or seconds or milliseconds or microseconds: + iso_duration += 'T' + if hours: + iso_duration += f'{hours}H' + if minutes: + iso_duration += f'{minutes}M' + if seconds or milliseconds or microseconds: + iso_duration += f'{seconds}' + if milliseconds or microseconds: + iso_duration += f'.{milliseconds:03d}' + if microseconds: + iso_duration += f'{microseconds:03d}' + iso_duration += 'S' + + # Prefix with '-' if the duration is negative + if self.total_seconds() < 0: + iso_duration = '-' + iso_duration + + return iso_duration @classmethod def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Duration: @@ -219,10 +277,27 @@ def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler microseconds=value.microseconds, ) + assert isinstance(value, str) try: - parsed = parse(value, exact=True) - if not isinstance(parsed, timedelta): + # https://github.com/python-pendulum/pendulum/issues/532 + if value.startswith('-'): + parsed = parse(value.lstrip('-'), exact=True) + else: + parsed = parse(value, exact=True) + if not isinstance(parsed, _Duration): raise ValueError(f'value is not a valid duration it is a {type(parsed)}') - return Duration(seconds=parsed.total_seconds()) + if value.startswith('-'): + parsed = -parsed + + return Duration( + years=parsed.years, + months=parsed.months, + weeks=parsed.weeks, + days=parsed.remaining_days, + hours=parsed.hours, + minutes=parsed.minutes, + seconds=parsed.remaining_seconds, + microseconds=parsed.microseconds, + ) except Exception as exc: raise PydanticCustomError('value_error', 'value is not a valid duration') from exc diff --git a/tests/test_pendulum_dt.py b/tests/test_pendulum_dt.py index 7635b5d..c482fd2 100644 --- a/tests/test_pendulum_dt.py +++ b/tests/test_pendulum_dt.py @@ -283,6 +283,31 @@ def test_pendulum_duration_from_serialized(delta_t_str): assert isinstance(model.delta_t, pendulum.Duration) +@pytest.mark.parametrize( + 'duration', + [ + (Duration(months=1)), + (Duration(weeks=1)), + (Duration(milliseconds=1)), + (Duration(microseconds=1)), + (Duration(days=1)), + (Duration(hours=1)), + (Duration(minutes=1)), + (Duration(seconds=1)), + (Duration(months=2, days=5)), + (Duration(weeks=3, hours=12)), + (Duration(days=10, minutes=30)), + (Duration(weeks=1, days=2, hours=3)), + (Duration(seconds=30, milliseconds=500)), + ], +) +def test_pendulum_duration_serialization_roundtrip(duration): + adapter = TypeAdapter(Duration) + serialized = adapter.dump_python(duration) + deserialized = TypeAdapter.validate_python(adapter, serialized) + assert deserialized == duration + + def get_invalid_dt_common(): return [ None,