Skip to content

Commit ca8ccea

Browse files
committed
Add support for Union typing
1 parent 61a0b73 commit ca8ccea

File tree

2 files changed

+50
-27
lines changed

2 files changed

+50
-27
lines changed

pydanticrud/backends/dynamodb.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def _to_epoch_float(dt):
8484
"boolean": lambda d: 1 if d else 0,
8585
"object": json.dumps,
8686
"array": json.dumps,
87-
"anyOf": str, # FIXME - this could be more complicated. This is a hacky fix.
8887
}
8988

9089

@@ -98,7 +97,6 @@ def do_nothing(x):
9897
"boolean": bool,
9998
"object": json.loads,
10099
"array": json.loads,
101-
"anyOf": do_nothing, # FIXME - this could be more complicated. This is a hacky fix.
102100
}
103101

104102

@@ -123,20 +121,28 @@ def __init__(self, schema):
123121
self.properties = schema["properties"]
124122
self.definitions = schema.get("definitions")
125123

126-
def _get_type(self, field_name):
127-
if "$ref" in self.properties[field_name]:
128-
def_name = self.properties[field_name]["$ref"].split('/')[-1]
129-
return self.definitions[def_name].get("type")
130-
return self.properties[field_name].get("type", "anyOf")
124+
def _get_type_possibilities(self, field_name) -> Set[str]:
125+
field_properties = self.properties[field_name]
126+
127+
possible_types = []
128+
if "anyOf" in field_properties:
129+
possible_types.extend([r.get("$ref", r.get("type")) for r in field_properties["anyOf"]])
130+
else:
131+
possible_types.append(field_properties.get("$ref", field_properties.get("type")))
132+
return set([
133+
(self.definitions[t.split('/')[-1]].get("type") if t.startswith('#/') else t)
134+
for t in possible_types
135+
])
131136

132137
def _serialize_field(self, field_name, value):
133-
field_type = self._get_type(field_name)
134-
try:
135-
if value is not None:
136-
return SERIALIZE_MAP[field_type](value)
137-
except KeyError:
138-
log.debug(f"No serializer for field_type {field_type}")
139-
return value # do nothing but log it.
138+
field_types = self._get_type_possibilities(field_name)
139+
if value is not None:
140+
for t in field_types:
141+
try:
142+
return SERIALIZE_MAP[t](value)
143+
except (ValueError, TypeError, KeyError):
144+
pass
145+
return value
140146

141147
def serialize_record(self, data_dict) -> dict:
142148
"""
@@ -148,13 +154,14 @@ def serialize_record(self, data_dict) -> dict:
148154
}
149155

150156
def _deserialize_field(self, field_name, value):
151-
field_type = self._get_type(field_name)
152-
try:
153-
if value is not None:
154-
return DESERIALIZE_MAP[field_type](value)
155-
except KeyError:
156-
log.debug(f"No deserializer for field_type {field_type}")
157-
return value # do nothing but log it.
157+
field_types = self._get_type_possibilities(field_name)
158+
if value is not None:
159+
for t in field_types:
160+
try:
161+
return DESERIALIZE_MAP[t](value)
162+
except (ValueError, TypeError, KeyError):
163+
pass
164+
return value
158165

159166
def deserialize_record(self, data_dict) -> dict:
160167
"""
@@ -252,7 +259,7 @@ def initialize(self):
252259
{
253260
"AttributeName": attr,
254261
"AttributeType": DYNAMO_TYPE_MAP.get(
255-
schema["properties"][attr].get("type", "anyOf"), "S"
262+
schema["properties"][attr].get("type"), "S"
256263
),
257264
}
258265
for attr in self.possible_keys

tests/test_dynamodb.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional
1+
from typing import Dict, List, Optional, Union
22
from base64 import b64decode, b64encode
33
from decimal import Decimal
44
from datetime import datetime
@@ -80,11 +80,17 @@ class Ticket(PydanticBaseModel):
8080
number: str
8181

8282

83+
class SomethingElse(PydanticBaseModel):
84+
herp: bool
85+
derp: int
86+
87+
8388
class NestedModel(BaseModel):
8489
account: str
8590
sort_date_key: str
8691
expires: str
8792
ticket: Optional[Ticket]
93+
other: Union[Ticket, SomethingElse]
8894

8995
class Config:
9096
title = "NestedModelTitle123"
@@ -141,9 +147,20 @@ def nested_model_data_generator(include_ticket=True, **kwargs):
141147
expires=future_datetime(days=1, hours=random.randint(1, 12), minutes=random.randint(1, 58)).isoformat(),
142148
ticket={
143149
'created_time': random_datetime().isoformat(),
144-
'number': random.randint(0, 1000)
150+
'number': str(random.randint(0, 1000))
151+
152+
} if include_ticket else None,
153+
other=random.choice([
154+
{
155+
'created_time': random_datetime().isoformat(),
156+
'number': str(random.randint(0, 1000))
145157

146-
} if include_ticket else None
158+
}, {
159+
'herp': random.choice([True, False]),
160+
'derp': random.randint(0, 1000)
161+
162+
}
163+
])
147164
)
148165
data.update(kwargs)
149166
return data
@@ -464,6 +481,7 @@ def test_query_with_nested_model(dynamo, nested_query_data):
464481
assert isinstance(m.ticket, Ticket)
465482
assert m.ticket.created_time is not None
466483
assert m.ticket.number is not None
484+
assert isinstance(m.other, (Ticket, SomethingElse))
467485

468486

469487
def test_query_with_nested_model_optional(dynamo, nested_query_data_empty_ticket):
@@ -486,5 +504,3 @@ def test_get_alias_model_data(dynamo, alias_query_data):
486504
data = alias_model_data_generator()
487505
res = AliasKeyModel.get(alias_query_data[0]['name'])
488506
assert res.dict(by_alias=True) == alias_query_data[0]
489-
490-

0 commit comments

Comments
 (0)