Skip to content

Commit dc5c869

Browse files
committed
feat: Support Decimal type in coordinates
1 parent 6a66fa7 commit dc5c869

File tree

2 files changed

+124
-22
lines changed

2 files changed

+124
-22
lines changed

pydantic_extra_types/coordinate.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,25 @@
66
from __future__ import annotations
77

88
from dataclasses import dataclass
9-
from typing import Any, ClassVar, Tuple
9+
from decimal import Decimal
10+
from typing import Any, ClassVar, Tuple, Union
1011

1112
from pydantic import GetCoreSchemaHandler
1213
from pydantic._internal import _repr
1314
from pydantic_core import ArgsKwargs, PydanticCustomError, core_schema
1415

16+
LatitudeType = Union[float, Decimal]
17+
LongitudeType = Union[float, Decimal]
18+
CoordinateType = Tuple[LatitudeType, LongitudeType]
19+
1520

1621
class Latitude(float):
1722
"""Latitude value should be between -90 and 90, inclusive.
1823
24+
Supports both float and Decimal types.
25+
1926
```py
27+
from decimal import Decimal
2028
from pydantic import BaseModel
2129
from pydantic_extra_types.coordinate import Latitude
2230
@@ -25,9 +33,10 @@ class Location(BaseModel):
2533
latitude: Latitude
2634
2735
28-
location = Location(latitude=41.40338)
29-
print(location)
30-
# > latitude=41.40338
36+
# Using float
37+
location1 = Location(latitude=41.40338)
38+
# Using Decimal
39+
location2 = Location(latitude=Decimal('41.40338'))
3140
```
3241
"""
3342

@@ -36,13 +45,21 @@ class Location(BaseModel):
3645

3746
@classmethod
3847
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
39-
return core_schema.float_schema(ge=cls.min, le=cls.max)
48+
return core_schema.union_schema(
49+
[
50+
core_schema.float_schema(ge=cls.min, le=cls.max),
51+
core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)),
52+
]
53+
)
4054

4155

4256
class Longitude(float):
4357
"""Longitude value should be between -180 and 180, inclusive.
4458
59+
Supports both float and Decimal types.
60+
4561
```py
62+
from decimal import Decimal
4663
from pydantic import BaseModel
4764
4865
from pydantic_extra_types.coordinate import Longitude
@@ -52,9 +69,10 @@ class Location(BaseModel):
5269
longitude: Longitude
5370
5471
55-
location = Location(longitude=2.17403)
56-
print(location)
57-
# > longitude=2.17403
72+
# Using float
73+
location1 = Location(longitude=2.17403)
74+
# Using Decimal
75+
location2 = Location(longitude=Decimal('2.17403'))
5876
```
5977
"""
6078

@@ -63,7 +81,12 @@ class Location(BaseModel):
6381

6482
@classmethod
6583
def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaHandler) -> core_schema.CoreSchema:
66-
return core_schema.float_schema(ge=cls.min, le=cls.max)
84+
return core_schema.union_schema(
85+
[
86+
core_schema.float_schema(ge=cls.min, le=cls.max),
87+
core_schema.decimal_schema(ge=Decimal(cls.min), le=Decimal(cls.max)),
88+
]
89+
)
6790

6891

