diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 314dd39f4..d8e323114 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -90,7 +90,13 @@ def check_query(self): """Check if the current query is supported by the database.""" if self.query.is_empty(): raise EmptyResultSet() - if self.query.distinct: + if self.query.distinct or getattr( + # In the case of Query.distinct().count(), the distinct attribute + # will be set on the inner_query. + getattr(self.query, "inner_query", None), + "distinct", + None, + ): # This is a heuristic to detect QuerySet.datetimes() and dates(). # "datetimefield" and "datefield" are the names of the annotations # the methods use. A user could annotate with the same names which diff --git a/django_mongodb/features.py b/django_mongodb/features.py index f0bba7d00..093032573 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -91,7 +91,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): "annotations.tests.NonAggregateAnnotationTestCase.test_annotation_with_m2m", "annotations.tests.NonAggregateAnnotationTestCase.test_chaining_annotation_filter_with_m2m", "annotations.tests.NonAggregateAnnotationTestCase.test_mti_annotations", - "lookup.tests.LookupTests.test_lookup_collision", "expressions.test_queryset_values.ValuesExpressionsTests.test_values_list_expression", "expressions.test_queryset_values.ValuesExpressionsTests.test_values_list_expression_flat", "expressions.tests.IterableLookupInnerExpressionsTests.test_expressions_in_lookups_join_choice", @@ -380,6 +379,7 @@ def django_test_expected_failures(self): "timezones.tests.NewDatabaseTests.test_query_datetimes_in_other_timezone", }, "QuerySet.distinct() is not supported.": { + "lookup.tests.LookupTests.test_lookup_collision_distinct", "update.tests.AdvancedTests.test_update_all", }, "QuerySet.extra() is not supported.": { diff --git a/django_mongodb/query.py b/django_mongodb/query.py index fc684f896..e48b99363 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -63,12 +63,8 @@ def count(self, limit=None, skip=None): Return the number of objects that would be returned, if this query was executed, up to `limit`, skipping `skip`. """ - kwargs = {} - if limit is not None: - kwargs["limit"] = limit - if skip is not None: - kwargs["skip"] = skip - return self.collection.count_documents(self.mongo_query, **kwargs) + result = list(self.get_cursor(count=True, limit=limit, skip=skip)) + return result[0]["__count"] if result else 0 def order_by(self, ordering): """ @@ -95,7 +91,16 @@ def delete(self): return self.collection.delete_many(self.mongo_query, **options).deleted_count @wrap_database_errors - def get_cursor(self): + def get_cursor(self, count=False, limit=None, skip=None): + """ + Return a pymongo CommandCursor that can be iterated on to give the + results of the query. + + If `count` is True, return a single document with the number of + documents that match the query. + + Use `limit` or `skip` to override those options of the query. + """ if self.query.low_mark == self.query.high_mark: return [] fields = {} @@ -129,10 +134,16 @@ def get_cursor(self): pipeline.append({"$project": fields}) if self.ordering: pipeline.append({"$sort": dict(self.ordering)}) - if self.query.low_mark > 0: + if skip is not None: + pipeline.append({"$skip": skip}) + elif self.query.low_mark > 0: pipeline.append({"$skip": self.query.low_mark}) - if self.query.high_mark is not None: + if limit is not None: + pipeline.append({"$limit": limit}) + elif self.query.high_mark is not None: pipeline.append({"$limit": self.query.high_mark - self.query.low_mark}) + if count: + pipeline.append({"$group": {"_id": None, "__count": {"$sum": 1}}}) return self.collection.aggregate(pipeline)