Skip to content

Commit a9bf52c

Browse files
committed
Prevent certificate tree corruption
1 parent 965329e commit a9bf52c

File tree

2 files changed

+50
-6
lines changed

2 files changed

+50
-6
lines changed

morango/models/certificates.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import json
77
import string
8+
import logging
89

910
import mptt.models
1011
from django.core.management import call_command
@@ -24,7 +25,10 @@
2425
from morango.errors import NonceDoesNotExist
2526
from morango.errors import NonceExpired
2627
from morango.utils import _assert
27-
28+
from django.db import transaction, connection
29+
from morango.sync.backends.utils import load_backend
30+
from contextlib import contextmanager
31+
from django.db.utils import OperationalError
2832

2933
class Certificate(mptt.models.MPTTModel, UUIDModelMixin):
3034

@@ -246,6 +250,37 @@ def verify(self, value, signature):
246250
def get_scope(self):
247251
return self.scope_definition.get_scope(self.scope_params)
248252

253+
@contextmanager
254+
def _attempt_lock_mptt(self):
255+
from morango.sync.utils import lock_partitions
256+
257+
DBBackend = load_backend(connection)
258+
259+
with transaction.atomic():
260+
# Call get_root on the parent as it is already saved in the DB
261+
root_id = self.parent.get_root().id if self.parent else self.id
262+
263+
# lock the partitions in our scope to prevent MPTT tree corruption during concurrent certificate creation
264+
lock_partitions(DBBackend, sync_filter=Filter(root_id) if root_id else None)
265+
yield
266+
267+
@contextmanager
268+
def _lock_mptt(self):
269+
try:
270+
with self._attempt_lock_mptt():
271+
yield
272+
except OperationalError as e:
273+
if "deadlock detected" in e.args[0]:
274+
logging.error("Deadlock detected when attempting to lock MPTT partitions, retrying once more")
275+
with self._attempt_lock_mptt():
276+
yield
277+
else:
278+
raise
279+
280+
def save(self, *args, **kwargs):
281+
with self._lock_mptt():
282+
super().save(*args, **kwargs)
283+
249284
def __str__(self):
250285
if self.scope_definition:
251286
return self.scope_definition.get_description(self.scope_params)

morango/sync/syncsession.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from requests.adapters import HTTPAdapter
1515
from requests.exceptions import HTTPError
1616
from requests.packages.urllib3.util.retry import Retry
17+
from django.db import transaction, connection
1718

1819
from .session import SessionWrapper
1920
from morango.api.serializers import CertificateSerializer
@@ -28,9 +29,11 @@
2829
from morango.errors import MorangoResumeSyncError
2930
from morango.errors import MorangoServerDoesNotAllowNewCertPush
3031
from morango.models.certificates import Certificate
32+
from morango.models.certificates import Filter
3133
from morango.models.certificates import Key
3234
from morango.models.core import InstanceIDModel
3335
from morango.models.core import SyncSession
36+
from morango.sync.backends.utils import load_backend
3437
from morango.sync.context import CompositeSessionContext
3538
from morango.sync.context import LocalSessionContext
3639
from morango.sync.context import NetworkSessionContext
@@ -39,13 +42,15 @@
3942
from morango.sync.utils import SyncSignalGroup
4043
from morango.utils import CAPABILITIES
4144
from morango.utils import pid_exists
45+
from morango.sync.utils import lock_partitions
4246

4347
if GZIP_BUFFER_POST in CAPABILITIES:
4448
from gzip import GzipFile
4549

4650

4751
logger = logging.getLogger(__name__)
4852

53+
DBBackend = load_backend(connection)
4954

5055
def _join_with_logical_operator(lst, operator):
5156
op = ") {operator} (".format(operator=operator)
@@ -351,11 +356,15 @@ def certificate_signing_request(
351356
cert_chain_response = self._get_certificate_chain(
352357
params={"ancestors_of": parent_cert.id}
353358
)
354-
355-
# upon receiving cert chain from server, we attempt to save the chain into our records
356-
Certificate.save_certificate_chain(
357-
cert_chain_response.json(), expected_last_id=parent_cert.id
358-
)
359+
cert_chain = cert_chain_response.json()
360+
with transaction.atomic():
361+
lock_partitions(DBBackend, sync_filter=Filter(cert_chain[0]["id"]))
362+
# check again, now that we have a lock
363+
if not Certificate.objects.filter(id=parent_cert.id).exists():
364+
# upon receiving cert chain from server, we attempt to save the chain into our records
365+
Certificate.save_certificate_chain(
366+
cert_chain, expected_last_id=parent_cert.id
367+
)
359368

360369
csr_key = Key()
361370
# build up data for csr

0 commit comments

Comments
 (0)