diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index 7f337cf8..abfed412 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -1,8 +1,11 @@ import contextlib +import logging import os from django.core.exceptions import ImproperlyConfigured +from django.db import DEFAULT_DB_ALIAS from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.backends.utils import debug_transaction from django.utils.asyncio import async_unsafe from django.utils.functional import cached_property from pymongo.collection import Collection @@ -32,6 +35,9 @@ def __exit__(self, exception_type, exception_value, exception_traceback): pass +logger = logging.getLogger("django.db.backends.base") + + class DatabaseWrapper(BaseDatabaseWrapper): data_types = { "AutoField": "int", @@ -142,6 +148,17 @@ def _isnull_operator(a, b): ops_class = DatabaseOperations validation_class = DatabaseValidation + def __init__(self, settings_dict, alias=DEFAULT_DB_ALIAS): + super().__init__(settings_dict, alias=alias) + self.session = None + # Tracks whether the connection is in a transaction managed by + # django_mongodb_backend.transaction.atomic. `in_atomic_block` isn't + # used in case Django's atomic() (used internally in Django) is called + # within this package's atomic(). + self.in_atomic_block_mongo = False + # Current number of nested 'atomic' calls. + self.nested_atomics = 0 + def get_collection(self, name, **kwargs): collection = Collection(self.database, name, **kwargs) if self.queries_logged: @@ -212,6 +229,10 @@ def close(self): def close_pool(self): """Close the MongoClient.""" + # Clear commit hooks and session. + self.run_on_commit = [] + if self.session: + self._end_session() connection = self.connection if connection is None: return @@ -230,3 +251,57 @@ def cursor(self): def get_database_version(self): """Return a tuple of the database's version.""" return tuple(self.connection.server_info()["versionArray"]) + + ## Transaction API for django_mongodb_backend.transaction.atomic() + @async_unsafe + def start_transaction_mongo(self): + if self.session is None: + self.session = self.connection.start_session() + with debug_transaction(self, "session.start_transaction()"): + self.session.start_transaction() + + @async_unsafe + def commit_mongo(self): + if self.session: + with debug_transaction(self, "session.commit_transaction()"): + self.session.commit_transaction() + self._end_session() + self.run_and_clear_commit_hooks() + + @async_unsafe + def rollback_mongo(self): + if self.session: + with debug_transaction(self, "session.abort_transaction()"): + self.session.abort_transaction() + self._end_session() + self.run_on_commit = [] + + def _end_session(self): + self.session.end_session() + self.session = None + + def on_commit(self, func, robust=False): + """ + Copied from BaseDatabaseWrapper.on_commit() except that it checks + in_atomic_block_mongo instead of in_atomic_block. + """ + if not callable(func): + raise TypeError("on_commit()'s callback must be a callable.") + if self.in_atomic_block_mongo: + # Transaction in progress; save for execution on commit. + # The first item in the tuple (an empty list) is normally the + # savepoint IDs, which isn't applicable on MongoDB. + self.run_on_commit.append(([], func, robust)) + else: + # No transaction in progress; execute immediately. + if robust: + try: + func() + except Exception as e: + logger.exception( + "Error calling %s in on_commit() (%s).", + func.__qualname__, + e, + ) + else: + func() diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 9d6e1cc1..9eb71476 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -696,7 +696,9 @@ def execute_sql(self, returning_fields=None): @wrap_database_errors def insert(self, docs, returning_fields=None): """Store a list of documents using field columns as element names.""" - inserted_ids = self.collection.insert_many(docs).inserted_ids + inserted_ids = self.collection.insert_many( + docs, session=self.connection.session + ).inserted_ids return [(x,) for x in inserted_ids] if returning_fields else [] @cached_property @@ -777,7 +779,9 @@ def execute_sql(self, result_type): @wrap_database_errors def update(self, criteria, pipeline): - return self.collection.update_many(criteria, pipeline).matched_count + return self.collection.update_many( + criteria, pipeline, session=self.connection.session + ).matched_count def check_query(self): super().check_query() diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index 25122cc5..fb627657 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -63,7 +63,9 @@ def delete(self): """Execute a delete query.""" if self.compiler.subqueries: raise NotSupportedError("Cannot use QuerySet.delete() when a subquery is required.") - return self.compiler.collection.delete_many(self.match_mql).deleted_count + return self.compiler.collection.delete_many( + self.match_mql, session=self.compiler.connection.session + ).deleted_count @wrap_database_errors def get_cursor(self): @@ -71,7 +73,9 @@ def get_cursor(self): Return a pymongo CommandCursor that can be iterated on to give the results of the query. """ - return self.compiler.collection.aggregate(self.get_pipeline()) + return self.compiler.collection.aggregate( + self.get_pipeline(), session=self.compiler.connection.session + ) def get_pipeline(self): pipeline = [] diff --git a/django_mongodb_backend/queryset.py b/django_mongodb_backend/queryset.py index 4a2d884d..b02496fe 100644 --- a/django_mongodb_backend/queryset.py +++ b/django_mongodb_backend/queryset.py @@ -35,7 +35,7 @@ def __init__(self, pipeline, using, model): def _execute_query(self): connection = connections[self.using] collection = connection.get_collection(self.model._meta.db_table) - self.cursor = collection.aggregate(self.pipeline) + self.cursor = collection.aggregate(self.pipeline, session=connection.session) def __str__(self): return str(self.pipeline) diff --git a/django_mongodb_backend/transaction.py b/django_mongodb_backend/transaction.py new file mode 100644 index 00000000..7657bf91 --- /dev/null +++ b/django_mongodb_backend/transaction.py @@ -0,0 +1,61 @@ +from contextlib import ContextDecorator + +from django.db import DEFAULT_DB_ALIAS, DatabaseError +from django.db.transaction import get_connection, on_commit + +__all__ = [ + "atomic", + "on_commit", # convenience alias +] + + +class Atomic(ContextDecorator): + """ + Guarantee the atomic execution of a given block. + + Simplified from django.db.transaction. + """ + + def __init__(self, using): + self.using = using + + def __enter__(self): + connection = get_connection(self.using) + if connection.in_atomic_block_mongo: + # Track the number of nested atomic() calls. + connection.nested_atomics += 1 + else: + # Start a transaction for the outermost atomic(). + connection.start_transaction_mongo() + connection.in_atomic_block_mongo = True + + def __exit__(self, exc_type, exc_value, traceback): + connection = get_connection(self.using) + if connection.nested_atomics: + # Exiting inner atomic. + connection.nested_atomics -= 1 + else: + # Reset flag when exiting outer atomic. + connection.in_atomic_block_mongo = False + if exc_type is None: + # atomic() exited without an error. + if not connection.in_atomic_block_mongo: + # Commit transaction if outer atomic(). + try: + connection.commit_mongo() + except DatabaseError: + connection.rollback_mongo() + else: + # atomic() exited with an error. + if not connection.in_atomic_block_mongo: + # Rollback transaction if outer atomic(). + connection.rollback_mongo() + + +def atomic(using=None): + # Bare decorator: @atomic -- although the first argument is called `using`, + # it's actually the function being decorated. + if callable(using): + return Atomic(DEFAULT_DB_ALIAS)(using) + # Decorator: @atomic(...) or context manager: with atomic(...): ... + return Atomic(using) diff --git a/docs/source/releases/5.2.x.rst b/docs/source/releases/5.2.x.rst index b672aed2..21b8a93e 100644 --- a/docs/source/releases/5.2.x.rst +++ b/docs/source/releases/5.2.x.rst @@ -13,6 +13,7 @@ New features - Added subquery support for :class:`~.fields.EmbeddedModelArrayField`. - Added the ``options`` parameter to :func:`~django_mongodb_backend.utils.parse_uri`. +- Added support for :ref:`database transactions `. - Added :class:`~.fields.PolymorphicEmbeddedModelField` and :class:`~.fields.PolymorphicEmbeddedModelArrayField` for storing a model instance or list of model instances that may be of more than one model class. diff --git a/docs/source/topics/index.rst b/docs/source/topics/index.rst index 63ff9a25..6e06b812 100644 --- a/docs/source/topics/index.rst +++ b/docs/source/topics/index.rst @@ -9,4 +9,5 @@ know: :maxdepth: 2 embedded-models + transactions known-issues diff --git a/docs/source/topics/known-issues.rst b/docs/source/topics/known-issues.rst index 5dbadf03..df71a946 100644 --- a/docs/source/topics/known-issues.rst +++ b/docs/source/topics/known-issues.rst @@ -80,11 +80,12 @@ Database functions Transaction management ====================== -Query execution uses Django and MongoDB's default behavior of autocommit mode. -Each query is immediately committed to the database. +By default, query execution uses Django and MongoDB's default behavior of autocommit +mode. Each query is immediately committed to the database. Django's :doc:`transaction management APIs ` -are not supported. +are not supported. Instead, this package provides its own :doc:`transaction APIs +`. Database introspection ====================== diff --git a/docs/source/topics/transactions.rst b/docs/source/topics/transactions.rst new file mode 100644 index 00000000..d316b46c --- /dev/null +++ b/docs/source/topics/transactions.rst @@ -0,0 +1,142 @@ +============ +Transactions +============ + +.. versionadded:: 5.2.0b2 + +.. module:: django_mongodb_backend.transaction + +MongoDB supports :doc:`transactions ` if it's +configured as a :doc:`replica set ` or a :doc:`sharded +cluster `. + +Because MongoDB transactions have some limitations and are not meant to be used +as freely as SQL transactions, :doc:`Django's transactions APIs +`, including most notably +:func:`django.db.transaction.atomic`, function as no-ops. + +Instead, Django MongoDB Backend provides its own +:func:`django_mongodb_backend.transaction.atomic` function. + +Outside of a transaction, query execution uses Django and MongoDB's default +behavior of autocommit mode. Each query is immediately committed to the +database. + +Controlling transactions +======================== + +.. function:: atomic(using=None) + + Atomicity is the defining property of database transactions. ``atomic`` + allows creating a block of code within which the atomicity on the database + is guaranteed. If the block of code is successfully completed, the changes + are committed to the database. If there is an exception, the changes are + rolled back. + + ``atomic`` is usable both as a :py:term:`decorator`:: + + from django_mongodb_backend import transaction + + + @transaction.atomic + def viewfunc(request): + # This code executes inside a transaction. + do_stuff() + + and as a :py:term:`context manager`:: + + from django_mongodb_backend import transaction + + + def viewfunc(request): + # This code executes in autocommit mode (Django's default). + do_stuff() + + with transaction.atomic(): + # This code executes inside a transaction. + do_more_stuff() + + .. admonition:: Avoid catching exceptions inside ``atomic``! + + When exiting an ``atomic`` block, Django looks at whether it's exited + normally or with an exception to determine whether to commit or roll + back. If you catch and handle exceptions inside an ``atomic`` block, + you may hide from Django the fact that a problem has happened. This can + result in unexpected behavior. + + This is mostly a concern for :exc:`~django.db.DatabaseError` and its + subclasses such as :exc:`~django.db.IntegrityError`. After such an + error, the transaction is broken and Django will perform a rollback at + the end of the ``atomic`` block. + + .. admonition:: You may need to manually revert app state when rolling back a transaction. + + The values of a model's fields won't be reverted when a transaction + rollback happens. This could lead to an inconsistent model state unless + you manually restore the original field values. + + For example, given ``MyModel`` with an ``active`` field, this snippet + ensures that the ``if obj.active`` check at the end uses the correct + value if updating ``active`` to ``True`` fails in the transaction:: + + from django_mongodb_backend import transaction + from django.db import DatabaseError + + obj = MyModel(active=False) + obj.active = True + try: + with transaction.atomic(): + obj.save() + except DatabaseError: + obj.active = False + + if obj.active: + ... + + This also applies to any other mechanism that may hold app state, such + as caching or global variables. For example, if the code proactively + updates data in the cache after saving an object, it's recommended to + use :ref:`transaction.on_commit() ` + instead, to defer cache alterations until the transaction is actually + committed. + + ``atomic`` takes a ``using`` argument which should be the name of a + database. If this argument isn't provided, Django uses the ``"default"`` + database. + +.. admonition:: Performance considerations + + Open transactions have a performance cost for your MongoDB server. To + minimize this overhead, keep your transactions as short as possible. This + is especially important if you're using :func:`atomic` in long-running + processes, outside of Django's request / response cycle. + +Performing actions after commit +=============================== + +The :func:`atomic` function supports Django's +:func:`~django.db.transaction.on_commit` API to :ref:`perform actions after a +transaction successfully commits `. + +For convenience, :func:`~django.db.transaction.on_commit` is aliased at +``django_mongodb_backend.transaction.on_commit`` so you can use both:: + + from django_mongodb_backend import transaction + + + transaction.atomic() + transaction.on_commit(...) + +.. _transactions-limitations: + +Limitations +=========== + +MongoDB's transaction limitations that are applicable to Django are: + +- :meth:`QuerySet.union() ` is not + supported inside a transaction. +- Savepoints (i.e. nested :func:`~django.db.transaction.atomic` blocks) aren't + supported. The outermost :func:`~django.db.transaction.atomic` will start + a transaction while any inner :func:`~django.db.transaction.atomic` blocks + have no effect. diff --git a/tests/raw_query_/test_raw_aggregate.py b/tests/raw_query_/test_raw_aggregate.py index f16ce3cb..99dcd5fa 100644 --- a/tests/raw_query_/test_raw_aggregate.py +++ b/tests/raw_query_/test_raw_aggregate.py @@ -6,8 +6,9 @@ from django.core.exceptions import FieldDoesNotExist from django.db import connection -from django.test import TestCase +from django.test import TestCase, skipUnlessDBFeature +from django_mongodb_backend import transaction from django_mongodb_backend.queryset import RawQuerySet from .models import ( @@ -326,3 +327,13 @@ def test_bool(self): def test_len(self): self.assertEqual(len(Book.objects.raw_aggregate([])), 4) self.assertEqual(len(Book.objects.raw_aggregate([{"$match": {"id": 0}}])), 0) + + @skipUnlessDBFeature("_supports_transactions") + def test_transaction(self): + count = Author.objects.count() + with self.assertRaisesMessage(Exception, "Oops"), transaction.atomic(): + Author.objects.update(last_name="Haddock") + # Changes within in a transaction are visible to raw_aggregate(). + query = [{"$match": {"last_name": "Haddock"}}] + self.assertEqual(len(Author.objects.raw_aggregate(query)), count) + raise Exception("Oops") diff --git a/tests/transaction_hooks_/__init__.py b/tests/transaction_hooks_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/transaction_hooks_/models.py b/tests/transaction_hooks_/models.py new file mode 100644 index 00000000..000fd71c --- /dev/null +++ b/tests/transaction_hooks_/models.py @@ -0,0 +1,5 @@ +from django.db import models + + +class Thing(models.Model): + num = models.IntegerField() diff --git a/tests/transaction_hooks_/tests.py b/tests/transaction_hooks_/tests.py new file mode 100644 index 00000000..5f904d8a --- /dev/null +++ b/tests/transaction_hooks_/tests.py @@ -0,0 +1,209 @@ +from django.db import connection +from django.test import TransactionTestCase, skipUnlessDBFeature + +from django_mongodb_backend import transaction + +from .models import Thing + + +class ForcedError(Exception): + pass + + +@skipUnlessDBFeature("_supports_transactions") +class TestConnectionOnCommit(TransactionTestCase): + """Largely copied from Django's test/transaction_hooks.""" + + available_apps = ["transaction_hooks_"] + + def setUp(self): + self.notified = [] + + def notify(self, id_): + if id_ == "error": + raise ForcedError() + self.notified.append(id_) + + def do(self, num): + """Create a Thing instance and notify about it.""" + Thing.objects.create(num=num) + transaction.on_commit(lambda: self.notify(num)) + + def assertDone(self, nums): + self.assertNotified(nums) + self.assertEqual(sorted(t.num for t in Thing.objects.all()), sorted(nums)) + + def assertNotified(self, nums): + self.assertEqual(self.notified, nums) + + def test_executes_immediately_if_no_transaction(self): + self.do(1) + self.assertDone([1]) + + def test_robust_if_no_transaction(self): + def robust_callback(): + raise ForcedError("robust callback") + + with self.assertLogs("django.db.backends.base", "ERROR") as cm: + transaction.on_commit(robust_callback, robust=True) + self.do(1) + + self.assertDone([1]) + log_record = cm.records[0] + self.assertEqual( + log_record.getMessage(), + "Error calling TestConnectionOnCommit.test_robust_if_no_transaction." + ".robust_callback in on_commit() (robust callback).", + ) + self.assertIsNotNone(log_record.exc_info) + raised_exception = log_record.exc_info[1] + self.assertIsInstance(raised_exception, ForcedError) + self.assertEqual(str(raised_exception), "robust callback") + + def test_robust_transaction(self): + def robust_callback(): + raise ForcedError("robust callback") + + with self.assertLogs("django.db.backends", "ERROR") as cm, transaction.atomic(): + transaction.on_commit(robust_callback, robust=True) + self.do(1) + + self.assertDone([1]) + log_record = cm.records[0] + self.assertEqual( + log_record.getMessage(), + "Error calling TestConnectionOnCommit.test_robust_transaction.." + "robust_callback in on_commit() during transaction (robust callback).", + ) + self.assertIsNotNone(log_record.exc_info) + raised_exception = log_record.exc_info[1] + self.assertIsInstance(raised_exception, ForcedError) + self.assertEqual(str(raised_exception), "robust callback") + + def test_delays_execution_until_after_transaction_commit(self): + with transaction.atomic(): + self.do(1) + self.assertNotified([]) + self.assertDone([1]) + + def test_does_not_execute_if_transaction_rolled_back(self): + try: + with transaction.atomic(): + self.do(1) + raise ForcedError() + except ForcedError: + pass + + self.assertDone([]) + + def test_executes_only_after_final_transaction_committed(self): + with transaction.atomic(): + with transaction.atomic(): + self.do(1) + self.assertNotified([]) + self.assertNotified([]) + self.assertDone([1]) + + def test_no_hooks_run_from_failed_transaction(self): + """If outer transaction fails, no hooks from within it run.""" + try: + with transaction.atomic(): + with transaction.atomic(): + self.do(1) + raise ForcedError() + except ForcedError: + pass + + self.assertDone([]) + + def test_runs_hooks_in_order_registered(self): + with transaction.atomic(): + self.do(1) + with transaction.atomic(): + self.do(2) + self.do(3) + + self.assertDone([1, 2, 3]) + + def test_hooks_cleared_after_successful_commit(self): + with transaction.atomic(): + self.do(1) + with transaction.atomic(): + self.do(2) + + self.assertDone([1, 2]) # not [1, 1, 2] + + def test_hooks_cleared_after_rollback(self): + try: + with transaction.atomic(): + self.do(1) + raise ForcedError() + except ForcedError: + pass + + with transaction.atomic(): + self.do(2) + + self.assertDone([2]) + + @skipUnlessDBFeature("test_db_allows_multiple_connections") + def test_hooks_cleared_on_reconnect(self): + with transaction.atomic(): + self.do(1) + connection.close_pool() + + connection.connect() + + with transaction.atomic(): + self.do(2) + + self.assertDone([2]) + + def test_error_in_hook_doesnt_prevent_clearing_hooks(self): + try: + with transaction.atomic(): + transaction.on_commit(lambda: self.notify("error")) + except ForcedError: + pass + + with transaction.atomic(): + self.do(1) + + self.assertDone([1]) + + def test_db_query_in_hook(self): + with transaction.atomic(): + Thing.objects.create(num=1) + transaction.on_commit(lambda: [self.notify(t.num) for t in Thing.objects.all()]) + + self.assertDone([1]) + + def test_transaction_in_hook(self): + def on_commit(): + with transaction.atomic(): + t = Thing.objects.create(num=1) + self.notify(t.num) + + with transaction.atomic(): + transaction.on_commit(on_commit) + + self.assertDone([1]) + + def test_hook_in_hook(self): + def on_commit(i, add_hook): + with transaction.atomic(): + if add_hook: + transaction.on_commit(lambda: on_commit(i + 10, False)) + t = Thing.objects.create(num=i) + self.notify(t.num) + + with transaction.atomic(): + transaction.on_commit(lambda: on_commit(1, True)) + transaction.on_commit(lambda: on_commit(2, True)) + + self.assertDone([1, 11, 2, 12]) + + def test_raises_exception_non_callable(self): + msg = "on_commit()'s callback must be a callable." + with self.assertRaisesMessage(TypeError, msg): + transaction.on_commit(None) diff --git a/tests/transactions_/__init__.py b/tests/transactions_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/transactions_/models.py b/tests/transactions_/models.py new file mode 100644 index 00000000..bf1a635d --- /dev/null +++ b/tests/transactions_/models.py @@ -0,0 +1,13 @@ +from django.db import models + + +class Reporter(models.Model): + first_name = models.CharField(max_length=30) + last_name = models.CharField(max_length=30) + email = models.EmailField() + + class Meta: + ordering = ("first_name", "last_name") + + def __str__(self): + return f"{self.first_name} {self.last_name}".strip() diff --git a/tests/transactions_/tests.py b/tests/transactions_/tests.py new file mode 100644 index 00000000..0a2b4ced --- /dev/null +++ b/tests/transactions_/tests.py @@ -0,0 +1,152 @@ +from django.db import DatabaseError +from django.test import TransactionTestCase, skipIfDBFeature, skipUnlessDBFeature + +from django_mongodb_backend import transaction + +from .models import Reporter + + +@skipUnlessDBFeature("_supports_transactions") +class AtomicTests(TransactionTestCase): + """Largely copied from Django's test/transactions.""" + + available_apps = ["transactions_"] + + def test_decorator_syntax_commit(self): + @transaction.atomic + def make_reporter(): + return Reporter.objects.create(first_name="Tintin") + + reporter = make_reporter() + self.assertSequenceEqual(Reporter.objects.all(), [reporter]) + + def test_decorator_syntax_rollback(self): + @transaction.atomic + def make_reporter(): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + + with self.assertRaisesMessage(Exception, "Oops"): + make_reporter() + self.assertSequenceEqual(Reporter.objects.all(), []) + + def test_alternate_decorator_syntax_commit(self): + @transaction.atomic() + def make_reporter(): + return Reporter.objects.create(first_name="Tintin") + + reporter = make_reporter() + self.assertSequenceEqual(Reporter.objects.all(), [reporter]) + + def test_alternate_decorator_syntax_rollback(self): + @transaction.atomic() + def make_reporter(): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + + with self.assertRaisesMessage(Exception, "Oops"): + make_reporter() + self.assertSequenceEqual(Reporter.objects.all(), []) + + def test_commit(self): + with transaction.atomic(): + reporter = Reporter.objects.create(first_name="Tintin") + self.assertSequenceEqual(Reporter.objects.all(), [reporter]) + + def test_rollback(self): + with self.assertRaisesMessage(Exception, "Oops"), transaction.atomic(): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + self.assertSequenceEqual(Reporter.objects.all(), []) + + def test_nested_commit_commit(self): + with transaction.atomic(): + reporter1 = Reporter.objects.create(first_name="Tintin") + with transaction.atomic(): + reporter2 = Reporter.objects.create(first_name="Archibald", last_name="Haddock") + self.assertSequenceEqual(Reporter.objects.all(), [reporter2, reporter1]) + + def test_nested_rollback_commit(self): + with self.assertRaisesMessage(Exception, "Oops"), transaction.atomic(): + Reporter.objects.create(last_name="Tintin") + with transaction.atomic(): + Reporter.objects.create(last_name="Haddock") + raise Exception("Oops, that's his first name") + self.assertSequenceEqual(Reporter.objects.all(), []) + + def test_nested_rollback_rollback(self): + with self.assertRaisesMessage(Exception, "Oops"), transaction.atomic(): + Reporter.objects.create(last_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"): + with transaction.atomic(): + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + raise Exception("Oops, that's his first name") + self.assertSequenceEqual(Reporter.objects.all(), []) + + def test_reuse_commit_commit(self): + atomic = transaction.atomic() + with atomic: + reporter1 = Reporter.objects.create(first_name="Tintin") + with atomic: + reporter2 = Reporter.objects.create(first_name="Archibald", last_name="Haddock") + self.assertSequenceEqual(Reporter.objects.all(), [reporter2, reporter1]) + + def test_reuse_rollback_commit(self): + atomic = transaction.atomic() + with self.assertRaisesMessage(Exception, "Oops"), atomic: + Reporter.objects.create(last_name="Tintin") + with atomic: + Reporter.objects.create(last_name="Haddock") + raise Exception("Oops, that's his first name") + self.assertSequenceEqual(Reporter.objects.all(), []) + + def test_reuse_rollback_rollback(self): + atomic = transaction.atomic() + with self.assertRaisesMessage(Exception, "Oops"), atomic: + Reporter.objects.create(last_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"): + with atomic: + Reporter.objects.create(first_name="Haddock") + raise Exception("Oops, that's his last name") + raise Exception("Oops, that's his first name") + self.assertSequenceEqual(Reporter.objects.all(), []) + + def test_rollback_update(self): + r = Reporter.objects.create(last_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"), transaction.atomic(): + Reporter.objects.update(last_name="Haddock") + # The update is visible in the transaction. + r.refresh_from_db() + self.assertEqual(r.last_name, "Haddock") + raise Exception("Oops") + # But is now rolled back. + r.refresh_from_db() + self.assertEqual(r.last_name, "Tintin") + + def test_rollback_delete(self): + r = Reporter.objects.create(last_name="Tintin") + with self.assertRaisesMessage(Exception, "Oops"), transaction.atomic(): + Reporter.objects.all().delete() + raise Exception("Oops") + self.assertSequenceEqual(Reporter.objects.all(), [r]) + + def test_wrap_callable_instance(self): + """Atomic can wrap callable instances.""" + + class Callable: + def __call__(self): + pass + + transaction.atomic(Callable()) # Must not raise an exception + + +@skipIfDBFeature("_supports_transactions") +class AtomicNotSupportedTests(TransactionTestCase): + available_apps = ["transactions_"] + + def test_not_supported(self): + # If transactions aren't supported, MongoDB raises an error: + # "Transaction numbers are only allowed on a replica set member or mongos" + with self.assertRaises(DatabaseError), transaction.atomic(): + Reporter.objects.create(first_name="Haddock")