Skip to content
This repository was archived by the owner on Aug 19, 2025. It is now read-only.

Commit 3b3257a

Browse files
author
Micheal Gendy
committed
add support for bulk_update
1 parent 0df624e commit 3b3257a

File tree

3 files changed

+153
-18
lines changed

3 files changed

+153
-18
lines changed

docs/making_queries.md

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ notes = await Note.objects.filter(completed=True).all()
5353

5454
There are some special operators defined automatically on every column:
5555

56-
* `in` - SQL `IN` operator.
57-
* `exact` - filter instances matching exact value.
58-
* `iexact` - same as `exact` but case-insensitive.
59-
* `contains` - filter instances containing value.
60-
* `icontains` - same as `contains` but case-insensitive.
61-
* `lt` - filter instances having value `Less Than`.
62-
* `lte` - filter instances having value `Less Than Equal`.
63-
* `gt` - filter instances having value `Greater Than`.
64-
* `gte` - filter instances having value `Greater Than Equal`.
56+
- `in` - SQL `IN` operator.
57+
- `exact` - filter instances matching exact value.
58+
- `iexact` - same as `exact` but case-insensitive.
59+
- `contains` - filter instances containing value.
60+
- `icontains` - same as `contains` but case-insensitive.
61+
- `lt` - filter instances having value `Less Than`.
62+
- `lte` - filter instances having value `Less Than Equal`.
63+
- `gt` - filter instances having value `Greater Than`.
64+
- `gte` - filter instances having value `Greater Than Equal`.
6565

6666
Example usage:
6767

@@ -84,7 +84,7 @@ notes = await Note.objects.filter(Note.columns.id.in_([1, 2, 3])).all()
8484
Here `Note.columns` refers to the columns of the underlying SQLAlchemy table.
8585

8686
!!! note
87-
Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields.
87+
Note that `Note.columns` returns SQLAlchemy table columns, whereas `Note.fields` returns `orm` fields.
8888

8989
### .limit()
9090

