Skip to content

Commit 406805d

Browse files
committed
refactor connection creation to use Django's APIs
1 parent a06efe3 commit 406805d

File tree

3 files changed

+33
-31
lines changed

3 files changed

+33
-31
lines changed

django_mongodb/base.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import contextlib
2+
13
from django.db.backends.base.base import BaseDatabaseWrapper
2-
from django.db.backends.signals import connection_created
34
from pymongo.collection import Collection
45
from pymongo.mongo_client import MongoClient
56

@@ -128,11 +129,6 @@ def _isnull_operator(a, b):
128129
introspection_class = DatabaseIntrospection
129130
ops_class = DatabaseOperations
130131

131-
def __init__(self, *args, **kwargs):
132-
super().__init__(*args, **kwargs)
133-
self.connected = False
134-
del self.connection
135-
136132
def get_collection(self, name, **kwargs):
137133
collection = Collection(self.database, name, **kwargs)
138134
if self.queries_logged:
@@ -145,44 +141,45 @@ def get_database(self):
145141
return self.database
146142

147143
def __getattr__(self, attr):
148-
"""
149-
Connect to the database the first time `connection` or `database` are
150-
accessed.
151-
"""
152-
if attr in ["connection", "database"]:
153-
assert not self.connected
154-
self._connect()
144+
"""Connect to the database the first time `database` is accessed."""
145+
if attr == "database":
146+
if self.connection is None:
147+
self.connect()
155148
return getattr(self, attr)
156149
raise AttributeError(attr)
157150

158-
def _connect(self):
159-
settings_dict = self.settings_dict
160-
self.connection = MongoClient(
161-
host=settings_dict["HOST"] or None,
162-
port=int(settings_dict["PORT"] or 27017),
163-
username=settings_dict.get("USER"),
164-
password=settings_dict.get("PASSWORD"),
165-
**settings_dict["OPTIONS"],
166-
)
167-
db_name = settings_dict["NAME"]
151+
def init_connection_state(self):
152+
db_name = self.settings_dict["NAME"]
168153
if db_name:
169154
self.database = self.connection[db_name]
155+
super().init_connection_state()
156+
157+
def get_connection_params(self):
158+
settings_dict = self.settings_dict
159+
return {
160+
"host": settings_dict["HOST"] or None,
161+
"port": int(settings_dict["PORT"] or 27017),
162+
"username": settings_dict.get("USER"),
163+
"password": settings_dict.get("PASSWORD"),
164+
**settings_dict["OPTIONS"],
165+
}
170166

171-
self.connected = True
172-
connection_created.send(sender=self.__class__, connection=self)
167+
def get_new_connection(self, conn_params):
168+
return MongoClient(**conn_params)
173169

174170
def _commit(self):
175171
pass
176172

177173
def _rollback(self):
178174
pass
179175

176+
def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
177+
pass
178+
180179
def close(self):
181-
if self.connected:
182-
self.connection.close()
183-
del self.connection
180+
super().close()
181+
with contextlib.suppress(AttributeError):
184182
del self.database
185-
self.connected = False
186183

187184
def cursor(self):
188185
return Cursor()

django_mongodb/features.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
7171
"backends.tests.BackendTestCase.test_is_usable_after_database_disconnects",
7272
# Connection creation doesn't follow the usual Django API.
7373
"backends.tests.ThreadTests.test_pass_connection_between_threads",
74-
"backends.tests.ThreadTests.test_closing_non_shared_connections",
7574
"backends.tests.ThreadTests.test_default_connection_thread_local",
7675
# Union as subquery is not mapping the parent parameter and collections:
7776
# https://github.com/mongodb-labs/django-mongodb/issues/156

django_mongodb/query.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from django.db.models.sql.constants import INNER
1010
from django.db.models.sql.datastructures import Join
1111
from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode
12+
from django.utils.functional import cached_property
1213
from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError
1314

1415

@@ -47,7 +48,6 @@ def __init__(self, compiler):
4748
self.query = compiler.query
4849
self._negated = False
4950
self.ordering = []
50-
self.collection = self.compiler.collection
5151
self.collection_name = self.compiler.collection_name
5252
self.mongo_query = getattr(compiler.query, "raw_query", {})
5353
self.subqueries = None
@@ -60,6 +60,12 @@ def __init__(self, compiler):
6060
# subquery.
6161
self.subquery_lookup = None
6262

63+
@cached_property
64+
def collection(self):
65+
# Initialize this lazily since `compiler.collection` connects to the
66+
# database.
67+
return self.compiler.collection
68+
6369
def __repr__(self):
6470
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
6571

0 commit comments

Comments
 (0)