Skip to content

Commit f89214f

Browse files
committed
Fixes bug where an EmbeddedDocument that shares the same id of its parent Document could be missing updates when .save was called
Fixes #1768, Fixes #1685
1 parent d17cac8 commit f89214f

File tree

3 files changed

+79
-29
lines changed

3 files changed

+79
-29
lines changed

mongoengine/base/document.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -503,65 +503,62 @@ def _clear_changed_fields(self):
503503

504504
self._changed_fields = []
505505

506-
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
506+
def _nestable_types_changed_fields(self, changed_fields, base_key, data):
507+
"""Inspect nested data for changed fields
508+
509+
:param changed_fields: Previously collected changed fields
510+
:param base_key: The base key that must be used to prepend changes to this data
511+
:param data: data to inspect for changes
512+
"""
507513
# Loop list / dict fields as they contain documents
508514
# Determine the iterator to use
509515
if not hasattr(data, 'items'):
510516
iterator = enumerate(data)
511517
else:
512518
iterator = data.iteritems()
513519

514-
for index, value in iterator:
515-
list_key = '%s%s.' % (key, index)
520+
for index_or_key, value in iterator:
521+
item_key = '%s%s.' % (base_key, index_or_key)
516522
# don't check anything lower if this key is already marked
517523
# as changed.
518-
if list_key[:-1] in changed_fields:
524+
if item_key[:-1] in changed_fields:
519525
continue
526+
520527
if hasattr(value, '_get_changed_fields'):
521-
changed = value._get_changed_fields(inspected)
522-
changed_fields += ['%s%s' % (list_key, k)
523-
for k in changed if k]
528+
changed = value._get_changed_fields()
529+
changed_fields += ['%s%s' % (item_key, k) for k in changed if k]
524530
elif isinstance(value, (list, tuple, dict)):
525531
self._nestable_types_changed_fields(
526-
changed_fields, list_key, value, inspected)
532+
changed_fields, item_key, value)
527533

528-
def _get_changed_fields(self, inspected=None):
534+
def _get_changed_fields(self):
529535
"""Return a list of all fields that have explicitly been changed.
530536
"""
531537
EmbeddedDocument = _import_class('EmbeddedDocument')
532-
DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument')
533538
ReferenceField = _import_class('ReferenceField')
534539
SortedListField = _import_class('SortedListField')
535540

536541
changed_fields = []
537542
changed_fields += getattr(self, '_changed_fields', [])
538543

539-
inspected = inspected or set()
540-
if hasattr(self, 'id') and isinstance(self.id, Hashable):
541-
if self.id in inspected:
542-
return changed_fields
543-
inspected.add(self.id)
544-
545544
for field_name in self._fields_ordered:
546545
db_field_name = self._db_field_map.get(field_name, field_name)
547546
key = '%s.' % db_field_name
548547
data = self._data.get(field_name, None)
549548
field = self._fields.get(field_name)
550549

551-
if hasattr(data, 'id'):
552-
if data.id in inspected:
553-
continue
554-
if isinstance(field, ReferenceField):
550+
if db_field_name in changed_fields:
551+
# Whole field already marked as changed, no need to go further
552+
continue
553+
554+
if isinstance(field, ReferenceField): # Don't follow referenced documents
555555
continue
556-
elif (
557-
isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and
558-
db_field_name not in changed_fields
559-
):
556+
557+
if isinstance(data, EmbeddedDocument):
560558
# Find all embedded fields that have been changed
561-
changed = data._get_changed_fields(inspected)
559+
changed = data._get_changed_fields()
562560
changed_fields += ['%s%s' % (key, k) for k in changed if k]
563-
elif (isinstance(data, (list, tuple, dict)) and
564-
db_field_name not in changed_fields):
561+
elif isinstance(data, (list, tuple, dict)):
565562
if (hasattr(field, 'field') and
566563
isinstance(field.field, ReferenceField)):
567564
continue
@@ -572,7 +569,7 @@ def _get_changed_fields(self, inspected=None):
572569
continue
573570

574571
self._nestable_types_changed_fields(
575-
changed_fields, key, data, inspected)
572+
changed_fields, key, data)
576573
return changed_fields
577574

578575
def _delta(self):

tests/document/instance.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,60 @@ class User(self.Person):
14221422
self.assertEqual(person.age, 21)
14231423
self.assertEqual(person.active, False)
14241424

1425+
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
1426+
# Refers to Issue #1685
1427+
class EmbeddedChildModel(EmbeddedDocument):
1428+
id = DictField(primary_key=True)
1429+
1430+
class ParentModel(Document):
1431+
child = EmbeddedDocumentField(
1432+
EmbeddedChildModel)
1433+
1434+
emb = EmbeddedChildModel(id={'1': [1]})
1435+
ParentModel(children=emb)._get_changed_fields()
1436+
1437+
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
1438+
class User(Document):
1439+
id = IntField(primary_key=True)
1440+
name = StringField()
1441+
1442+
class Message(Document):
1443+
id = IntField(primary_key=True)
1444+
author = ReferenceField(User)
1445+
1446+
Message.drop_collection()
1447+
1448+
# All objects share the same id, but each in a different collection
1449+
user = User(id=1, name='user-name').save()
1450+
message = Message(id=1, author=user).save()
1451+
1452+
message.author.name = 'tutu'
1453+
self.assertEqual(message._get_changed_fields(), [])
1454+
self.assertEqual(user._get_changed_fields(), ['name'])
1455+
1456+
def test__get_changed_fields_same_ids_embedded(self):
1457+
# Refers to Issue #1768
1458+
class User(EmbeddedDocument):
1459+
id = IntField()
1460+
name = StringField()
1461+
1462+
class Message(Document):
1463+
id = IntField(primary_key=True)
1464+
author = EmbeddedDocumentField(User)
1465+
1466+
Message.drop_collection()
1467+
1468+
# All objects share the same id, but each in a different collection
1469+
user = User(id=1, name='user-name')#.save()
1470+
message = Message(id=1, author=user).save()
1471+
1472+
message.author.name = 'tutu'
1473+
self.assertEqual(message._get_changed_fields(), ['author.name'])
1474+
message.save()
1475+
1476+
message_fetched = Message.objects.with_id(message.id)
1477+
self.assertEqual(message_fetched.author.name, 'tutu')
1478+
14251479
def test_query_count_when_saving(self):
14261480
"""Ensure references don't cause extra fetches when saving"""
14271481
class Organization(Document):

tests/test_dereference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,6 @@ class Baz(Document):
10291029
self.assertEqual(type(foo.bar), Bar)
10301030
self.assertEqual(type(foo.baz), Baz)
10311031

1032-
10331032
def test_document_reload_reference_integrity(self):
10341033
"""
10351034
Ensure reloading a document with multiple similar id

0 commit comments

Comments
 (0)