Skip to content

Commit a30990f

Browse files
authored
Merge pull request #12 from RSS-Engineering/submodel-deserialization-fixes
Submodel Deserialization Fixes
2 parents 73b70db + 97b6514 commit a30990f

File tree

5 files changed

+112
-53
lines changed

5 files changed

+112
-53
lines changed

pydanticrud/backends/dynamodb.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from typing import Optional, Set
1+
from typing import Optional, Set, Union
22
import logging
33
import json
44
from datetime import datetime
5-
from base64 import b64encode, b64decode
65

76
import boto3
87
from boto3.dynamodb.conditions import Key, Attr
@@ -80,25 +79,19 @@ def _to_epoch_float(dt):
8079

8180
SERIALIZE_MAP = {
8281
"number": str, # float or decimal
83-
"string": lambda d: d.isoformat() if isinstance(d, datetime) else d, # string, datetime
82+
"string": str,
83+
"string:date-time": lambda d: d.isoformat(),
8484
"boolean": lambda d: 1 if d else 0,
8585
"object": json.dumps,
8686
"array": json.dumps,
87-
"anyOf": str, # FIXME - this could be more complicated. This is a hacky fix.
8887
}
8988

9089

91-
def do_nothing(x):
92-
return x
93-
94-
9590
DESERIALIZE_MAP = {
9691
"number": float,
97-
"string": do_nothing,
9892
"boolean": bool,
9993
"object": json.loads,
10094
"array": json.loads,
101-
"anyOf": do_nothing, # FIXME - this could be more complicated. This is a hacky fix.
10295
}
10396

10497

@@ -120,22 +113,48 @@ def index_definition(index_name, keys, gsi=False):
120113

121114
class DynamoSerializer:
122115
def __init__(self, schema):
123-
self.schema = schema
116+
self.properties = schema["properties"]
117+
self.definitions = schema.get("definitions")
118+
119+
def _get_type_possibilities(self, field_name) -> Set[tuple]:
120+
field_properties = self.properties[field_name]
121+
122+
possible_types = []
123+
if "anyOf" in field_properties:
124+
possible_types.extend([r.get("$ref", r) for r in field_properties["anyOf"]])
125+
else:
126+
possible_types.append(field_properties.get("$ref", field_properties))
127+
128+
def type_from_definition(definition_signature: Union[str, dict]) -> dict:
129+
if isinstance(definition_signature, str):
130+
t = definition_signature.split('/')[-1]
131+
return self.definitions[t]
132+
return definition_signature
133+
134+
type_dicts = [
135+
type_from_definition(t)
136+
for t in possible_types
137+
]
138+
139+
return set([
140+
(t['type'], t.get('format', ''))
141+
for t in type_dicts
142+
])
124143

125144
def _serialize_field(self, field_name, value):
126-
definition = self.schema.get("definitions")
127-
schema = self.schema["properties"]
128-
if definition:
129-
for k, v in definition.items():
130-
schema[k.lower()] = v
131-
schema = self.schema["properties"]
132-
field_type = schema[field_name].get("type", "anyOf")
133-
try:
134-
if any([field_name in self.schema['required'], value is not None]):
135-
return SERIALIZE_MAP[field_type](value)
136-
except KeyError:
137-
log.debug(f"No serializer for field_type {field_type}")
138-
return value # do nothing but log it.
145+
field_types = self._get_type_possibilities(field_name)
146+
if value is not None:
147+
for t in field_types:
148+
try:
149+
type_signature = ":".join(t).rstrip(':')
150+
try:
151+
return SERIALIZE_MAP[type_signature](value)
152+
except KeyError:
153+
return SERIALIZE_MAP[t[0]](value)
154+
except (ValueError, TypeError, KeyError):
155+
pass
156+
157+
return value
139158

140159
def serialize_record(self, data_dict) -> dict:
141160
"""
@@ -147,18 +166,19 @@ def serialize_record(self, data_dict) -> dict:
147166
}
148167

149168
def _deserialize_field(self, field_name, value):
150-
definition = self.schema.get("definitions")
151-
schema = self.schema["properties"]
152-
if definition:
153-
for k, v in definition.items():
154-
schema[k.lower()] = v
155-
field_type = schema[field_name].get("type", "anyOf")
156-
try:
157-
if any([field_name in self.schema['required'], value is not None]):
158-
return DESERIALIZE_MAP[field_type](value)
159-
except KeyError:
160-
log.debug(f"No deserializer for field_type {field_type}")
161-
return value # do nothing but log it.
169+
field_types = self._get_type_possibilities(field_name)
170+
if value is not None:
171+
for t in field_types:
172+
try:
173+
type_signature = ":".join(t).rstrip(':')
174+
try:
175+
return DESERIALIZE_MAP[type_signature](value)
176+
except KeyError:
177+
return DESERIALIZE_MAP[t[0]](value)
178+
except (ValueError, TypeError, KeyError):
179+
pass
180+
181+
return value
162182

163183
def deserialize_record(self, data_dict) -> dict:
164184
"""
@@ -213,7 +233,7 @@ def _key_param_to_dict(self, key):
213233
}
214234
if self.range_key:
215235
if not isinstance(key, tuple) or not len(key) == 2:
216-
raise ValueError(f"{self.table_name} needs both a hash_key and a range_key to delete a record.")
236+
raise ValueError(f"{self.table_name} needs both a hash_key and a range_key.")
217237
_key = {
218238
self.hash_key: key[0],
219239
self.range_key: key[1]
@@ -256,7 +276,7 @@ def initialize(self):
256276
{
257277
"AttributeName": attr,
258278
"AttributeType": DYNAMO_TYPE_MAP.get(
259-
schema["properties"][attr].get("type", "anyOf"), "S"
279+
schema["properties"][attr].get("type"), "S"
260280
),
261281
}
262282
for attr in self.possible_keys

pydanticrud/main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pydantic.error_wrappers
12
from pydantic import BaseModel as PydanticBaseModel
23
from pydantic.main import ModelMetaclass
34
from rule_engine import Rule
@@ -25,6 +26,9 @@ def __len__(self):
2526
def __iter__(self):
2627
return self
2728

29+
def __getitem__(self, indices):
30+
return self.records.__getitem__(indices)
31+
2832
def __next__(self):
2933
try:
3034
member = self.records[self._current_index]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pydanticrud"
3-
version = "0.1.0"
3+
version = "0.2.0"
44
description = "Supercharge your Pydantic models with CRUD methods and a pluggable backend"
55
authors = ["Timothy Farrell <[email protected]>"]
66
license = "MIT"

tests/test_dynamodb.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import Dict, List, Optional
2-
from base64 import b64decode, b64encode
1+
from typing import Dict, List, Optional, Union
32
from decimal import Decimal
43
from datetime import datetime
5-
import json
6-
from uuid import uuid4
4+
from uuid import uuid4, UUID
75
import random
86

97
import docker
@@ -32,6 +30,7 @@ class SimpleKeyModel(BaseModel):
3230
enabled: bool
3331
data: Dict[int, int] = None
3432
items: List[int]
33+
hash: UUID
3534

3635
class Config:
3736
title = "ModelTitle123"
@@ -80,11 +79,17 @@ class Ticket(PydanticBaseModel):
8079
number: str
8180

8281

82+
class SomethingElse(PydanticBaseModel):
83+
herp: bool
84+
derp: int
85+
86+
8387
class NestedModel(BaseModel):
8488
account: str
8589
sort_date_key: str
8690
expires: str
8791
ticket: Optional[Ticket]
92+
other: Union[Ticket, SomethingElse]
8893

8994
class Config:
9095
title = "NestedModelTitle123"
@@ -116,6 +121,7 @@ def simple_model_data_generator(**kwargs):
116121
enabled=random.choice((True, False)),
117122
data={random.randint(0, 1000): random.randint(0, 1000)},
118123
items=[random.randint(0, 100000), random.randint(0, 100000), random.randint(0, 100000)],
124+
hash=uuid4()
119125
)
120126
data.update(kwargs)
121127
return data
@@ -141,9 +147,20 @@ def nested_model_data_generator(include_ticket=True, **kwargs):
141147
expires=future_datetime(days=1, hours=random.randint(1, 12), minutes=random.randint(1, 58)).isoformat(),
142148
ticket={
143149
'created_time': random_datetime().isoformat(),
144-
'number': random.randint(0, 1000)
150+
'number': str(random.randint(0, 1000))
151+
152+
} if include_ticket else None,
153+
other=random.choice([
154+
{
155+
'created_time': random_datetime().isoformat(),
156+
'number': str(random.randint(0, 1000))
157+
158+
}, {
159+
'herp': random.choice([True, False]),
160+
'derp': random.randint(0, 1000)
145161

146-
} if include_ticket else None
162+
}
163+
])
147164
)
148165
data.update(kwargs)
149166
return data
@@ -243,7 +260,7 @@ def alias_query_data(alias_table):
243260
AliasKeyModel.delete(datum["name"])
244261

