Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions pydantic_extra_types/coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,25 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, ClassVar, Tuple
from decimal import Decimal
from typing import Any, ClassVar, Tuple, Union

from pydantic import GetCoreSchemaHandler
from pydantic._internal import _repr
from pydantic_core import ArgsKwargs, PydanticCustomError, core_schema

LatitudeType = Union[float, Decimal]
LongitudeType = Union[float, Decimal]
CoordinateType = Tuple[LatitudeType, LongitudeType]


class Latitude(float):
"""Latitude value should be between -90 and 90, inclusive.

Supports both float and Decimal types.

```py
from decimal import Decimal
from pydantic import BaseModel
from pydantic_extra_types.coordinate import Latitude

Expand All @@ -25,9 +33,10 @@ class Location(BaseModel):
latitude: Latitude


location = Location(latitude=41.40338)
print(location)
# > latitude=41.40338
# Using float
location1 = Location(latitude=41.40338)
# Using Decimal
location2 = Location(latitude=Decimal('41.40338'))
```
"""

Expand All @@ -36,13 +45,21 @@ class Location(BaseModel):

@classmethod
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.float_schema(ge=cls.min, le=cls.max)
return core_schema.union_schema(
[
core_schema.float_schema(ge=cls.min, le=cls.max),
core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)),
]
)


class Longitude(float):
"""Longitude value should be between -180 and 180, inclusive.

Supports both float and Decimal types.

```py
from decimal import Decimal
from pydantic import BaseModel

from pydantic_extra_types.coordinate import Longitude
Expand All @@ -52,9 +69,10 @@ class Location(BaseModel):
longitude: Longitude


location = Location(longitude=2.17403)
print(location)
# > longitude=2.17403
# Using float
location1 = Location(longitude=2.17403)
# Using Decimal
location2 = Location(longitude=Decimal('2.17403'))
```
"""

Expand All @@ -63,7 +81,12 @@ class Location(BaseModel):

@classmethod
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
return core_schema.float_schema(ge=cls.min, le=cls.max)
return core_schema.union_schema(
[
core_schema.float_schema(ge=cls.min, le=cls.max),
core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)),
]
)


@dataclass
Expand All @@ -73,10 +96,11 @@ class Coordinate(_repr.Representation):
You can use the `Coordinate` data type for storing coordinates. Coordinates can be
defined using one of the following formats:

1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)`.
1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)` or `(Decimal('41.40338'), Decimal('2.17403'))`.
2. `Coordinate` instance: `Coordinate(latitude=Latitude, longitude=Longitude)`.

```py
from decimal import Decimal
from pydantic import BaseModel

from pydantic_extra_types.coordinate import Coordinate
Expand All @@ -86,7 +110,12 @@ class Location(BaseModel):
coordinate: Coordinate


location = Location(coordinate=(41.40338, 2.17403))
# Using float values
location1 = Location(coordinate=(41.40338, 2.17403))
# > coordinate=Coordinate(latitude=41.40338, longitude=2.17403)

# Using Decimal values
location2 = Location(coordinate=(Decimal('41.40338'), Decimal('2.17403')))
# > coordinate=Coordinate(latitude=41.40338, longitude=2.17403)
```
"""
Expand All @@ -102,7 +131,7 @@ def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaH
core_schema.no_info_wrap_validator_function(cls._parse_str, core_schema.str_schema()),
core_schema.no_info_wrap_validator_function(
cls._parse_tuple,
handler.generate_schema(Tuple[float, float]),
handler.generate_schema(CoordinateType),
),
handler(source),
]
Expand Down
97 changes: 86 additions & 11 deletions tests/test_coordinate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from decimal import Decimal
from re import Pattern
from typing import Any, Optional
from typing import Any, Optional, Union

import pytest
from pydantic import BaseModel, ValidationError
Expand Down Expand Up @@ -34,7 +35,14 @@ class Lng(BaseModel):
(Coordinate(latitude=0, longitude=0), (0, 0), None),
(ArgsKwargs(args=()), (0, 0), None),
(ArgsKwargs(args=(1, 0.0)), (1.0, 0), None),
# # Invalid coordinates
# Decimal test cases
((Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None),
((Decimal('-90.0'), Decimal('0.0')), (Decimal('-90.0'), Decimal('0.0')), None),
((Decimal('45.678'), Decimal('-123.456')), (Decimal('45.678'), Decimal('-123.456')), None),
(Coordinate(Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None),
(Coordinate(latitude=Decimal('0'), longitude=Decimal('0')), (Decimal('0'), Decimal('0')), None),
(ArgsKwargs(args=(Decimal('1'), Decimal('0.0'))), (Decimal('1.0'), Decimal('0.0')), None),
# Invalid coordinates
((), None, 'Field required'), # Empty tuple
((10.0,), None, 'Field required'), # Tuple with only one value
(('ten, '), None, 'string is not recognized as a valid coordinate'),
Expand All @@ -49,10 +57,11 @@ class Lng(BaseModel):
(2, None, 'Input should be a dictionary or an instance of Coordinate'), # Wrong type
],
)
def test_format_for_coordinate(coord: (Any, Any), result: (float, float), error: Optional[Pattern]):
def test_format_for_coordinate(
coord: (Any, Any), result: (Union[float, Decimal], Union[float, Decimal]), error: Optional[Pattern]
):
if error is None:
_coord: Coordinate = Coord(coord=coord).coord
print('vars(_coord)', vars(_coord))
assert _coord.latitude == result[0]
assert _coord.longitude == result[1]
else:
Expand All @@ -69,6 +78,16 @@ def test_format_for_coordinate(coord: (Any, Any), result: (float, float), error:
# Invalid coordinates
((-91.0, 0.0), 'Input should be greater than or equal to -90'),
((50.0, 181.0), 'Input should be less than or equal to 180'),
# Valid Decimal coordinates
((Decimal('-90.0'), Decimal('0.0')), None),
((Decimal('50.0'), Decimal('180.0')), None),
((Decimal('-89.999999'), Decimal('179.999999')), None),
((Decimal('0.0'), Decimal('0.0')), None),
# Invalid Decimal coordinates
((Decimal('-90.1'), Decimal('0.0')), 'Input should be greater than or equal to -90'),
((Decimal('50.0'), Decimal('180.1')), 'Input should be less than or equal to 180'),
((Decimal('90.1'), Decimal('0.0')), 'Input should be less than or equal to 90'),
((Decimal('0.0'), Decimal('-180.1')), 'Input should be greater than or equal to -180'),
],
)
def test_limit_for_coordinate(coord: (Any, Any), error: Optional[Pattern]):
Expand All @@ -91,17 +110,21 @@ def test_limit_for_coordinate(coord: (Any, Any), error: Optional[Pattern]):
('90.0', True),
(-90.0, True),
('-90.0', True),
(Decimal('90.0'), True),
(Decimal('-90.0'), True),
# Unvalid latitude
(91.0, False),
(-91.0, False),
(Decimal('91.0'), False),
(Decimal('-91.0'), False),
],
)
def test_format_latitude(latitude: float, valid: bool):
if valid:
_lat = Lat(lat=latitude).lat
assert _lat == float(latitude)
else:
with pytest.raises(ValidationError, match='1 validation error for Lat'):
with pytest.raises(ValidationError, match='2 validation errors for Lat'):
Lat(lat=latitude)


Expand All @@ -119,46 +142,89 @@ def test_format_latitude(latitude: float, valid: bool):
(-91.0, True),
(180.0, True),
(-180.0, True),
(Decimal('180.0'), True),
(Decimal('-180.0'), True),
# Unvalid latitude
(181.0, False),
(-181.0, False),
(Decimal('181.0'), False),
(Decimal('-181.0'), False),
],
)
def test_format_longitude(longitude: float, valid: bool):
if valid:
_lng = Lng(lng=longitude).lng
assert _lng == float(longitude)
else:
with pytest.raises(ValidationError, match='1 validation error for Lng'):
with pytest.raises(ValidationError, match='2 validation errors for Lng'):
Lng(lng=longitude)


def test_str_repr():
# Float tests
assert str(Coord(coord=(20.0, 10.0)).coord) == '20.0,10.0'
assert str(Coord(coord=('20.0, 10.0')).coord) == '20.0,10.0'
assert repr(Coord(coord=(20.0, 10.0)).coord) == 'Coordinate(latitude=20.0, longitude=10.0)'
# Decimal tests
assert str(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == '20.0,10.0'
assert str(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == '20.000,10.000'
assert (
repr(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord)
== "Coordinate(latitude=Decimal('20.0'), longitude=Decimal('10.0'))"
)


def test_eq():
# Float tests
assert Coord(coord=(20.0, 10.0)).coord != Coord(coord='20.0,11.0').coord
assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord
assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord
assert Coord(coord=(20.0, 10.0)).coord == Coord(coord='20.0,10.0').coord

# Decimal tests
assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord='20.0,10.0').coord
assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord=(20.0, 10.0)).coord
assert (
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord != Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord
)
assert (
Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord
== Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
)


def test_hashable():
# Float tests
assert hash(Coord(coord=(20.0, 10.0)).coord) == hash(Coord(coord=(20.0, 10.0)).coord)
assert hash(Coord(coord=(20.0, 11.0)).coord) != hash(Coord(coord=(20.0, 10.0)).coord)

# Decimal tests
assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash(
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
)
assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash(Coord(coord=(20.0, 10.0)).coord)
assert hash(Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord) != hash(
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
)
assert hash(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == hash(
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
)


def test_json_schema():
class Model(BaseModel):
value: Coordinate

assert Model.model_json_schema(mode='validation')['$defs']['Coordinate'] == {
'properties': {
'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'},
'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'},
'latitude': {
'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}],
'title': 'Latitude',
},
'longitude': {
'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}],
'title': 'Longitude',
},
},
'required': ['latitude', 'longitude'],
'title': 'Coordinate',
Expand All @@ -170,7 +236,10 @@ class Model(BaseModel):
{
'maxItems': 2,
'minItems': 2,
'prefixItems': [{'type': 'number'}, {'type': 'number'}],
'prefixItems': [
{'anyOf': [{'type': 'number'}, {'type': 'string'}]},
{'anyOf': [{'type': 'number'}, {'type': 'string'}]},
],
'type': 'array',
},
{'type': 'string'},
Expand All @@ -181,8 +250,14 @@ class Model(BaseModel):
'$defs': {
'Coordinate': {
'properties': {
'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'},
'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'},
'latitude': {
'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}],
'title': 'Latitude',
},
'longitude': {
'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}],
'title': 'Longitude',
},
},
'required': ['latitude', 'longitude'],
'title': 'Coordinate',
Expand Down
34 changes: 24 additions & 10 deletions tests/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,11 @@
{
'properties': {
'x': {
'maximum': 90.0,
'minimum': -90.0,
'anyOf': [
{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'},
{'type': 'string'},
],
'title': 'X',
'type': 'number',
}
},
'required': ['x'],
Expand All @@ -155,10 +156,11 @@
{
'properties': {
'x': {
'maximum': 180.0,
'minimum': -180.0,
'anyOf': [
{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'},
{'type': 'string'},
],
'title': 'X',
'type': 'number',
}
},
'required': ['x'],
Expand All @@ -172,8 +174,20 @@
'$defs': {
'Coordinate': {
'properties': {
'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'},
'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'},
'latitude': {
'anyOf': [
{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'},
{'type': 'string'},
],
'title': 'Latitude',
},
'longitude': {
'anyOf': [
{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'},
{'type': 'string'},
],
'title': 'Longitude',
},
},
'required': ['latitude', 'longitude'],
'title': 'Coordinate',
Expand All @@ -188,8 +202,8 @@
'maxItems': 2,
'minItems': 2,
'prefixItems': [
{'type': 'number'},
{'type': 'number'},
{'anyOf': [{'type': 'number'}, {'type': 'string'}]},
{'anyOf': [{'type': 'number'}, {'type': 'string'}]},
],
'type': 'array',
},
Expand Down