6992
@dataclass
@@ -73,10 +96,11 @@ class Coordinate(_repr.Representation):
7396
You can use the `Coordinate` data type for storing coordinates. Coordinates can be
7497
defined using one of the following formats:
7598
76-
1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)`.
99+
1. Tuple: `(Latitude, Longitude)`. For example: `(41.40338, 2.17403)` or `(Decimal('41.40338'), Decimal('2.17403'))`.
77100
2. `Coordinate` instance: `Coordinate(latitude=Latitude, longitude=Longitude)`.
78101
79102
```py
103+
from decimal import Decimal
80104
from pydantic import BaseModel
81105
82106
from pydantic_extra_types.coordinate import Coordinate
@@ -86,7 +110,12 @@ class Location(BaseModel):
86110
coordinate: Coordinate
87111
88112
89-
location = Location(coordinate=(41.40338, 2.17403))
113+
# Using float values
114+
location1 = Location(coordinate=(41.40338, 2.17403))
115+
# > coordinate=Coordinate(latitude=41.40338, longitude=2.17403)
116+
117+
# Using Decimal values
118+
location2 = Location(coordinate=(Decimal('41.40338'), Decimal('2.17403')))
90119
# > coordinate=Coordinate(latitude=41.40338, longitude=2.17403)
91120
```
92121
"""
@@ -102,7 +131,7 @@ def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaH
102131
core_schema.no_info_wrap_validator_function(cls._parse_str, core_schema.str_schema()),
103132
core_schema.no_info_wrap_validator_function(
104133
cls._parse_tuple,
105-
handler.generate_schema(Tuple[float, float]),
134+
handler.generate_schema(CoordinateType),
106135
),
107136
handler(source),
108137
]

tests/test_coordinate.py

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from decimal import Decimal
12
from re import Pattern
23
from typing import Any, Optional
34

@@ -34,7 +35,14 @@ class Lng(BaseModel):
3435
(Coordinate(latitude=0, longitude=0), (0, 0), None),
3536
(ArgsKwargs(args=()), (0, 0), None),
3637
(ArgsKwargs(args=(1, 0.0)), (1.0, 0), None),
37-
# # Invalid coordinates
38+
# Decimal test cases
39+
((Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None),
40+
((Decimal('-90.0'), Decimal('0.0')), (Decimal('-90.0'), Decimal('0.0')), None),
41+
((Decimal('45.678'), Decimal('-123.456')), (Decimal('45.678'), Decimal('-123.456')), None),
42+
(Coordinate(Decimal('20.0'), Decimal('10.0')), (Decimal('20.0'), Decimal('10.0')), None),
43+
(Coordinate(latitude=Decimal('0'), longitude=Decimal('0')), (Decimal('0'), Decimal('0')), None),
44+
(ArgsKwargs(args=(Decimal('1'), Decimal('0.0'))), (Decimal('1.0'), Decimal('0.0')), None),
45+
# Invalid coordinates
3846
((), None, 'Field required'), # Empty tuple
3947
((10.0,), None, 'Field required'), # Tuple with only one value
4048
(('ten, '), None, 'string is not recognized as a valid coordinate'),
@@ -49,10 +57,9 @@ class Lng(BaseModel):
4957
(2, None, 'Input should be a dictionary or an instance of Coordinate'), # Wrong type
5058
],
5159
)
52-
def test_format_for_coordinate(coord: (Any, Any), result: (float, float), error: Optional[Pattern]):
60+
def test_format_for_coordinate(coord: (Any, Any), result: (float | Decimal, float | Decimal), error: Optional[Pattern]):
5361
if error is None:
5462
_coord: Coordinate = Coord(coord=coord).coord
55-
print('vars(_coord)', vars(_coord))
5663
assert _coord.latitude == result[0]
5764
assert _coord.longitude == result[1]
5865
else:
@@ -69,6 +76,16 @@ def test_format_for_coordinate(coord: (Any, Any), result: (float, float), error:
6976
# Invalid coordinates
7077
((-91.0, 0.0), 'Input should be greater than or equal to -90'),
7178
((50.0, 181.0), 'Input should be less than or equal to 180'),
79+
# Valid Decimal coordinates
80+
((Decimal('-90.0'), Decimal('0.0')), None),
81+
((Decimal('50.0'), Decimal('180.0')), None),
82+
((Decimal('-89.999999'), Decimal('179.999999')), None),
83+
((Decimal('0.0'), Decimal('0.0')), None),
84+
# Invalid Decimal coordinates
85+
((Decimal('-90.1'), Decimal('0.0')), 'Input should be greater than or equal to -90'),
86+
((Decimal('50.0'), Decimal('180.1')), 'Input should be less than or equal to 180'),
87+
((Decimal('90.1'), Decimal('0.0')), 'Input should be less than or equal to 90'),
88+
((Decimal('0.0'), Decimal('-180.1')), 'Input should be greater than or equal to -180'),
7289
],
7390
)
7491
def test_limit_for_coordinate(coord: (Any, Any), error: Optional[Pattern]):
@@ -91,17 +108,21 @@ def test_limit_for_coordinate(coord: (Any, Any), error: Optional[Pattern]):
91108
('90.0', True),
92109
(-90.0, True),
93110
('-90.0', True),
111+
(Decimal('90.0'), True),
112+
(Decimal('-90.0'), True),
94113
# Unvalid latitude
95114
(91.0, False),
96115
(-91.0, False),
116+
(Decimal('91.0'), False),
117+
(Decimal('-91.0'), False),
97118
],
98119
)
99120
def test_format_latitude(latitude: float, valid: bool):
100121
if valid:
101122
_lat = Lat(lat=latitude).lat
102123
assert _lat == float(latitude)
103124
else:
104-
with pytest.raises(ValidationError, match='1 validation error for Lat'):
125+
with pytest.raises(ValidationError, match='2 validation errors for Lat'):
105126
Lat(lat=latitude)
106127

107128

@@ -119,46 +140,89 @@ def test_format_latitude(latitude: float, valid: bool):
119140
(-91.0, True),
120141
(180.0, True),
121142
(-180.0, True),
143+
(Decimal('180.0'), True),
144+
(Decimal('-180.0'), True),
122145
# Unvalid latitude
123146
(181.0, False),
124147
(-181.0, False),
148+
(Decimal('181.0'), False),
149+
(Decimal('-181.0'), False),
125150
],
126151
)
127152
def test_format_longitude(longitude: float, valid: bool):
128153
if valid:
129154
_lng = Lng(lng=longitude).lng
130155
assert _lng == float(longitude)
131156
else:
132-
with pytest.raises(ValidationError, match='1 validation error for Lng'):
157+
with pytest.raises(ValidationError, match='2 validation errors for Lng'):
133158
Lng(lng=longitude)
134159

135160

136161
def test_str_repr():
162+
# Float tests
137163
assert str(Coord(coord=(20.0, 10.0)).coord) == '20.0,10.0'
138164
assert str(Coord(coord=('20.0, 10.0')).coord) == '20.0,10.0'
139165
assert repr(Coord(coord=(20.0, 10.0)).coord) == 'Coordinate(latitude=20.0, longitude=10.0)'
166+
# Decimal tests
167+
assert str(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == '20.0,10.0'
168+
assert str(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == '20.000,10.000'
169+
assert (
170+
repr(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord)
171+
== "Coordinate(latitude=Decimal('20.0'), longitude=Decimal('10.0'))"
172+
)
140173

141174

142175
def test_eq():
176+
# Float tests
143177
assert Coord(coord=(20.0, 10.0)).coord != Coord(coord='20.0,11.0').coord
144178
assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord
145179
assert Coord(coord=('20.0, 10.0')).coord != Coord(coord='20.0,11.0').coord
146180
assert Coord(coord=(20.0, 10.0)).coord == Coord(coord='20.0,10.0').coord
147181

182+
# Decimal tests
183+
assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord='20.0,10.0').coord
184+
assert Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord == Coord(coord=(20.0, 10.0)).coord
185+
assert (
186+
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord != Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord
187+
)
188+
assert (
189+
Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord
190+
== Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
191+
)
192+
148193

149194
def test_hashable():
195+
# Float tests
150196
assert hash(Coord(coord=(20.0, 10.0)).coord) == hash(Coord(coord=(20.0, 10.0)).coord)
151197
assert hash(Coord(coord=(20.0, 11.0)).coord) != hash(Coord(coord=(20.0, 10.0)).coord)
152198

199+
# Decimal tests
200+
assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash(
201+
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
202+
)
203+
assert hash(Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord) == hash(Coord(coord=(20.0, 10.0)).coord)
204+
assert hash(Coord(coord=(Decimal('20.0'), Decimal('11.0'))).coord) != hash(
205+
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
206+
)
207+
assert hash(Coord(coord=(Decimal('20.000'), Decimal('10.000'))).coord) == hash(
208+
Coord(coord=(Decimal('20.0'), Decimal('10.0'))).coord
209+
)
210+
153211

154212
def test_json_schema():
155213
class Model(BaseModel):
156214
value: Coordinate
157215

158216
assert Model.model_json_schema(mode='validation')['$defs']['Coordinate'] == {
159217
'properties': {
160-
'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'},
161-
'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'},
218+
'latitude': {
219+
'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}],
220+
'title': 'Latitude',
221+
},
222+
'longitude': {
223+
'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}],
224+
'title': 'Longitude',
225+
},
162226
},
163227
'required': ['latitude', 'longitude'],
164228
'title': 'Coordinate',
@@ -170,7 +234,10 @@ class Model(BaseModel):
170234
{
171235
'maxItems': 2,
172236
'minItems': 2,
173-
'prefixItems': [{'type': 'number'}, {'type': 'number'}],
237+
'prefixItems': [
238+
{'anyOf': [{'type': 'number'}, {'type': 'string'}]},
239+
{'anyOf': [{'type': 'number'}, {'type': 'string'}]},
240+
],
174241
'type': 'array',
175242
},
176243
{'type': 'string'},
@@ -181,8 +248,14 @@ class Model(BaseModel):
181248
'$defs': {
182249
'Coordinate': {
183250
'properties': {
184-
'latitude': {'maximum': 90.0, 'minimum': -90.0, 'title': 'Latitude', 'type': 'number'},
185-
'longitude': {'maximum': 180.0, 'minimum': -180.0, 'title': 'Longitude', 'type': 'number'},
251+
'latitude': {
252+
'anyOf': [{'maximum': 90.0, 'minimum': -90.0, 'type': 'number'}, {'type': 'string'}],
253+
'title': 'Latitude',
254+
},
255+
'longitude': {
256+
'anyOf': [{'maximum': 180.0, 'minimum': -180.0, 'type': 'number'}, {'type': 'string'}],
257+
'title': 'Longitude',
258+
},
186259
},
187260
'required': ['latitude', 'longitude'],
188261
'title': 'Coordinate',

0 commit comments

Comments
 (0)