245262

246-
@pytest.fixture(scope="module")
263+
@pytest.fixture
247264
def nested_query_data(nested_table):
248265
presets = [dict()] * 5
249266
data = [datum for datum in [nested_model_data_generator(**i) for i in presets]]
@@ -257,7 +274,7 @@ def nested_query_data(nested_table):
257274
NestedModel.delete((datum[NestedModel.Config.hash_key], datum[NestedModel.Config.range_key]))
258275

259276

260-
@pytest.fixture(scope="module")
277+
@pytest.fixture
261278
def nested_query_data_empty_ticket(nested_table):
262279
presets = [dict()] * 5
263280
data = [datum for datum in [nested_model_data_generator(include_ticket=False, **i) for i in presets]]
@@ -460,14 +477,16 @@ def test_query_scan_complex(dynamo, complex_query_data):
460477

461478
def test_query_with_nested_model(dynamo, nested_query_data):
462479
res = NestedModel.query()
463-
res_data = [m.ticket for m in res]
464-
assert any(elem is not None for elem in res_data)
480+
for m in res:
481+
assert isinstance(m.ticket, Ticket)
482+
assert m.ticket.created_time is not None
483+
assert m.ticket.number is not None
484+
assert isinstance(m.other, (Ticket, SomethingElse))
465485

466486

467487
def test_query_with_nested_model_optional(dynamo, nested_query_data_empty_ticket):
468488
res = NestedModel.query()
469-
res_data = [m.ticket for m in res]
470-
assert any(elem is None for elem in res_data)
489+
assert all([m.ticket is None for m in res])
471490

472491

473492
def test_query_alias_save(dynamo):
@@ -485,5 +504,3 @@ def test_get_alias_model_data(dynamo, alias_query_data):
485504
data = alias_model_data_generator()
486505
res = AliasKeyModel.get(alias_query_data[0]['name'])
487506
assert res.dict(by_alias=True) == alias_query_data[0]
488-
489-

tests/test_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ def __init__(self, cfg):
1111
def get(cls, id):
1212
pass
1313

14+
@classmethod
15+
def query(cls, id):
16+
pass
17+
1418

1519
class Model(BaseModel):
1620
id: int
@@ -38,5 +42,19 @@ def test_model_backend_get():
3842
assert m.total == 3.0
3943

4044

45+
def test_model_backend_query():
46+
with patch.object(
47+
FalseBackend, "query", return_value=[dict(id=1, name="two", total=3.0)]
48+
) as mock_query:
49+
m = Model.query(2)
50+
51+
mock_query.assert_called_with(2)
52+
assert m.count is None # In this case the backend did not provide a total count.
53+
assert len(m[:]) == 1 # .. but we can cast to a list and get that length.
54+
assert m[0].id == 1
55+
assert m[0].name == "two"
56+
assert m[0].total == 3.0
57+
58+
4159
def test_model_table_name_from_title():
4260
assert Model.get_table_name() == Model.Config.title.lower()

0 commit comments

Comments
 (0)