diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 24a9829a0..eb6e4ca83 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -51,6 +51,7 @@ jobs: - name: Run tests run: > python3 django_repo/tests/runtests.py --settings mongodb_settings -v 2 + admin_filters aggregation aggregation_regress annotations diff --git a/django_mongodb/aggregates.py b/django_mongodb/aggregates.py index 35daab2c2..e67cefa2f 100644 --- a/django_mongodb/aggregates.py +++ b/django_mongodb/aggregates.py @@ -1,9 +1,6 @@ -from copy import deepcopy - from django.db.models.aggregates import Aggregate, Count, StdDev, Variance from django.db.models.expressions import Case, Value, When from django.db.models.lookups import IsNull -from django.db.models.sql.where import WhereNode from .query_utils import process_lhs @@ -45,12 +42,9 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co node = self.copy() node.filter = None source_expressions = node.get_source_expressions() - filter_ = deepcopy(self.filter) - filter_.add( - WhereNode([IsNull(source_expressions[0], True)], negated=True), - filter_.default, + condition = When( + self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1))) ) - condition = When(filter_, then=Value(1)) node.set_source_expressions([Case(condition)] + source_expressions[1:]) inner_expression = process_lhs(node, compiler, connection) else: diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 9f1478d75..1b8807a01 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -368,8 +368,12 @@ def build_query(self, columns=None): ) if not query.aggregation_pipeline: query.aggregation_pipeline = [] - query.aggregation_pipeline.append({"$group": {"_id": distinct_fields}}) - query.project_fields = {key: f"$_id.{key}" for key in distinct_fields} + query.aggregation_pipeline.extend( + [ + {"$group": {"_id": distinct_fields}}, + {"$project": {key: f"$_id.{key}" for key in distinct_fields}}, + ] + ) else: # Otherwise, project fields without grouping. query.project_fields = self.get_project_fields(columns, ordering_fields) diff --git a/django_mongodb/query.py b/django_mongodb/query.py index 1d7dd951d..f952c9aba 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -154,7 +154,7 @@ def join(self, compiler, connection): if isinstance(hand_side_value, Col): # If the column is not part of the joined table, add it to # lhs_fields. - if hand_side_value.alias != self.table_name: + if hand_side_value.alias != self.table_alias: pos = len(lhs_fields) lhs_fields.append(expr.lhs.as_mql(compiler, connection)) else: