-
Notifications
You must be signed in to change notification settings - Fork 96
add support for bulk_update #148
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
import enum | ||
import json | ||
import typing | ||
|
||
import databases | ||
|
@@ -20,6 +22,8 @@ | |
"lte": "__le__", | ||
} | ||
|
||
MODEL = typing.TypeVar("MODEL", bound="Model") | ||
|
||
|
||
|
||
def _update_auto_now_fields(values, fields): | ||
for key, value in fields.items(): | ||
|
@@ -28,6 +32,15 @@ def _update_auto_now_fields(values, fields): | |
return values | ||
|
||
|
||
def _convert_value(value): | ||
if isinstance(value, dict): | ||
return json.dumps(value) | ||
elif isinstance(value, enum.Enum): | ||
return value.name | ||
else: | ||
return value | ||
|
||
|
||
class ModelRegistry: | ||
def __init__(self, database: databases.Database) -> None: | ||
self.database = database | ||
|
@@ -454,6 +467,41 @@ async def update(self, **kwargs) -> None: | |
|
||
await self.database.execute(expr) | ||
|
||
async def bulk_update( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I should've noticed this earlier, apologies for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I agree with you it needs to be more readable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aminalaee Any updates ? |
||
self, objs: typing.List[MODEL], fields: typing.List[str] | ||
) -> None: | ||
fields = { | ||
key: field.validator | ||
for key, field in self.model_cls.fields.items() | ||
if key in fields | ||
} | ||
validator = typesystem.Schema(fields=fields) | ||
new_objs = [ | ||
_update_auto_now_fields(validator.validate(value), self.model_cls.fields) | ||
for value in [ | ||
{ | ||
key: _convert_value(value) | ||
for key, value in obj.__dict__.items() | ||
if key in fields | ||
} | ||
for obj in objs | ||
] | ||
] | ||
expr = ( | ||
self.table.update() | ||
.where(self.table.c.id == sqlalchemy.bindparam("id")) | ||
aminalaee marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
.values( | ||
{ | ||
field: sqlalchemy.bindparam(field) | ||
for obj in new_objs | ||
for field in obj.keys() | ||
} | ||
) | ||
) | ||
pk_list = [{"id": obj.pk} for obj in objs] | ||
joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)] | ||
await self.database.execute_many(str(expr), joined_list) | ||
|
||
async def get_or_create( | ||
self, defaults: typing.Dict[str, typing.Any], **kwargs | ||
) -> typing.Tuple[typing.Any, bool]: | ||
|
Uh oh!
There was an error while loading. Please reload this page.