diff --git a/django_mongodb/base.py b/django_mongodb/base.py index caa1f63a8..9b3828cea 100644 --- a/django_mongodb/base.py +++ b/django_mongodb/base.py @@ -1,5 +1,6 @@ +import contextlib + from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.backends.signals import connection_created from pymongo.collection import Collection from pymongo.mongo_client import MongoClient @@ -128,11 +129,6 @@ def _isnull_operator(a, b): introspection_class = DatabaseIntrospection ops_class = DatabaseOperations - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.connected = False - del self.connection - def get_collection(self, name, **kwargs): collection = Collection(self.database, name, **kwargs) if self.queries_logged: @@ -145,31 +141,31 @@ def get_database(self): return self.database def __getattr__(self, attr): - """ - Connect to the database the first time `connection` or `database` are - accessed. - """ - if attr in ["connection", "database"]: - assert not self.connected - self._connect() + """Connect to the database the first time `database` is accessed.""" + if attr == "database": + if self.connection is None: + self.connect() return getattr(self, attr) raise AttributeError(attr) - def _connect(self): - settings_dict = self.settings_dict - self.connection = MongoClient( - host=settings_dict["HOST"] or None, - port=int(settings_dict["PORT"] or 27017), - username=settings_dict.get("USER"), - password=settings_dict.get("PASSWORD"), - **settings_dict["OPTIONS"], - ) - db_name = settings_dict["NAME"] + def init_connection_state(self): + db_name = self.settings_dict["NAME"] if db_name: self.database = self.connection[db_name] + super().init_connection_state() + + def get_connection_params(self): + settings_dict = self.settings_dict + return { + "host": settings_dict["HOST"] or None, + "port": int(settings_dict["PORT"] or 27017), + "username": settings_dict.get("USER"), + "password": settings_dict.get("PASSWORD"), + **settings_dict["OPTIONS"], + } - self.connected = True - connection_created.send(sender=self.__class__, connection=self) + def get_new_connection(self, conn_params): + return MongoClient(**conn_params) def _commit(self): pass @@ -177,12 +173,13 @@ def _commit(self): def _rollback(self): pass + def set_autocommit(self, autocommit, force_begin_transaction_with_broken_autocommit=False): + pass + def close(self): - if self.connected: - self.connection.close() - del self.connection + super().close() + with contextlib.suppress(AttributeError): del self.database - self.connected = False def cursor(self): return Cursor() diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index d01d9e946..0b7241141 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -386,9 +386,9 @@ def build_query(self, columns=None): try: expr = where.as_mql(self, self.connection) if where else {} except FullResultSet: - query.mongo_query = {} + query.match_mql = {} else: - query.mongo_query = {"$expr": expr} + query.match_mql = {"$expr": expr} if extra_fields: query.extra_fields = self.get_project_fields(extra_fields, force_expression=True) query.subqueries = self.subqueries @@ -722,7 +722,7 @@ def execute_sql(self, result_type): prepared = prepared.as_mql(self, self.connection) values[field.column] = prepared try: - criteria = self.build_query().mongo_query + criteria = self.build_query().match_mql except EmptyResultSet: return 0 is_empty = not bool(values) diff --git a/django_mongodb/features.py b/django_mongodb/features.py index b17f3abe7..4c9b32d1b 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -3,6 +3,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): + minimum_database_version = (5, 0) allow_sliced_subqueries_with_in = False allows_multiple_constraints_on_same_fields = False can_create_inline_fk = False @@ -71,7 +72,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "backends.tests.BackendTestCase.test_is_usable_after_database_disconnects", # Connection creation doesn't follow the usual Django API. "backends.tests.ThreadTests.test_pass_connection_between_threads", - "backends.tests.ThreadTests.test_closing_non_shared_connections", "backends.tests.ThreadTests.test_default_connection_thread_local", # Union as subquery is not mapping the parent parameter and collections: # https://github.com/mongodb-labs/django-mongodb/issues/156 diff --git a/django_mongodb/query.py b/django_mongodb/query.py index 92ad9e3e2..049775205 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -42,14 +42,9 @@ class MongoQuery: def __init__(self, compiler): self.compiler = compiler - self.connection = compiler.connection - self.ops = compiler.connection.ops self.query = compiler.query - self._negated = False self.ordering = [] - self.collection = self.compiler.collection - self.collection_name = self.compiler.collection_name - self.mongo_query = getattr(compiler.query, "raw_query", {}) + self.match_mql = {} self.subqueries = None self.lookup_pipeline = None self.project_fields = None @@ -61,14 +56,14 @@ def __init__(self, compiler): self.subquery_lookup = None def __repr__(self): - return f"" + return f"" @wrap_database_errors def delete(self): """Execute a delete query.""" if self.compiler.subqueries: raise NotSupportedError("Cannot use QuerySet.delete() when a subquery is required.") - return self.collection.delete_many(self.mongo_query).deleted_count + return self.compiler.collection.delete_many(self.match_mql).deleted_count @wrap_database_errors def get_cursor(self): @@ -76,7 +71,7 @@ def get_cursor(self): Return a pymongo CommandCursor that can be iterated on to give the results of the query. """ - return self.collection.aggregate(self.get_pipeline()) + return self.compiler.collection.aggregate(self.get_pipeline()) def get_pipeline(self): pipeline = [] @@ -84,8 +79,8 @@ def get_pipeline(self): pipeline.extend(self.lookup_pipeline) for query in self.subqueries or (): pipeline.extend(query.get_pipeline()) - if self.mongo_query: - pipeline.append({"$match": self.mongo_query}) + if self.match_mql: + pipeline.append({"$match": self.match_mql}) if self.aggregation_pipeline: pipeline.extend(self.aggregation_pipeline) if self.project_fields: