Skip to content

refactor connection creation to use Django's APIs #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions django_mongodb/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.signals import connection_created
from pymongo.collection import Collection
from pymongo.mongo_client import MongoClient

Expand Down Expand Up @@ -128,11 +129,6 @@ def _isnull_operator(a, b):
introspection_class = DatabaseIntrospection
ops_class = DatabaseOperations

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.connected = False
del self.connection

def get_collection(self, name, **kwargs):
collection = Collection(self.database, name, **kwargs)
if self.queries_logged:
Expand All @@ -145,44 +141,45 @@ def get_database(self):
return self.database

def __getattr__(self, attr):
"""
Connect to the database the first time `connection` or `database` are
accessed.
"""
if attr in ["connection", "database"]:
assert not self.connected
self._connect()
"""Connect to the database the first time `database` is accessed."""
if attr == "database":
if self.connection is None:
self.connect()
return getattr(self, attr)
raise AttributeError(attr)

def _connect(self):
settings_dict = self.settings_dict
self.connection = MongoClient(
host=settings_dict["HOST"] or None,
port=int(settings_dict["PORT"] or 27017),
username=settings_dict.get("USER"),
password=settings_dict.get("PASSWORD"),
**settings_dict["OPTIONS"],
)
db_name = settings_dict["NAME"]
def init_connection_state(self):
db_name = self.settings_dict["NAME"]
if db_name:
self.database = self.connection[db_name]
super().init_connection_state()

def get_connection_params(self):
settings_dict = self.settings_dict
return {
"host": settings_dict["HOST"] or None,
"port": int(settings_dict["PORT"] or 27017),
"username": settings_dict.get("USER"),
"password": settings_dict.get("PASSWORD"),
**settings_dict["OPTIONS"],
}

self.connected = True
connection_created.send(sender=self.__class__, connection=self)
def get_new_connection(self, conn_params):
return MongoClient(**conn_params)

def _commit(self):
pass

def _rollback(self):
pass

def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False):
pass

def close(self):
if self.connected:
self.connection.close()
del self.connection
super().close()
with contextlib.suppress(AttributeError):
del self.database
self.connected = False

def cursor(self):
return Cursor()
Expand Down
6 changes: 3 additions & 3 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,9 @@ def build_query(self, columns=None):
try:
expr = where.as_mql(self, self.connection) if where else {}
except FullResultSet:
query.mongo_query = {}
query.match_mql = {}
else:
query.mongo_query = {"$expr": expr}
query.match_mql = {"$expr": expr}
if extra_fields:
query.extra_fields = self.get_project_fields(extra_fields, force_expression=True)
query.subqueries = self.subqueries
Expand Down Expand Up @@ -722,7 +722,7 @@ def execute_sql(self, result_type):
prepared = prepared.as_mql(self, self.connection)
values[field.column] = prepared
try:
criteria = self.build_query().mongo_query
criteria = self.build_query().match_mql
except EmptyResultSet:
return 0
is_empty = not bool(values)
Expand Down
2 changes: 1 addition & 1 deletion django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


class DatabaseFeatures(BaseDatabaseFeatures):
minimum_database_version = (5, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on 5.0 being what's tested against, correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Recall that we bumped Mongo 4.4 (copied from Pymongo's CI, I think) to 5.0 in 8bbfc61.

allow_sliced_subqueries_with_in = False
allows_multiple_constraints_on_same_fields = False
can_create_inline_fk = False
Expand Down Expand Up @@ -71,7 +72,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"backends.tests.BackendTestCase.test_is_usable_after_database_disconnects",
# Connection creation doesn't follow the usual Django API.
"backends.tests.ThreadTests.test_pass_connection_between_threads",
"backends.tests.ThreadTests.test_closing_non_shared_connections",
"backends.tests.ThreadTests.test_default_connection_thread_local",
# Union as subquery is not mapping the parent parameter and collections:
# https://github.com/mongodb-labs/django-mongodb/issues/156
Expand Down
17 changes: 6 additions & 11 deletions django_mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,9 @@ class MongoQuery:

def __init__(self, compiler):
self.compiler = compiler
self.connection = compiler.connection
self.ops = compiler.connection.ops
self.query = compiler.query
self._negated = False
self.ordering = []
self.collection = self.compiler.collection
self.collection_name = self.compiler.collection_name
self.mongo_query = getattr(compiler.query, "raw_query", {})
self.match_mql = {}
self.subqueries = None
self.lookup_pipeline = None
self.project_fields = None
Expand All @@ -61,31 +56,31 @@ def __init__(self, compiler):
self.subquery_lookup = None

def __repr__(self):
return f"<MongoQuery: {self.mongo_query!r} ORDER {self.ordering!r}>"
return f"<MongoQuery: {self.match_mql!r} ORDER {self.ordering!r}>"

@wrap_database_errors
def delete(self):
"""Execute a delete query."""
if self.compiler.subqueries:
raise NotSupportedError("Cannot use QuerySet.delete() when a subquery is required.")
return self.collection.delete_many(self.mongo_query).deleted_count
return self.compiler.collection.delete_many(self.match_mql).deleted_count

@wrap_database_errors
def get_cursor(self):
"""
Return a pymongo CommandCursor that can be iterated on to give the
results of the query.
"""
return self.collection.aggregate(self.get_pipeline())
return self.compiler.collection.aggregate(self.get_pipeline())

def get_pipeline(self):
pipeline = []
if self.lookup_pipeline:
pipeline.extend(self.lookup_pipeline)
for query in self.subqueries or ():
pipeline.extend(query.get_pipeline())
if self.mongo_query:
pipeline.append({"$match": self.mongo_query})
if self.match_mql:
pipeline.append({"$match": self.match_mql})
if self.aggregation_pipeline:
pipeline.extend(self.aggregation_pipeline)
if self.project_fields:
Expand Down