diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index e28fdc26b..9eb505823 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -4,6 +4,7 @@ from django.core.exceptions import ImproperlyConfigured from django.db.backends.base.base import BaseDatabaseWrapper from django.utils.asyncio import async_unsafe +from django.utils.functional import cached_property from pymongo.collection import Collection from pymongo.driver_info import DriverInfo from pymongo.mongo_client import MongoClient @@ -149,13 +150,13 @@ def get_database(self): return OperationDebugWrapper(self) return self.database - def __getattr__(self, attr): - """Connect to the database the first time `database` is accessed.""" - if attr == "database": - if self.connection is None: - self.connect() - return getattr(self, attr) - raise AttributeError(attr) + @cached_property + def database(self): + """Connect to the database the first time it's accessed.""" + if self.connection is None: + self.connect() + # Cache the database attribute set by init_connection_state() + return self.database def init_connection_state(self): self.database = self.connection[self.settings_dict["NAME"]] diff --git a/tests/backend_/test_base.py b/tests/backend_/test_base.py index 2ce48bbeb..5c599f73b 100644 --- a/tests/backend_/test_base.py +++ b/tests/backend_/test_base.py @@ -1,5 +1,6 @@ from django.core.exceptions import ImproperlyConfigured from django.db import connection +from django.db.backends.signals import connection_created from django.test import SimpleTestCase, TestCase from django_mongodb_backend.base import DatabaseWrapper @@ -21,3 +22,23 @@ def test_set_autocommit(self): self.assertIs(connection.get_autocommit(), False) connection.set_autocommit(True) self.assertIs(connection.get_autocommit(), True) + + def test_connection_created_database_attr(self): + """ + connection.database is available in the connection_created signal. + """ + data = {} + + def receiver(sender, connection, **kwargs): # noqa: ARG001 + data["database"] = connection.database + + connection_created.connect(receiver) + connection.close() + # Accessing database implicitly connects. + connection.database # noqa: B018 + self.assertIs(data["database"], connection.database) + connection.close() + connection_created.disconnect(receiver) + data.clear() + connection.connect() + self.assertEqual(data, {})