diff --git a/pydantic_extra_types/coordinate.py b/pydantic_extra_types/coordinate.py index cfd14fe..a709eb2 100644 --- a/pydantic_extra_types/coordinate.py +++ b/pydantic_extra_types/coordinate.py @@ -6,17 +6,25 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, ClassVar, Tuple +from decimal import Decimal +from typing import Any, ClassVar, Tuple, Union from pydantic import GetCoreSchemaHandler from pydantic._internal import _repr from pydantic_core import ArgsKwargs, PydanticCustomError, core_schema +LatitudeType = Union[float, Decimal] +LongitudeType = Union[float, Decimal] +CoordinateType = Tuple[LatitudeType, LongitudeType] + class Latitude(float): """Latitude value should be between -90 and 90, inclusive. + Supports both float and Decimal types. + ```py + from decimal import Decimal from pydantic import BaseModel from pydantic_extra_types.coordinate import Latitude @@ -25,9 +33,10 @@ class Location(BaseModel): latitude: Latitude - location = Location(latitude=41.40338) - print(location) - # > latitude=41.40338 + # Using float + location1 = Location(latitude=41.40338) + # Using Decimal + location2 = Location(latitude=Decimal('41.40338')) ``` """ @@ -36,13 +45,21 @@ class Location(BaseModel): @classmethod def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - return core_schema.float_schema(ge=cls.min, le=cls.max) + return core_schema.union_schema( + [ + core_schema.float_schema(ge=cls.min, le=cls.max), + core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)), + ] + ) class Longitude(float): """Longitude value should be between -180 and 180, inclusive. + Supports both float and Decimal types. + ```py + from decimal import Decimal from pydantic import BaseModel from pydantic_extra_types.coordinate import Longitude @@ -52,9 +69,10 @@ class Location(BaseModel): longitude: Longitude - location = Location(longitude=2.17403) - print(location) - # > longitude=2.17403 + # Using float + location1 = Location(longitude=2.17403) + # Using Decimal + location2 = Location(longitude=Decimal('2.17403')) ``` """ @@ -63,7 +81,12 @@ class Location(BaseModel): @classmethod def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema: - return core_schema.float_schema(ge=cls.min, le=cls.max) + return core_schema.union_schema( + [ + core_schema.float_schema(ge=cls.min, le=cls.max), + core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)), + ] + ) @dataclass @@ -73,10 +96,11 @@ class Coordinate(_repr.Representation): You can use the `Coordinate` data type for storing coordinates. Coordinates can be defined using one of the following formats: - 1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)`. + 1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)` or `(Decimal('41.40338'), Decimal('2.17403'))`. 2. `Coordinate` instance: `Coordinate(latitude=Latitude, longitude=Longitude)`. ```py + from decimal import Decimal from pydantic import BaseModel from pydantic_extra_types.coordinate import Coordinate @@ -86,7 +110,12 @@ class Location(BaseModel): coordinate: Coordinate - location = Location(coordinate=(41.40338, 2.17403)) + # Using float values + location1 = Location(coordinate=(41.40338, 2.17403)) + # > coordinate=Coordinate(latitude=41.40338, longitude=2.17403) + + # Using Decimal values + location2 = Location(coordinate=(Decimal('41.40338'), Decimal('2.17403'))) # > coordinate=Coordinate(latitude=41.40338, longitude=2.17403) ``` """ @@ -102,7 +131,7 @@ def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaH core_schema.no_info_wrap_validator_function(cls._parse_str, core_schema.str_schema()), core_schema.no_info_wrap_validator_function( cls._parse_tuple, - handler.generate_schema(Tuple[float, float]), + handler.generate_schema(CoordinateType), ), handler(source), ] diff --git a/tests/test_coordinate.py b/tests/test_coordinate.py index e75ab86..eb9e1c0 100644 --- a/tests/test_coordinate.py +++ b/tests/test_coordinate.py @@ -1,5 +1,6 @@ +from decimal import Decimal from re import Pattern -from typing import Any, Optional +from typing import Any, Optional, Union import pytest from pydantic import BaseModel, ValidationError @@ -34,7 +35,14 @@ class Lng(BaseModel): (Coordinate(latitude=0, longitude=0), (0, 0), None), (ArgsKwargs(args=()), (0, 0), None), (ArgsKwargs(args=(1, 0.0)), (1.0, 0), None), - # # Invalid coordinates + # Decimal test cases + ((Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None), + ((Decimal('-90.0'), Decimal('0.0')), (Decimal('-90.0'), Decimal('0.0')), None), + ((Decimal('45.678'), Decimal('-123.456')), (Decimal('45.678'), Decimal('-123.456')), None), + (Coordinate(Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None), + (Coordinate(latitude=Decimal('0'), longitude=Decimal('0')), (Decimal('0'), Decimal('0')), None), + (ArgsKwargs(args=(Decimal('1'), Decimal('0.0'))), (Decimal('1.0'), Decimal('0.0')), None), + # Invalid coordinates ((), None, 'Field required'), # Empty tuple ((10.0,), None, 'Field required'), # Tuple with only one value (('ten, '), None, 'string is not recognized as a valid coordinate'), @@ -49,10 +57,11 @@ class Lng(BaseModel): (2, None, 'Input should be a dictionary or an instance of Coordinate'), # Wrong type ], ) -def test_format_for_coordinate(coord: (Any, Any), result: (float, float), error: Optional[Pattern]): +def test_format_for_coordinate( + coord: (Any, Any), result: (Union[float, Decimal], Union[float, Decimal]), error: Optional[Pattern] +): if error is None: _coord: Coordinate = Coord(coord=coord).coord - print('vars(_coord)', vars(_coord)) assert _coord.latitude == result[0] assert _coord.longitude == result[1] else: @@ -69,6 +78,16 @@ def test_format_for_coordinate(coord: (Any, Any), result: (float, float), error: # Invalid coordinates ((-91.0, 0.0), 'Input should be greater than or equal to -90'), ((50.0, 181.0), 'Input should be less than or equal to 180'), + # Valid Decimal coordinates + ((Decimal('-90.0'), Decimal('0.0')), None), + ((Decimal('50.0'), Decimal('180.0')), None), + ((Decimal('-89.999999'), Decimal('179.999999')), None), + ((Decimal('0.0'), Decimal('0.0')), None), + # Invalid Decimal coordinates + ((Decimal('-90.1'), Decimal('0.0')), 'Input should be greater than or equal to -90'), + ((Decimal('50.0'), Decimal('180.1')), 'Input should be less than or equal to 180'), + ((Decimal('90.1'), Decimal('0.0')), 'Input should be less than or equal to 90'), + ((Decimal('0.0'), Decimal('-180.1')), 'Input should be greater than or equal to -180'), ], ) def test_limit_for_coordinate(coord: (Any, Any), error: Optional[Pattern]): @@ -91,9 +110,13 @@ def test_limit_for_coordinate(coord: (Any, Any), error: Optional[Pattern]): ('90.0', True), (-90.0, True), ('-90.0', True), + (Decimal('90.0'), True), + (Decimal('-90.0'), True), # Unvalid latitude (91.0, False), (-91.0, False), + (Decimal('91.0'), False), + (Decimal('-91.0'), False), ], ) def test_format_latitude(latitude: float, valid: bool): @@ -101,7 +124,7 @@ def test_format_latitude(latitude: float, valid: bool): _lat = Lat(lat=latitude).lat assert _lat == float(latitude) else: - with pytest.raises(ValidationError, match='1 validation error for Lat'): + with pytest.raises(ValidationError, match='2 validation errors for Lat'): Lat(lat=latitude) @@ -119,9 +142,13 @@ def test_format_latitude(latitude: float, valid: bool): (-91.0, True), (180.0, True), (-180.0, True), + (Decimal('180.0'), True), + (Decimal('-180.0'), True), # Unvalid latitude (181.0, False), (-181.0, False), + (Decimal('181.0'), False), + (Decimal('-181.0'), False), ], ) def test_format_longitude(longitude: float, valid: bool): @@ -129,27 +156,60 @@ def test_format_longitude(longitude: float, valid: bool): _lng = Lng(lng=longitude).lng assert _lng == float(longitude) else: - with pytest.raises(ValidationError, match='1 validation error for Lng'): + with pytest.raises(ValidationError, match='2 validation errors for Lng'): Lng(lng=longitude) def test_str_repr(): + # Float tests assert str(Coord(coord=(20.0, 10.0)).coord) == '20.0,10.0' assert str(Coord(coord=('20.0, 10.0')).coord) == '20.0,10.0' assert repr(Coord(coord=(20.0, 10.0)).coord) == 'Coordinate(latitude=20.0, longitude=10.0)' + # Decimal tests + assert str(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == '20.0,10.0' + assert str(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == '20.000,10.000' + assert ( + repr(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) + == "Coordinate(latitude=Decimal('20.0'), longitude=Decimal('10.0'))" + ) def test_eq(): + # Float tests assert Coord(coord=(20.0, 10.0)).coord != Coord(coord='20.0,11.0').coord assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord assert Coord(coord=(20.0, 10.0)).coord == Coord(coord='20.0,10.0').coord + # Decimal tests + assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord='20.0,10.0').coord + assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord=(20.0, 10.0)).coord + assert ( + Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord != Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord + ) + assert ( + Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord + == Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord + ) + def test_hashable(): + # Float tests assert hash(Coord(coord=(20.0, 10.0)).coord) == hash(Coord(coord=(20.0, 10.0)).coord) assert hash(Coord(coord=(20.0, 11.0)).coord) != hash(Coord(coord=(20.0, 10.0)).coord) + # Decimal tests + assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash( + Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord + ) + assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash(Coord(coord=(20.0, 10.0)).coord) + assert hash(Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord) != hash( + Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord + ) + assert hash(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == hash( + Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord + ) + def test_json_schema(): class Model(BaseModel): @@ -157,8 +217,14 @@ class Model(BaseModel): assert Model.model_json_schema(mode='validation')['$defs']['Coordinate'] == { 'properties': { - 'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'}, - 'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'}, + 'latitude': { + 'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}], + 'title': 'Latitude', + }, + 'longitude': { + 'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}], + 'title': 'Longitude', + }, }, 'required': ['latitude', 'longitude'], 'title': 'Coordinate', @@ -170,7 +236,10 @@ class Model(BaseModel): { 'maxItems': 2, 'minItems': 2, - 'prefixItems': [{'type': 'number'}, {'type': 'number'}], + 'prefixItems': [ + {'anyOf': [{'type': 'number'}, {'type': 'string'}]}, + {'anyOf': [{'type': 'number'}, {'type': 'string'}]}, + ], 'type': 'array', }, {'type': 'string'}, @@ -181,8 +250,14 @@ class Model(BaseModel): '$defs': { 'Coordinate': { 'properties': { - 'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'}, - 'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'}, + 'latitude': { + 'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}], + 'title': 'Latitude', + }, + 'longitude': { + 'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}], + 'title': 'Longitude', + }, }, 'required': ['latitude', 'longitude'], 'title': 'Coordinate', diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index f7b7e83..36f80e8 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -139,10 +139,11 @@ { 'properties': { 'x': { - 'maximum': 90.0, - 'minimum': -90.0, + 'anyOf': [ + {'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, + {'type': 'string'}, + ], 'title': 'X', - 'type': 'number', } }, 'required': ['x'], @@ -155,10 +156,11 @@ { 'properties': { 'x': { - 'maximum': 180.0, - 'minimum': -180.0, + 'anyOf': [ + {'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, + {'type': 'string'}, + ], 'title': 'X', - 'type': 'number', } }, 'required': ['x'], @@ -172,8 +174,20 @@ '$defs': { 'Coordinate': { 'properties': { - 'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'}, - 'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'}, + 'latitude': { + 'anyOf': [ + {'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, + {'type': 'string'}, + ], + 'title': 'Latitude', + }, + 'longitude': { + 'anyOf': [ + {'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, + {'type': 'string'}, + ], + 'title': 'Longitude', + }, }, 'required': ['latitude', 'longitude'], 'title': 'Coordinate', @@ -188,8 +202,8 @@ 'maxItems': 2, 'minItems': 2, 'prefixItems': [ - {'type': 'number'}, - {'type': 'number'}, + {'anyOf': [{'type': 'number'}, {'type': 'string'}]}, + {'anyOf': [{'type': 'number'}, {'type': 'string'}]}, ], 'type': 'array', },