Skip to content

Commit 26e2fc8

Browse files
authored
Merge pull request #1887 from bagerard/fix_changed_fields_issue_same_id_in_nested_doc2
Fix bug where an EmbeddedDocument with the same id as its parent would not be tracked for changes
2 parents 8e18484 + f89214f commit 26e2fc8

File tree

3 files changed

+79
-29
lines changed

3 files changed

+79
-29
lines changed

mongoengine/base/document.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -500,66 +500,63 @@ def _clear_changed_fields(self):
500500

501501
self._changed_fields = []
502502

503-
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
503+
def _nestable_types_changed_fields(self, changed_fields, base_key, data):
504+
"""Inspect nested data for changed fields
505+
506+
:param changed_fields: Previously collected changed fields
507+
:param base_key: The base key that must be used to prepend changes to this data
508+
:param data: data to inspect for changes
509+
"""
504510
# Loop list / dict fields as they contain documents
505511
# Determine the iterator to use
506512
if not hasattr(data, 'items'):
507513
iterator = enumerate(data)
508514
else:
509515
iterator = data.iteritems()
510516

511-
for index, value in iterator:
512-
list_key = '%s%s.' % (key, index)
517+
for index_or_key, value in iterator:
518+
item_key = '%s%s.' % (base_key, index_or_key)
513519
# don't check anything lower if this key is already marked
514520
# as changed.
515-
if list_key[:-1] in changed_fields:
521+
if item_key[:-1] in changed_fields:
516522
continue
523+
517524
if hasattr(value, '_get_changed_fields'):
518-
changed = value._get_changed_fields(inspected)
519-
changed_fields += ['%s%s' % (list_key, k)
520-
for k in changed if k]
525+
changed = value._get_changed_fields()
526+
changed_fields += ['%s%s' % (item_key, k) for k in changed if k]
521527
elif isinstance(value, (list, tuple, dict)):
522528
self._nestable_types_changed_fields(
523-
changed_fields, list_key, value, inspected)
529+
changed_fields, item_key, value)
524530

525-
def _get_changed_fields(self, inspected=None):
531+
def _get_changed_fields(self):
526532
"""Return a list of all fields that have explicitly been changed.
527533
"""
528534
EmbeddedDocument = _import_class('EmbeddedDocument')
529-
DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument')
530535
ReferenceField = _import_class('ReferenceField')
531536
GenericReferenceField = _import_class('GenericReferenceField')
532537
SortedListField = _import_class('SortedListField')
533538

534539
changed_fields = []
535540
changed_fields += getattr(self, '_changed_fields', [])
536541

537-
inspected = inspected or set()
538-
if hasattr(self, 'id') and isinstance(self.id, Hashable):
539-
if self.id in inspected:
540-
return changed_fields
541-
inspected.add(self.id)
542-
543542
for field_name in self._fields_ordered:
544543
db_field_name = self._db_field_map.get(field_name, field_name)
545544
key = '%s.' % db_field_name
546545
data = self._data.get(field_name, None)
547546
field = self._fields.get(field_name)
548547

549-
if hasattr(data, 'id'):
550-
if data.id in inspected:
551-
continue
552-
if isinstance(field, ReferenceField):
548+
if db_field_name in changed_fields:
549+
# Whole field already marked as changed, no need to go further
550+
continue
551+
552+
if isinstance(field, ReferenceField): # Don't follow referenced documents
553553
continue
554-
elif (
555-
isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and
556-
db_field_name not in changed_fields
557-
):
554+
555+
if isinstance(data, EmbeddedDocument):
558556
# Find all embedded fields that have been changed
559-
changed = data._get_changed_fields(inspected)
557+
changed = data._get_changed_fields()
560558
changed_fields += ['%s%s' % (key, k) for k in changed if k]
561-
elif (isinstance(data, (list, tuple, dict)) and
562-
db_field_name not in changed_fields):
559+
elif isinstance(data, (list, tuple, dict)):
563560
if (hasattr(field, 'field') and
564561
isinstance(field.field, (ReferenceField, GenericReferenceField))):
565562
continue
@@ -570,7 +567,7 @@ def _get_changed_fields(self, inspected=None):
570567
continue
571568

572569
self._nestable_types_changed_fields(
573-
changed_fields, key, data, inspected)
570+
changed_fields, key, data)
574571
return changed_fields
575572

576573
def _delta(self):

tests/document/instance.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,60 @@ class User(self.Person):
14221422
self.assertEqual(person.age, 21)
14231423
self.assertEqual(person.active, False)
14241424

1425+
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
1426+
# Refers to Issue #1685
1427+
class EmbeddedChildModel(EmbeddedDocument):
1428+
id = DictField(primary_key=True)
1429+
1430+
class ParentModel(Document):
1431+
child = EmbeddedDocumentField(
1432+
EmbeddedChildModel)
1433+
1434+
emb = EmbeddedChildModel(id={'1': [1]})
1435+
ParentModel(children=emb)._get_changed_fields()
1436+
1437+
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
1438+
class User(Document):
1439+
id = IntField(primary_key=True)
1440+
name = StringField()
1441+
1442+
class Message(Document):
1443+
id = IntField(primary_key=True)
1444+
author = ReferenceField(User)
1445+
1446+
Message.drop_collection()
1447+
1448+
# All objects share the same id, but each in a different collection
1449+
user = User(id=1, name='user-name').save()
1450+
message = Message(id=1, author=user).save()
1451+
1452+
message.author.name = 'tutu'
1453+
self.assertEqual(message._get_changed_fields(), [])
1454+
self.assertEqual(user._get_changed_fields(), ['name'])
1455+
1456+
def test__get_changed_fields_same_ids_embedded(self):
1457+
# Refers to Issue #1768
1458+
class User(EmbeddedDocument):
1459+
id = IntField()
1460+
name = StringField()
1461+
1462+
class Message(Document):
1463+
id = IntField(primary_key=True)
1464+
author = EmbeddedDocumentField(User)
1465+
1466+
Message.drop_collection()
1467+
1468+
# All objects share the same id, but each in a different collection
1469+
user = User(id=1, name='user-name')#.save()
1470+
message = Message(id=1, author=user).save()
1471+
1472+
message.author.name = 'tutu'
1473+
self.assertEqual(message._get_changed_fields(), ['author.name'])
1474+
message.save()
1475+
1476+
message_fetched = Message.objects.with_id(message.id)
1477+
self.assertEqual(message_fetched.author.name, 'tutu')
1478+
14251479
def test_query_count_when_saving(self):
14261480
"""Ensure references don't cause extra fetches when saving"""
14271481
class Organization(Document):

tests/test_dereference.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,6 @@ class Baz(Document):
10291029
self.assertEqual(type(foo.bar), Bar)
10301030
self.assertEqual(type(foo.baz), Baz)
10311031

1032-
10331032
def test_document_reload_reference_integrity(self):
10341033
"""
10351034
Ensure reloading a document with multiple similar id

0 commit comments

Comments
 (0)