From 3b3257a7dde4c71bc1cbe2453369ad2c465a87d3 Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Tue, 11 Jan 2022 11:13:32 +0200 Subject: [PATCH 1/7] add support for bulk_update --- docs/making_queries.md | 47 ++++++++++++++++---------- orm/models.py | 48 ++++++++++++++++++++++++++ tests/test_columns.py | 76 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 18 deletions(-) diff --git a/docs/making_queries.md b/docs/making_queries.md index f7b9ec1..b36de48 100644 --- a/docs/making_queries.md +++ b/docs/making_queries.md @@ -53,15 +53,15 @@ notes = await Note.objects.filter(completed=True).all() There are some special operators defined automatically on every column: -* `in` - SQL `IN` operator. -* `exact` - filter instances matching exact value. -* `iexact` - same as `exact` but case-insensitive. -* `contains` - filter instances containing value. -* `icontains` - same as `contains` but case-insensitive. -* `lt` - filter instances having value `Less Than`. -* `lte` - filter instances having value `Less Than Equal`. -* `gt` - filter instances having value `Greater Than`. -* `gte` - filter instances having value `Greater Than Equal`. +- `in` - SQL `IN` operator. +- `exact` - filter instances matching exact value. +- `iexact` - same as `exact` but case-insensitive. +- `contains` - filter instances containing value. +- `icontains` - same as `contains` but case-insensitive. +- `lt` - filter instances having value `Less Than`. +- `lte` - filter instances having value `Less Than Equal`. +- `gt` - filter instances having value `Greater Than`. +- `gte` - filter instances having value `Greater Than Equal`. Example usage: @@ -84,7 +84,7 @@ notes = await Note.objects.filter(Note.columns.id.in_([1, 2, 3])).all() Here `Note.columns` refers to the columns of the underlying SQLAlchemy table. !!! note - Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields. +Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields. ### .limit() @@ -119,7 +119,7 @@ notes = await Note.objects.order_by("text", "-id").all() ``` !!! note - This will sort by ascending `text` and descending `id`. +This will sort by ascending `text` and descending `id`. ## Returning results @@ -146,10 +146,10 @@ await Note.objects.create(text="Send invoices.", completed=True) You need to pass a list of dictionaries of required fields to create multiple objects: ```python -await Product.objects.bulk_create( +await Note.objects.bulk_create( [ - {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, - {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, + {"text": "Buy the groceries", "completed": False}, + {"text": "Call Mum.", "completed": True}, ] ) @@ -209,7 +209,7 @@ note = await Note.objects.get(id=1) ``` !!! note - `.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`. +`.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`. ### .update() @@ -233,6 +233,18 @@ note = await Note.objects.first() await note.update(completed=True) ``` +### .bulk_update() + +You can also bulk update multiple objects at once by passing a list of objects and a list of fields to update. + +```python +notes = await Note.objects.all() +for note in notes : + note.completed = True + +await Note.objects.bulk_update(notes, fields=["completed"]) +``` + ## Convenience Methods ### .get_or_create() @@ -250,8 +262,7 @@ This will query a `Note` with `text` as `"Going to car wash"`, if it doesn't exist, it will use `defaults` argument to create the new instance. !!! note - Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. - +Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. ### .update_or_create() @@ -269,4 +280,4 @@ if an instance is found, it will use the `defaults` argument to update the insta If it matches no records, it will use the comibnation of arguments to create the new instance. !!! note - Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. +Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. diff --git a/orm/models.py b/orm/models.py index b402814..1bdfdd8 100644 --- a/orm/models.py +++ b/orm/models.py @@ -1,3 +1,5 @@ +import enum +import json import typing import databases @@ -20,6 +22,8 @@ "lte": "__le__", } +MODEL = typing.TypeVar("MODEL", bound="Model") + def _update_auto_now_fields(values, fields): for key, value in fields.items(): @@ -28,6 +32,15 @@ def _update_auto_now_fields(values, fields): return values +def _convert_value(value): + if isinstance(value, dict): + return json.dumps(value) + elif isinstance(value, enum.Enum): + return value.name + else: + return value + + class ModelRegistry: def __init__(self, database: databases.Database) -> None: self.database = database @@ -454,6 +467,41 @@ async def update(self, **kwargs) -> None: await self.database.execute(expr) + async def bulk_update( + self, objs: typing.List[MODEL], fields: typing.List[str] + ) -> None: + fields = { + key: field.validator + for key, field in self.model_cls.fields.items() + if key in fields + } + validator = typesystem.Schema(fields=fields) + new_objs = [ + _update_auto_now_fields(validator.validate(value), self.model_cls.fields) + for value in [ + { + key: _convert_value(value) + for key, value in obj.__dict__.items() + if key in fields + } + for obj in objs + ] + ] + expr = ( + self.table.update() + .where(self.table.c.id == sqlalchemy.bindparam("id")) + .values( + { + field: sqlalchemy.bindparam(field) + for obj in new_objs + for field in obj.keys() + } + ) + ) + pk_list = [{"id": obj.pk} for obj in objs] + joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)] + await self.database.execute_many(str(expr), joined_list) + async def get_or_create( self, defaults: typing.Dict[str, typing.Any], **kwargs ) -> typing.Tuple[typing.Any, bool]: diff --git a/tests/test_columns.py b/tests/test_columns.py index 278aecd..e628759 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -159,3 +159,79 @@ async def test_bulk_create(): assert products[1].data == {"foo": 456} assert products[1].value == 456.789 assert products[1].status == StatusEnum.DRAFT + + +async def test_bulk_update(): + await Product.objects.bulk_create( + [ + { + "created": "2020-01-01T00:00:00Z", + "data": {"foo": 123}, + "value": 123.456, + "status": StatusEnum.RELEASED, + }, + { + "created": "2020-01-01T00:00:00Z", + "data": {"foo": 456}, + "value": 456.789, + "status": StatusEnum.DRAFT, + }, + ] + ) + products = await Product.objects.all() + products[0].created = "2021-01-01T00:00:00Z" + products[1].created = "2022-01-01T00:00:00Z" + products[0].status = StatusEnum.DRAFT + products[1].status = StatusEnum.RELEASED + products[0].data = {"foo": 1234} + products[1].data = {"foo": 5678} + products[0].value = 1234.567 + products[1].value = 5678.891 + await Product.objects.bulk_update( + products, fields=["created", "status", "data", "value"] + ) + products = await Product.objects.all() + assert products[0].created == datetime.datetime(2021, 1, 1, 0, 0, 0) + assert products[1].created == datetime.datetime(2022, 1, 1, 0, 0, 0) + assert products[0].status == StatusEnum.DRAFT + assert products[1].status == StatusEnum.RELEASED + assert products[0].data == {"foo": 1234} + assert products[1].data == {"foo": 5678} + assert products[0].value == 1234.567 + assert products[1].value == 5678.891 + + +async def test_bulk_update_with_relation(): + class Album(orm.Model): + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "name": orm.Text(), + } + + class Track(orm.Model): + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "name": orm.Text(), + "album": orm.ForeignKey(Album), + } + + await models.create_all() + + album = await Album.objects.create(name="foo") + album2 = await Album.objects.create(name="bar") + + await Track.objects.bulk_create( + [ + {"name": "foo", "album": album}, + {"name": "bar", "album": album}, + ] + ) + tracks = await Track.objects.all() + for track in tracks: + track.album = album2 + await Track.objects.bulk_update(tracks, fields=["album"]) + tracks = await Track.objects.all() + assert tracks[0].album.pk == album2.pk + assert tracks[1].album.pk == album2.pk From 7baf305d535a01141690234defdb3d4d1a68bcb6 Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Tue, 11 Jan 2022 11:38:39 +0200 Subject: [PATCH 2/7] change id to model pk name --- orm/models.py | 4 ++-- tests/test_columns.py | 36 ------------------------------------ tests/test_foreignkey.py | 19 +++++++++++++++++++ 3 files changed, 21 insertions(+), 38 deletions(-) diff --git a/orm/models.py b/orm/models.py index 1bdfdd8..26b1205 100644 --- a/orm/models.py +++ b/orm/models.py @@ -489,7 +489,7 @@ async def bulk_update( ] expr = ( self.table.update() - .where(self.table.c.id == sqlalchemy.bindparam("id")) + .where(self.table.c.id == sqlalchemy.bindparam(self.pkname)) .values( { field: sqlalchemy.bindparam(field) @@ -498,7 +498,7 @@ async def bulk_update( } ) ) - pk_list = [{"id": obj.pk} for obj in objs] + pk_list = [{self.pkname: obj.pk} for obj in objs] joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)] await self.database.execute_many(str(expr), joined_list) diff --git a/tests/test_columns.py b/tests/test_columns.py index e628759..5fc9888 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -199,39 +199,3 @@ async def test_bulk_update(): assert products[1].data == {"foo": 5678} assert products[0].value == 1234.567 assert products[1].value == 5678.891 - - -async def test_bulk_update_with_relation(): - class Album(orm.Model): - registry = models - fields = { - "id": orm.Integer(primary_key=True), - "name": orm.Text(), - } - - class Track(orm.Model): - registry = models - fields = { - "id": orm.Integer(primary_key=True), - "name": orm.Text(), - "album": orm.ForeignKey(Album), - } - - await models.create_all() - - album = await Album.objects.create(name="foo") - album2 = await Album.objects.create(name="bar") - - await Track.objects.bulk_create( - [ - {"name": "foo", "album": album}, - {"name": "bar", "album": album}, - ] - ) - tracks = await Track.objects.all() - for track in tracks: - track.album = album2 - await Track.objects.bulk_update(tracks, fields=["album"]) - tracks = await Track.objects.all() - assert tracks[0].album.pk == album2.pk - assert tracks[1].album.pk == album2.pk diff --git a/tests/test_foreignkey.py b/tests/test_foreignkey.py index 1ab6175..8b9e5ef 100644 --- a/tests/test_foreignkey.py +++ b/tests/test_foreignkey.py @@ -278,3 +278,22 @@ async def test_nullable_foreign_key(): assert member.email == "dev@encode.io" assert member.team.pk is None + + +async def test_bulk_update_with_relation(): + album = await Album.objects.create(name="foo") + album2 = await Album.objects.create(name="bar") + + await Track.objects.bulk_create( + [ + {"name": "foo", "album": album, "position": 1, "title": "foo"}, + {"name": "bar", "album": album, "position": 2, "title": "bar"}, + ] + ) + tracks = await Track.objects.all() + for track in tracks: + track.album = album2 + await Track.objects.bulk_update(tracks, fields=["album"]) + tracks = await Track.objects.all() + assert tracks[0].album.pk == album2.pk + assert tracks[1].album.pk == album2.pk From aaa98281ed927aca01f3ca752128711a73c0dfa5 Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Tue, 11 Jan 2022 12:12:03 +0200 Subject: [PATCH 3/7] fix tests issue --- tests/test_columns.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_columns.py b/tests/test_columns.py index 5fc9888..7db2227 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -165,13 +165,13 @@ async def test_bulk_update(): await Product.objects.bulk_create( [ { - "created": "2020-01-01T00:00:00Z", + "created_day": datetime.date.today(), "data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED, }, { - "created": "2020-01-01T00:00:00Z", + "created_day": datetime.date.today(), "data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT, @@ -179,8 +179,8 @@ async def test_bulk_update(): ] ) products = await Product.objects.all() - products[0].created = "2021-01-01T00:00:00Z" - products[1].created = "2022-01-01T00:00:00Z" + products[0].created_day = datetime.date.today() - datetime.timedelta(days=1) + products[1].created_day = datetime.date.today() - datetime.timedelta(days=1) products[0].status = StatusEnum.DRAFT products[1].status = StatusEnum.RELEASED products[0].data = {"foo": 1234} @@ -188,11 +188,11 @@ async def test_bulk_update(): products[0].value = 1234.567 products[1].value = 5678.891 await Product.objects.bulk_update( - products, fields=["created", "status", "data", "value"] + products, fields=["created_day", "status", "data", "value"] ) products = await Product.objects.all() - assert products[0].created == datetime.datetime(2021, 1, 1, 0, 0, 0) - assert products[1].created == datetime.datetime(2022, 1, 1, 0, 0, 0) + assert products[0].created_day == datetime.date.today() - datetime.timedelta(days=1) + assert products[1].created_day == datetime.date.today() - datetime.timedelta(days=1) assert products[0].status == StatusEnum.DRAFT assert products[1].status == StatusEnum.RELEASED assert products[0].data == {"foo": 1234} From a02da8ee047bdb43fbb77a74ab241276a0a7231e Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Tue, 11 Jan 2022 12:19:58 +0200 Subject: [PATCH 4/7] fix tests issue --- tests/test_columns.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_columns.py b/tests/test_columns.py index 7db2227..3fd7656 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -185,8 +185,8 @@ async def test_bulk_update(): products[1].status = StatusEnum.RELEASED products[0].data = {"foo": 1234} products[1].data = {"foo": 5678} - products[0].value = 1234.567 - products[1].value = 5678.891 + products[0].value = 345.5 + products[1].value = 789.8 await Product.objects.bulk_update( products, fields=["created_day", "status", "data", "value"] ) @@ -197,5 +197,5 @@ async def test_bulk_update(): assert products[1].status == StatusEnum.RELEASED assert products[0].data == {"foo": 1234} assert products[1].data == {"foo": 5678} - assert products[0].value == 1234.567 - assert products[1].value == 5678.891 + assert products[0].value == 345.5 + assert products[1].value == 789.8 From cea75153f344d8b8e67965f400ddceb8ccd6c0fd Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Tue, 11 Jan 2022 12:52:02 +0200 Subject: [PATCH 5/7] fix lint issues --- docs/making_queries.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/making_queries.md b/docs/making_queries.md index b36de48..b36a2d3 100644 --- a/docs/making_queries.md +++ b/docs/making_queries.md @@ -53,15 +53,15 @@ notes = await Note.objects.filter(completed=True).all() There are some special operators defined automatically on every column: -- `in` - SQL `IN` operator. -- `exact` - filter instances matching exact value. -- `iexact` - same as `exact` but case-insensitive. -- `contains` - filter instances containing value. -- `icontains` - same as `contains` but case-insensitive. -- `lt` - filter instances having value `Less Than`. -- `lte` - filter instances having value `Less Than Equal`. -- `gt` - filter instances having value `Greater Than`. -- `gte` - filter instances having value `Greater Than Equal`. +* `in` - SQL `IN` operator. +* `exact` - filter instances matching exact value. +* `iexact` - same as `exact` but case-insensitive. +* `contains` - filter instances containing value. +* `icontains` - same as `contains` but case-insensitive. +* `lt` - filter instances having value `Less Than`. +* `lte` - filter instances having value `Less Than Equal`. +* `gt` - filter instances having value `Greater Than`. +* `gte` - filter instances having value `Greater Than Equal`. Example usage: @@ -84,7 +84,7 @@ notes = await Note.objects.filter(Note.columns.id.in_([1, 2, 3])).all() Here `Note.columns` refers to the columns of the underlying SQLAlchemy table. !!! note -Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields. + Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields. ### .limit() @@ -119,7 +119,7 @@ notes = await Note.objects.order_by("text", "-id").all() ``` !!! note -This will sort by ascending `text` and descending `id`. + This will sort by ascending `text` and descending `id`. ## Returning results @@ -209,7 +209,7 @@ note = await Note.objects.get(id=1) ``` !!! note -`.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`. + `.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`. ### .update() @@ -262,7 +262,7 @@ This will query a `Note` with `text` as `"Going to car wash"`, if it doesn't exist, it will use `defaults` argument to create the new instance. !!! note -Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. + Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. ### .update_or_create() @@ -280,4 +280,4 @@ if an instance is found, it will use the `defaults` argument to update the insta If it matches no records, it will use the comibnation of arguments to create the new instance. !!! note -Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. + Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. From 592ebba6ee34f2fc213208655d7c35c7c66188bf Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Wed, 2 Feb 2022 16:15:55 +0200 Subject: [PATCH 6/7] dynamically get pk column --- orm/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/orm/models.py b/orm/models.py index 26b1205..ea9264f 100644 --- a/orm/models.py +++ b/orm/models.py @@ -489,7 +489,9 @@ async def bulk_update( ] expr = ( self.table.update() - .where(self.table.c.id == sqlalchemy.bindparam(self.pkname)) + .where( + getattr(self.table.c, self.pkname) == sqlalchemy.bindparam(self.pkname) + ) .values( { field: sqlalchemy.bindparam(field) @@ -498,7 +500,7 @@ async def bulk_update( } ) ) - pk_list = [{self.pkname: obj.pk} for obj in objs] + pk_list = [{self.pkname: getattr(obj, self.pkname)} for obj in objs] joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)] await self.database.execute_many(str(expr), joined_list) From 6f7963581ae8468354c7f590f9ec8aaa60d8125a Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Wed, 2 Feb 2022 18:03:01 +0200 Subject: [PATCH 7/7] improve readability --- orm/models.py | 44 +++++++++++++++++++------------------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/orm/models.py b/orm/models.py index ea9264f..2766bd2 100644 --- a/orm/models.py +++ b/orm/models.py @@ -22,8 +22,6 @@ "lte": "__le__", } -MODEL = typing.TypeVar("MODEL", bound="Model") - def _update_auto_now_fields(values, fields): for key, value in fields.items(): @@ -468,7 +466,7 @@ async def update(self, **kwargs) -> None: await self.database.execute(expr) async def bulk_update( - self, objs: typing.List[MODEL], fields: typing.List[str] + self, objs: typing.List["Model"], fields: typing.List[str] ) -> None: fields = { key: field.validator @@ -477,29 +475,25 @@ async def bulk_update( } validator = typesystem.Schema(fields=fields) new_objs = [ - _update_auto_now_fields(validator.validate(value), self.model_cls.fields) - for value in [ - { - key: _convert_value(value) - for key, value in obj.__dict__.items() - if key in fields - } - for obj in objs - ] + { + key: _convert_value(value) + for key, value in obj.__dict__.items() + if key in fields + } + for obj in objs ] - expr = ( - self.table.update() - .where( - getattr(self.table.c, self.pkname) == sqlalchemy.bindparam(self.pkname) - ) - .values( - { - field: sqlalchemy.bindparam(field) - for obj in new_objs - for field in obj.keys() - } - ) - ) + new_objs = [ + _update_auto_now_fields(validator.validate(obj), self.model_cls.fields) + for obj in new_objs + ] + pk_column = getattr(self.table.c, self.pkname) + expr = self.table.update().where(pk_column == sqlalchemy.bindparam(self.pkname)) + kwargs = { + field: sqlalchemy.bindparam(field) + for obj in new_objs + for field in obj.keys() + } + expr = expr.values(kwargs) pk_list = [{self.pkname: getattr(obj, self.pkname)} for obj in objs] joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)] await self.database.execute_many(str(expr), joined_list)