diff --git a/CHANGELOG.md b/CHANGELOG.md index 48417e8..653b2d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ List of the most important changes for each release. +## 0.8.5 +- Prevents MPTT corruption that occurs with concurrent certificate creation + ## 0.8.4 - Adds syncable_objects model manager to use by default for all syncing operations, to allow default objects to be overridden diff --git a/morango/__init__.py b/morango/__init__.py index fa3ddd8..af46754 100644 --- a/morango/__init__.py +++ b/morango/__init__.py @@ -1 +1 @@ -__version__ = "0.8.4" +__version__ = "0.8.5" diff --git a/morango/models/certificates.py b/morango/models/certificates.py index d6e9371..5cef6e7 100644 --- a/morango/models/certificates.py +++ b/morango/models/certificates.py @@ -5,6 +5,7 @@ """ import json import string +import logging import mptt.models from django.core.management import call_command @@ -24,7 +25,10 @@ from morango.errors import NonceDoesNotExist from morango.errors import NonceExpired from morango.utils import _assert - +from django.db import transaction, connection +from morango.sync.backends.utils import load_backend +from contextlib import contextmanager +from django.db.utils import OperationalError class Certificate(mptt.models.MPTTModel, UUIDModelMixin): @@ -246,6 +250,37 @@ def verify(self, value, signature): def get_scope(self): return self.scope_definition.get_scope(self.scope_params) + @contextmanager + def _attempt_lock_mptt(self): + from morango.sync.utils import lock_partitions + + DBBackend = load_backend(connection) + + with transaction.atomic(): + # Call get_root on the parent as it is already saved in the DB + root_id = self.parent.get_root().id if self.parent else self.id + + # lock the partitions in our scope to prevent MPTT tree corruption during concurrent certificate creation + lock_partitions(DBBackend, sync_filter=Filter(root_id) if root_id else None) + yield + + @contextmanager + def _lock_mptt(self): + try: + with self._attempt_lock_mptt(): + yield + except OperationalError as e: + if "deadlock detected" in e.args[0]: + logging.error("Deadlock detected when attempting to lock MPTT partitions, retrying once more") + with self._attempt_lock_mptt(): + yield + else: + raise + + def save(self, *args, **kwargs): + with self._lock_mptt(): + super().save(*args, **kwargs) + def __str__(self): if self.scope_definition: return self.scope_definition.get_description(self.scope_params) diff --git a/morango/sync/syncsession.py b/morango/sync/syncsession.py index 0cc69e4..b3f6160 100644 --- a/morango/sync/syncsession.py +++ b/morango/sync/syncsession.py @@ -14,6 +14,7 @@ from requests.adapters import HTTPAdapter from requests.exceptions import HTTPError from requests.packages.urllib3.util.retry import Retry +from django.db import transaction, connection from .session import SessionWrapper from morango.api.serializers import CertificateSerializer @@ -28,9 +29,11 @@ from morango.errors import MorangoResumeSyncError from morango.errors import MorangoServerDoesNotAllowNewCertPush from morango.models.certificates import Certificate +from morango.models.certificates import Filter from morango.models.certificates import Key from morango.models.core import InstanceIDModel from morango.models.core import SyncSession +from morango.sync.backends.utils import load_backend from morango.sync.context import CompositeSessionContext from morango.sync.context import LocalSessionContext from morango.sync.context import NetworkSessionContext @@ -39,6 +42,7 @@ from morango.sync.utils import SyncSignalGroup from morango.utils import CAPABILITIES from morango.utils import pid_exists +from morango.sync.utils import lock_partitions if GZIP_BUFFER_POST in CAPABILITIES: from gzip import GzipFile @@ -46,6 +50,7 @@ logger = logging.getLogger(__name__) +DBBackend = load_backend(connection) def _join_with_logical_operator(lst, operator): op = ") {operator} (".format(operator=operator) @@ -351,11 +356,15 @@ def certificate_signing_request( cert_chain_response = self._get_certificate_chain( params={"ancestors_of": parent_cert.id} ) - - # upon receiving cert chain from server, we attempt to save the chain into our records - Certificate.save_certificate_chain( - cert_chain_response.json(), expected_last_id=parent_cert.id - ) + cert_chain = cert_chain_response.json() + with transaction.atomic(): + lock_partitions(DBBackend, sync_filter=Filter(cert_chain[0]["id"])) + # check again, now that we have a lock + if not Certificate.objects.filter(id=parent_cert.id).exists(): + # upon receiving cert chain from server, we attempt to save the chain into our records + Certificate.save_certificate_chain( + cert_chain, expected_last_id=parent_cert.id + ) csr_key = Key() # build up data for csr