@@ -119,7 +119,7 @@ notes = await Note.objects.order_by("text", "-id").all()
119119
```
120120

121121
!!! note
122-
This will sort by ascending `text` and descending `id`.
122+
This will sort by ascending `text` and descending `id`.
123123

124124
## Returning results
125125

@@ -146,10 +146,10 @@ await Note.objects.create(text="Send invoices.", completed=True)
146146
You need to pass a list of dictionaries of required fields to create multiple objects:
147147

148148
```python
149-
await Product.objects.bulk_create(
149+
await Note.objects.bulk_create(
150150
[
151-
{"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED},
152-
{"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT},
151+
{"text": "Buy the groceries", "completed": False},
152+
{"text": "Call Mum.", "completed": True},
153153

154154
]
155155
)
@@ -209,7 +209,7 @@ note = await Note.objects.get(id=1)
209209
```
210210

211211
!!! note
212-
`.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`.
212+
`.get()` expects to find only one instance. This can raise `NoMatch` or `MultipleMatches`.
213213

214214
### .update()
215215

@@ -233,6 +233,18 @@ note = await Note.objects.first()
233233
await note.update(completed=True)
234234
```
235235

236+
### .bulk_update()
237+
238+
You can also bulk update multiple objects at once by passing a list of objects and a list of fields to update.
239+
240+
```python
241+
notes = await Note.objects.all()
242+
for note in notes :
243+
note.completed = True
244+
245+
await Note.objects.bulk_update(notes, fields=["completed"])
246+
```
247+
236248
## Convenience Methods
237249

238250
### .get_or_create()
@@ -250,8 +262,7 @@ This will query a `Note` with `text` as `"Going to car wash"`,
250262
if it doesn't exist, it will use `defaults` argument to create the new instance.
251263

252264
!!! note
253-
Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.
254-
265+
Since `get_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.
255266

256267
### .update_or_create()
257268

@@ -269,4 +280,4 @@ if an instance is found, it will use the `defaults` argument to update the insta
269280
If it matches no records, it will use the comibnation of arguments to create the new instance.
270281

271282
!!! note
272-
Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.
283+
Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception.

orm/models.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import enum
2+
import json
13
import typing
24

35
import databases
@@ -20,6 +22,8 @@
2022
"lte": "__le__",
2123
}
2224

25+
MODEL = typing.TypeVar("MODEL", bound="Model")
26+
2327

2428
def _update_auto_now_fields(values, fields):
2529
for key, value in fields.items():
@@ -28,6 +32,15 @@ def _update_auto_now_fields(values, fields):
2832
return values
2933

3034

35+
def _convert_value(value):
36+
if isinstance(value, dict):
37+
return json.dumps(value)
38+
elif isinstance(value, enum.Enum):
39+
return value.name
40+
else:
41+
return value
42+
43+
3144
class ModelRegistry:
3245
def __init__(self, database: databases.Database) -> None:
3346
self.database = database
@@ -454,6 +467,41 @@ async def update(self, **kwargs) -> None:
454467

455468
await self.database.execute(expr)
456469

470+
async def bulk_update(
471+
self, objs: typing.List[MODEL], fields: typing.List[str]
472+
) -> None:
473+
fields = {
474+
key: field.validator
475+
for key, field in self.model_cls.fields.items()
476+
if key in fields
477+
}
478+
validator = typesystem.Schema(fields=fields)
479+
new_objs = [
480+
_update_auto_now_fields(validator.validate(value), self.model_cls.fields)
481+
for value in [
482+
{
483+
key: _convert_value(value)
484+
for key, value in obj.__dict__.items()
485+
if key in fields
486+
}
487+
for obj in objs
488+
]
489+
]
490+
expr = (
491+
self.table.update()
492+
.where(self.table.c.id == sqlalchemy.bindparam("id"))
493+
.values(
494+
{
495+
field: sqlalchemy.bindparam(field)
496+
for obj in new_objs
497+
for field in obj.keys()
498+
}
499+
)
500+
)
501+
pk_list = [{"id": obj.pk} for obj in objs]
502+
joined_list = [{**pk, **value} for pk, value in zip(pk_list, new_objs)]
503+
await self.database.execute_many(str(expr), joined_list)
504+
457505
async def get_or_create(
458506
self, defaults: typing.Dict[str, typing.Any], **kwargs
459507
) -> typing.Tuple[typing.Any, bool]:

tests/test_columns.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,79 @@ async def test_bulk_create():
159159
assert products[1].data == {"foo": 456}
160160
assert products[1].value == 456.789
161161
assert products[1].status == StatusEnum.DRAFT
162+
163+
164+
async def test_bulk_update():
165+
await Product.objects.bulk_create(
166+
[
167+
{
168+
"created": "2020-01-01T00:00:00Z",
169+
"data": {"foo": 123},
170+
"value": 123.456,
171+
"status": StatusEnum.RELEASED,
172+
},
173+
{
174+
"created": "2020-01-01T00:00:00Z",
175+
"data": {"foo": 456},
176+
"value": 456.789,
177+
"status": StatusEnum.DRAFT,
178+
},
179+
]
180+
)
181+
products = await Product.objects.all()
182+
products[0].created = "2021-01-01T00:00:00Z"
183+
products[1].created = "2022-01-01T00:00:00Z"
184+
products[0].status = StatusEnum.DRAFT
185+
products[1].status = StatusEnum.RELEASED
186+
products[0].data = {"foo": 1234}
187+
products[1].data = {"foo": 5678}
188+
products[0].value = 1234.567
189+
products[1].value = 5678.891
190+
await Product.objects.bulk_update(
191+
products, fields=["created", "status", "data", "value"]
192+
)
193+
products = await Product.objects.all()
194+
assert products[0].created == datetime.datetime(2021, 1, 1, 0, 0, 0)
195+
assert products[1].created == datetime.datetime(2022, 1, 1, 0, 0, 0)
196+
assert products[0].status == StatusEnum.DRAFT
197+
assert products[1].status == StatusEnum.RELEASED
198+
assert products[0].data == {"foo": 1234}
199+
assert products[1].data == {"foo": 5678}
200+
assert products[0].value == 1234.567
201+
assert products[1].value == 5678.891
202+
203+
204+
async def test_bulk_update_with_relation():
205+
class Album(orm.Model):
206+
registry = models
207+
fields = {
208+
"id": orm.Integer(primary_key=True),
209+
"name": orm.Text(),
210+
}
211+
212+
class Track(orm.Model):
213+
registry = models
214+
fields = {
215+
"id": orm.Integer(primary_key=True),
216+
"name": orm.Text(),
217+
"album": orm.ForeignKey(Album),
218+
}
219+
220+
await models.create_all()
221+
222+
album = await Album.objects.create(name="foo")
223+
album2 = await Album.objects.create(name="bar")
224+
225+
await Track.objects.bulk_create(
226+
[
227+
{"name": "foo", "album": album},
228+
{"name": "bar", "album": album},
229+
]
230+
)
231+
tracks = await Track.objects.all()
232+
for track in tracks:
233+
track.album = album2
234+
await Track.objects.bulk_update(tracks, fields=["album"])
235+
tracks = await Track.objects.all()
236+
assert tracks[0].album.pk == album2.pk
237+
assert tracks[1].album.pk == album2.pk

0 commit comments

Comments
 (0)