Skip to content

Commit dfd1e81

Browse files
committed
[FIX]: async_update's missing operators
1 parent c9a8961 commit dfd1e81

File tree

2 files changed

+74
-37
lines changed

2 files changed

+74
-37
lines changed

mongoengine/queryset/base.py

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2526,55 +2526,37 @@ async def async_update(
25262526
:param upsert: insert if document doesn't exist (default False)
25272527
:param multi: update multiple documents (default True)
25282528
:param write_concern: write concern options
2529-
:param update: update operations to perform
2529+
:param update: Django-style update keyword arguments
25302530
:returns: number of documents affected
25312531
"""
25322532
from mongoengine.connection import DEFAULT_CONNECTION_NAME
25332533

25342534
alias = self._document._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
25352535
ensure_async_connection(alias)
25362536

2537-
if not update:
2538-
raise OperationError("No update parameters passed")
2537+
if not update and not upsert:
2538+
raise OperationError("No update parameters, would remove data")
25392539

25402540
queryset = self.clone()
25412541
query = queryset._query
25422542

2543-
# Process the update dict to handle field names and operators
2544-
update_dict = {}
2545-
for key, value in update.items():
2546-
# Handle operators like inc__field, set__field, etc.
2547-
if "__" in key and key.split("__")[0] in [
2548-
"inc",
2549-
"set",
2550-
"unset",
2551-
"push",
2552-
"pull",
2553-
"pull_all",
2554-
"addToSet",
2555-
]:
2556-
op, field = key.split("__", 1)
2557-
# Convert pull_all to pullAll for MongoDB
2558-
if op == "pull_all":
2559-
mongo_op = "$pullAll"
2560-
else:
2561-
mongo_op = f"${op}"
2562-
if mongo_op not in update_dict:
2563-
update_dict[mongo_op] = {}
2564-
field_name = queryset._document._translate_field_name(field)
2565-
update_dict[mongo_op][field_name] = value
2566-
elif key.startswith("$"):
2567-
# Direct MongoDB operator
2568-
update_dict[key] = {}
2569-
for field_name, field_value in value.items():
2570-
field_name = queryset._document._translate_field_name(field_name)
2571-
update_dict[key][field_name] = field_value
2543+
# Use transform.update to handle all operators consistently with sync version
2544+
if "__raw__" in update and isinstance(update["__raw__"], list):
2545+
# Case of Update with Aggregation Pipeline
2546+
update_dict = [
2547+
transform.update(queryset._document, **{"__raw__": u})
2548+
for u in update["__raw__"]
2549+
]
2550+
else:
2551+
update_dict = transform.update(queryset._document, **update)
2552+
2553+
# If doing an atomic upsert on an inheritable class
2554+
# then ensure we add _cls to the update operation
2555+
if upsert and "_cls" in query:
2556+
if "$set" in update_dict:
2557+
update_dict["$set"]["_cls"] = queryset._document._class_name
25722558
else:
2573-
# Direct field update - wrap in $set
2574-
if "$set" not in update_dict:
2575-
update_dict["$set"] = {}
2576-
key = queryset._document._translate_field_name(key)
2577-
update_dict["$set"][key] = value
2559+
update_dict["$set"] = {"_cls": queryset._document._class_name}
25782560

25792561
collection = await self._async_get_collection()
25802562

tests/test_async_queryset.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,61 @@ async def test_async_update(self):
282282
old_authors = await AsyncAuthor.objects.filter(age=55).async_to_list()
283283
assert len(old_authors) == 2
284284

285+
@pytest.mark.asyncio
286+
async def test_async_update_add_to_set(self):
287+
"""Test async_update() with add_to_set operator."""
288+
# Create a book with some tags
289+
book = await AsyncBook.objects.async_create(
290+
title="Test Book", pages=100, tags=["python"]
291+
)
292+
293+
# Add a new tag using add_to_set
294+
updated = await AsyncBook.objects.filter(id=book.id).async_update(
295+
add_to_set__tags="mongodb"
296+
)
297+
assert updated == 1
298+
299+
# Verify the tag was added
300+
book = await AsyncBook.objects.async_get(id=book.id)
301+
assert "python" in book.tags
302+
assert "mongodb" in book.tags
303+
assert len(book.tags) == 2
304+
305+
# Try to add duplicate tag - should not add
306+
await AsyncBook.objects.filter(id=book.id).async_update(
307+
add_to_set__tags="python"
308+
)
309+
book = await AsyncBook.objects.async_get(id=book.id)
310+
assert len(book.tags) == 2 # Still 2, no duplicate
311+
312+
@pytest.mark.asyncio
313+
async def test_async_update_operators(self):
314+
"""Test async_update() with various operators."""
315+
# Create a book
316+
book = await AsyncBook.objects.async_create(
317+
title="Test Book", pages=100, tags=["a", "b"]
318+
)
319+
320+
# Test push operator
321+
await AsyncBook.objects.filter(id=book.id).async_update(push__tags="c")
322+
book = await AsyncBook.objects.async_get(id=book.id)
323+
assert book.tags == ["a", "b", "c"]
324+
325+
# Test pull operator
326+
await AsyncBook.objects.filter(id=book.id).async_update(pull__tags="b")
327+
book = await AsyncBook.objects.async_get(id=book.id)
328+
assert book.tags == ["a", "c"]
329+
330+
# Test set operator
331+
await AsyncBook.objects.filter(id=book.id).async_update(set__pages=200)
332+
book = await AsyncBook.objects.async_get(id=book.id)
333+
assert book.pages == 200
334+
335+
# Test unset operator
336+
await AsyncBook.objects.filter(id=book.id).async_update(unset__pages=1)
337+
book = await AsyncBook.objects.async_get(id=book.id)
338+
assert book.pages is None
339+
285340
@pytest.mark.asyncio
286341
async def test_async_update_one(self):
287342
"""Test async_update_one() method."""

0 commit comments

Comments
 (0)