Skip to content

Commit 5febcd3

Browse files
committed
refactor connection creation to use Django's APIs
1 parent 86d5885 commit 5febcd3

File tree

3 files changed

+28
-33
lines changed

3 files changed

+28
-33
lines changed

django_mongodb/base.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import contextlib
2+
13
from django.db.backends.base.base import BaseDatabaseWrapper
2-
from django.db.backends.signals import connection_created
34
from pymongo.collection import Collection
45
from pymongo.mongo_client import MongoClient
56

@@ -128,11 +129,6 @@ def _isnull_operator(a, b):
128129
introspection_class = DatabaseIntrospection
129130
ops_class = DatabaseOperations
130131

131-
def __init__(self, *args, **kwargs):
132-
super().__init__(*args, **kwargs)
133-
self.connected = False
134-
del self.connection
135-
136132
def get_collection(self, name, **kwargs):
137133
collection = Collection(self.database, name, **kwargs)
138134
if self.queries_logged:
@@ -145,44 +141,45 @@ def get_database(self):
145141
return self.database
146142

147143
def __getattr__(self, attr):
148-
"""
149-
Connect to the database the first time `connection` or `database` are
150-
accessed.
151-
"""
152-
if attr in ["connection", "database"]:
153-
assert not self.connected
154-
self._connect()
144+
"""Connect to the database the first time `database` is accessed."""
145+
if attr == "database":
146+
if self.connection is None:
147+
self.connect()
155148
return getattr(self, attr)
156149
raise AttributeError(attr)
157150

158-
def _connect(self):
159-
settings_dict = self.settings_dict
160-
self.connection = MongoClient(
161-
host=settings_dict["HOST"] or None,
162-
port=int(settings_dict["PORT"] or 27017),
163-
username=settings_dict.get("USER"),
164-
password=settings_dict.get("PASSWORD"),
165-
**settings_dict["OPTIONS"],
166-
)
167-
db_name = settings_dict["NAME"]
151+
def init_connection_state(self):
152+
db_name = self.settings_dict["NAME"]
168153
if db_name:
169154
self.database = self.connection[db_name]
155+
super().init_connection_state()
156+
157+
def get_connection_params(self):
158+
settings_dict = self.settings_dict
159+
return {
160+
"host": settings_dict["HOST"] or None,
161+
"port": int(settings_dict["PORT"] or 27017),
162+
"username": settings_dict.get("USER"),
163+
"password": settings_dict.get("PASSWORD"),
164+
**settings_dict["OPTIONS"],
165+
}
170166

171-
self.connected = True
172-
connection_created.send(sender=self.__class__, connection=self)
167+
def get_new_connection(self, conn_params):
168+
return MongoClient(**conn_params)
173169

174170
def _commit(self):
175171
pass
176172

177173
def _rollback(self):
178174
pass
179175

176+
def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
177+
pass
178+
180179
def close(self):
181-
if self.connected:
182-
self.connection.close()
183-
del self.connection
180+
super().close()
181+
with contextlib.suppress(AttributeError):
184182
del self.database
185-
self.connected = False
186183

187184
def cursor(self):
188185
return Cursor()

django_mongodb/features.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
7171
"backends.tests.BackendTestCase.test_is_usable_after_database_disconnects",
7272
# Connection creation doesn't follow the usual Django API.
7373
"backends.tests.ThreadTests.test_pass_connection_between_threads",
74-
"backends.tests.ThreadTests.test_closing_non_shared_connections",
7574
"backends.tests.ThreadTests.test_default_connection_thread_local",
7675
# Union as subquery is not mapping the parent parameter and collections:
7776
# https://github.com/mongodb-labs/django-mongodb/issues/156

django_mongodb/query.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def __init__(self, compiler):
4747
self.query = compiler.query
4848
self._negated = False
4949
self.ordering = []
50-
self.collection = self.compiler.collection
5150
self.collection_name = self.compiler.collection_name
5251
self.mongo_query = getattr(compiler.query, "raw_query", {})
5352
self.subqueries = None
@@ -68,15 +67,15 @@ def delete(self):
6867
"""Execute a delete query."""
6968
if self.compiler.subqueries:
7069
raise NotSupportedError("Cannot use QuerySet.delete() when a subquery is required.")
71-
return self.collection.delete_many(self.mongo_query).deleted_count
70+
return self.compiler.collection.delete_many(self.mongo_query).deleted_count
7271

7372
@wrap_database_errors
7473
def get_cursor(self):
7574
"""
7675
Return a pymongo CommandCursor that can be iterated on to give the
7776
results of the query.
7877
"""
79-
return self.collection.aggregate(self.get_pipeline())
78+
return self.compiler.collection.aggregate(self.get_pipeline())
8079

8180
def get_pipeline(self):
8281
pipeline = []

0 commit comments

Comments
 (0)