-
Notifications
You must be signed in to change notification settings - Fork 97
add support for bulk_update #148
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| import enum | ||
| import json | ||
| import typing | ||
|
|
||
| import databases | ||
|
|
@@ -20,6 +22,8 @@ | |
| "lte": "__le__", | ||
| } | ||
|
|
||
| MODEL = typing.TypeVar("MODEL", bound="Model") | ||
|
||
|
|
||
|
|
||
| def _update_auto_now_fields(values, fields): | ||
| for key, value in fields.items(): | ||
|
|
@@ -28,6 +32,15 @@ def _update_auto_now_fields(values, fields): | |
| return values | ||
|
|
||
|
|
||
| def _convert_value(value): | ||
| if isinstance(value, dict): | ||
| return json.dumps(value) | ||
| elif isinstance(value, enum.Enum): | ||
| return value.name | ||
| else: | ||
| return value | ||
|
|
||
|
|
||
| class ModelRegistry: | ||
| def __init__(self, database: databases.Database) -> None: | ||
| self.database = database | ||
|
|
@@ -454,6 +467,41 @@ async def update(self, **kwargs) -> None: | |
|
|
||
| await self.database.execute(expr) | ||
|
|
||
| async def bulk_update( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry I should've noticed this earlier, apologies for that.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I agree with you it needs to be more readable
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aminalaee Any updates ? |
||
| self, objs: typing.List[MODEL], fields: typing.List[str] | ||
| ) -> None: | ||
| fields = { | ||
| key: field.validator | ||
| for key, field in self.model_cls.fields.items() | ||
| if key in fields | ||
| } | ||
| validator = typesystem.Schema(fields=fields) | ||
| new_objs = [ | ||
| _update_auto_now_fields(validator.validate(value), self.model_cls.fields) | ||
| for value in [ | ||
| { | ||
| key: _convert_value(value) | ||
| for key, value in obj.__dict__.items() | ||
| if key in fields | ||
| } | ||
| for obj in objs | ||
| ] | ||
| ] | ||
| expr = ( | ||
| self.table.update() | ||
| .where(self.table.c.id == sqlalchemy.bindparam("id")) | ||
aminalaee marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| .values( | ||
| { | ||
| field: sqlalchemy.bindparam(field) | ||
| for obj in new_objs | ||
| for field in obj.keys() | ||
| } | ||
| ) | ||
| ) | ||
| pk_list = [{"id": obj.pk} for obj in objs] | ||
| joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)] | ||
| await self.database.execute_many(str(expr), joined_list) | ||
|
|
||
| async def get_or_create( | ||
| self, defaults: typing.Dict[str, typing.Any], **kwargs | ||
| ) -> typing.Tuple[typing.Any, bool]: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.