diff --git a/django_mongodb_backend/aggregates.py b/django_mongodb_backend/aggregates.py index 31f4b29ba..609ded75a 100644 --- a/django_mongodb_backend/aggregates.py +++ b/django_mongodb_backend/aggregates.py @@ -24,7 +24,7 @@ def aggregate( node.set_source_expressions([Case(condition), *source_expressions[1:]]) else: node = self - lhs_mql = process_lhs(node, compiler, connection) + lhs_mql = process_lhs(node, compiler, connection, as_expr=True) if resolve_inner_expression: return lhs_mql operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower()) @@ -46,9 +46,9 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1))) ) node.set_source_expressions([Case(condition), *source_expressions[1:]]) - inner_expression = process_lhs(node, compiler, connection) + inner_expression = process_lhs(node, compiler, connection, as_expr=True) else: - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]} inner_expression = { "$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1} @@ -58,7 +58,7 @@ def count(self, compiler, connection, resolve_inner_expression=False, **extra_co return {"$sum": inner_expression} # If distinct=True or resolve_inner_expression=False, sum the size of the # set. - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) # None shouldn't be counted, so subtract 1 if it's present. exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}} return {"$add": [{"$size": lhs_mql}, exits_null]} @@ -73,7 +73,7 @@ def stddev_variance(self, compiler, connection, **extra_context): def register_aggregates(): - Aggregate.as_mql = aggregate - Count.as_mql = count - StdDev.as_mql = stddev_variance - Variance.as_mql = stddev_variance + Aggregate.as_mql_expr = aggregate + Count.as_mql_expr = count + StdDev.as_mql_expr = stddev_variance + Variance.as_mql_expr = stddev_variance diff --git a/django_mongodb_backend/base.py b/django_mongodb_backend/base.py index ecc7f78c7..47efd9466 100644 --- a/django_mongodb_backend/base.py +++ b/django_mongodb_backend/base.py @@ -2,7 +2,8 @@ import logging import os -from django.core.exceptions import ImproperlyConfigured +from bson import Decimal128 +from django.core.exceptions import EmptyResultSet, FullResultSet, ImproperlyConfigured from django.db import DEFAULT_DB_ALIAS from django.db.backends.base.base import BaseDatabaseWrapper from django.db.backends.utils import debug_transaction @@ -20,7 +21,7 @@ from .features import DatabaseFeatures from .introspection import DatabaseIntrospection from .operations import DatabaseOperations -from .query_utils import regex_match +from .query_utils import regex_expr, regex_match from .schema import DatabaseSchemaEditor from .utils import OperationDebugWrapper from .validation import DatabaseValidation @@ -97,40 +98,91 @@ class DatabaseWrapper(BaseDatabaseWrapper): } _connection_pools = {} - def _isnull_operator(a, b): - is_null = { + def _isnull_operator_expr(field, is_null): + is_null_expr = { "$or": [ # The path does not exist (i.e. is "missing") - {"$eq": [{"$type": a}, "missing"]}, + {"$eq": [{"$type": field}, "missing"]}, # or the value is None. - {"$eq": [a, None]}, + {"$eq": [field, None]}, ] } - return is_null if b else {"$not": is_null} + return is_null_expr if is_null else {"$not": is_null_expr} - mongo_operators = { + mongo_expr_operators = { "exact": lambda a, b: {"$eq": [a, b]}, "gt": lambda a, b: {"$gt": [a, b]}, "gte": lambda a, b: {"$gte": [a, b]}, # MongoDB considers null less than zero. Exclude null values to match # SQL behavior. - "lt": lambda a, b: {"$and": [{"$lt": [a, b]}, DatabaseWrapper._isnull_operator(a, False)]}, + "lt": lambda a, b: { + "$and": [{"$lt": [a, b]}, DatabaseWrapper._isnull_operator_expr(a, False)] + }, "lte": lambda a, b: { - "$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator(a, False)] + "$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator_expr(a, False)] }, - "in": lambda a, b: {"$in": [a, b]}, - "isnull": _isnull_operator, + "in": lambda a, b: {"$in": (a, b)}, + "isnull": _isnull_operator_expr, "range": lambda a, b: { "$and": [ - {"$or": [DatabaseWrapper._isnull_operator(b[0], True), {"$gte": [a, b[0]]}]}, - {"$or": [DatabaseWrapper._isnull_operator(b[1], True), {"$lte": [a, b[1]]}]}, + {"$or": [DatabaseWrapper._isnull_operator_expr(b[0], True), {"$gte": [a, b[0]]}]}, + {"$or": [DatabaseWrapper._isnull_operator_expr(b[1], True), {"$lte": [a, b[1]]}]}, ] }, - "iexact": lambda a, b: regex_match(a, ("^", b, {"$literal": "$"}), insensitive=True), - "startswith": lambda a, b: regex_match(a, ("^", b)), - "istartswith": lambda a, b: regex_match(a, ("^", b), insensitive=True), - "endswith": lambda a, b: regex_match(a, (b, {"$literal": "$"})), - "iendswith": lambda a, b: regex_match(a, (b, {"$literal": "$"}), insensitive=True), + "iexact": lambda a, b: regex_expr(a, ("^", b, {"$literal": "$"}), insensitive=True), + "startswith": lambda a, b: regex_expr(a, ("^", b)), + "istartswith": lambda a, b: regex_expr(a, ("^", b), insensitive=True), + "endswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"})), + "iendswith": lambda a, b: regex_expr(a, (b, {"$literal": "$"}), insensitive=True), + "contains": lambda a, b: regex_expr(a, b), + "icontains": lambda a, b: regex_expr(a, b, insensitive=True), + "regex": lambda a, b: regex_expr(a, b), + "iregex": lambda a, b: regex_expr(a, b, insensitive=True), + } + + def range_match(a, b): + conditions = [] + start, end = b + if start is not None: + conditions.append({a: {"$gte": b[0]}}) + if end is not None: + conditions.append({a: {"$lte": b[1]}}) + if not conditions: + raise FullResultSet + if start is not None and end is not None: + if isinstance(start, Decimal128): + start = start.to_decimal() + if isinstance(end, Decimal128): + end = end.to_decimal() + if start > end: + raise EmptyResultSet + return {"$and": conditions} + + def _isnull_operator_match(field, is_null): + if is_null: + return {"$or": [{field: {"$exists": False}}, {field: None}]} + return {"$and": [{field: {"$exists": True}}, {field: {"$ne": None}}]} + + mongo_operators = { + "exact": lambda a, b: {a: b}, + "gt": lambda a, b: {a: {"$gt": b}}, + "gte": lambda a, b: {a: {"$gte": b}}, + # MongoDB considers null less than zero. Exclude null values to match + # SQL behavior. + "lt": lambda a, b: { + "$and": [{a: {"$lt": b}}, DatabaseWrapper._isnull_operator_match(a, False)] + }, + "lte": lambda a, b: { + "$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator_match(a, False)] + }, + "in": lambda a, b: {a: {"$in": tuple(b)}}, + "isnull": _isnull_operator_match, + "range": range_match, + "iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True), + "startswith": lambda a, b: regex_match(a, f"^{b}"), + "istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True), + "endswith": lambda a, b: regex_match(a, f"{b}$"), + "iendswith": lambda a, b: regex_match(a, f"{b}$", insensitive=True), "contains": lambda a, b: regex_match(a, b), "icontains": lambda a, b: regex_match(a, b, insensitive=True), "regex": lambda a, b: regex_match(a, b), diff --git a/django_mongodb_backend/compiler.py b/django_mongodb_backend/compiler.py index 628a91e84..aef1862ff 100644 --- a/django_mongodb_backend/compiler.py +++ b/django_mongodb_backend/compiler.py @@ -69,12 +69,14 @@ def _get_replace_expr(self, sub_expr, group, alias): if getattr(sub_expr, "distinct", False): # If the expression should return distinct values, use $addToSet to # deduplicate. - rhs = sub_expr.as_mql(self, self.connection, resolve_inner_expression=True) + rhs = sub_expr.as_mql( + self, self.connection, resolve_inner_expression=True, as_expr=True + ) group[alias] = {"$addToSet": rhs} replacing_expr = sub_expr.copy() replacing_expr.set_source_expressions([inner_column, None]) else: - group[alias] = sub_expr.as_mql(self, self.connection) + group[alias] = sub_expr.as_mql(self, self.connection, as_expr=True) replacing_expr = inner_column # Count must return 0 rather than null. if isinstance(sub_expr, Count): @@ -302,9 +304,7 @@ def _compound_searches_queries(self, search_replacements): search.as_mql(self, self.connection), { "$addFields": { - result_col.as_mql(self, self.connection, as_path=True): { - "$meta": score_function - } + result_col.as_mql(self, self.connection): {"$meta": score_function} } }, ] @@ -334,7 +334,7 @@ def pre_sql_setup(self, with_col_aliases=False): pipeline.extend(query.get_pipeline()) # Remove the added subqueries. self.subqueries = [] - pipeline.append({"$match": {"$expr": having}}) + pipeline.append({"$match": having}) self.aggregation_pipeline = pipeline self.annotations = { target: expr.replace_expressions(all_replacements) @@ -481,11 +481,11 @@ def build_query(self, columns=None): query.lookup_pipeline = self.get_lookup_pipeline() where = self.get_where() try: - expr = where.as_mql(self, self.connection) if where else {} + match = where.as_mql(self, self.connection) if where else {} except FullResultSet: query.match_mql = {} else: - query.match_mql = {"$expr": expr} + query.match_mql = match if extra_fields: query.extra_fields = self.get_project_fields(extra_fields, force_expression=True) query.subqueries = self.subqueries @@ -643,7 +643,9 @@ def get_combinator_queries(self): for alias, expr in self.columns: # Unfold foreign fields. if isinstance(expr, Col) and expr.alias != self.collection_name: - ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection) + ids[expr.alias][expr.target.column] = expr.as_mql( + self, self.connection, as_expr=True + ) else: ids[alias] = f"${alias}" # Convert defaultdict to dict so it doesn't appear as @@ -707,16 +709,16 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False # For brevity/simplicity, project {"field_name": 1} # instead of {"field_name": "$field_name"}. if isinstance(expr, Col) and name == expr.target.column and not force_expression - else expr.as_mql(self, self.connection) + else expr.as_mql(self, self.connection, as_expr=True) ) except EmptyResultSet: empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented) value = ( False if empty_result_set_value is NotImplemented else empty_result_set_value ) - fields[collection][name] = Value(value).as_mql(self, self.connection) + fields[collection][name] = Value(value).as_mql(self, self.connection, as_expr=True) except FullResultSet: - fields[collection][name] = Value(True).as_mql(self, self.connection) + fields[collection][name] = Value(True).as_mql(self, self.connection, as_expr=True) # Annotations (stored in None) and the main collection's fields # should appear in the top-level of the fields dict. fields.update(fields.pop(None, {})) @@ -739,10 +741,10 @@ def _get_ordering(self): idx = itertools.count(start=1) for order in self.order_by_objs or []: if isinstance(order.expression, Col): - field_name = order.as_mql(self, self.connection).removeprefix("$") + field_name = order.as_mql(self, self.connection, as_expr=True).removeprefix("$") fields.append((order.expression.target.column, order.expression)) elif isinstance(order.expression, Ref): - field_name = order.as_mql(self, self.connection).removeprefix("$") + field_name = order.as_mql(self, self.connection, as_expr=True).removeprefix("$") else: field_name = f"__order{next(idx)}" fields.append((field_name, order.expression)) @@ -879,7 +881,7 @@ def execute_sql(self, result_type): ) prepared = field.get_db_prep_save(value, connection=self.connection) if hasattr(value, "as_mql"): - prepared = prepared.as_mql(self, self.connection) + prepared = prepared.as_mql(self, self.connection, as_expr=True) values[field.column] = prepared try: criteria = self.build_query().match_mql diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 0387fb97c..4138b8024 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -1,11 +1,13 @@ import datetime from decimal import Decimal +from functools import partialmethod from uuid import UUID from bson import Decimal128 from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import NotSupportedError from django.db.models.expressions import ( + BaseExpression, Case, Col, ColPairs, @@ -28,21 +30,29 @@ from django_mongodb_backend.query_utils import process_lhs +def base_expression(self, compiler, connection, as_expr=False, **extra): + if not as_expr and hasattr(self, "as_mql_path") and getattr(self, "can_use_path", False): + return self.as_mql_path(compiler, connection, **extra) + + expr = self.as_mql_expr(compiler, connection, **extra) + return expr if as_expr else {"$expr": expr} + + def case(self, compiler, connection): case_parts = [] for case in self.cases: case_mql = {} try: - case_mql["case"] = case.as_mql(compiler, connection) + case_mql["case"] = case.as_mql(compiler, connection, as_expr=True) except EmptyResultSet: continue except FullResultSet: - default_mql = case.result.as_mql(compiler, connection) + default_mql = case.result.as_mql(compiler, connection, as_expr=True) break - case_mql["then"] = case.result.as_mql(compiler, connection) + case_mql["then"] = case.result.as_mql(compiler, connection, as_expr=True) case_parts.append(case_mql) else: - default_mql = self.default.as_mql(compiler, connection) + default_mql = self.default.as_mql(compiler, connection, as_expr=True) if not case_parts: return default_mql return { @@ -53,7 +63,7 @@ def case(self, compiler, connection): } -def col(self, compiler, connection, as_path=False): # noqa: ARG001 +def col(self, compiler, connection, as_expr=False): # noqa: ARG001 # If the column is part of a subquery and belongs to one of the parent # queries, it will be stored for reference using $let in a $lookup stage. # If the query is built with `alias_cols=False`, treat the column as @@ -71,28 +81,28 @@ def col(self, compiler, connection, as_path=False): # noqa: ARG001 # Add the column's collection's alias for columns in joined collections. has_alias = self.alias and self.alias != compiler.collection_name prefix = f"{self.alias}." if has_alias else "" - if not as_path: + if as_expr: prefix = f"${prefix}" return f"{prefix}{self.target.column}" -def col_pairs(self, compiler, connection): +def col_pairs(self, compiler, connection, as_expr=False): cols = self.get_cols() if len(cols) > 1: raise NotSupportedError("ColPairs is not supported.") - return cols[0].as_mql(compiler, connection) + return cols[0].as_mql(compiler, connection, as_expr=as_expr) def combined_expression(self, compiler, connection): expressions = [ - self.lhs.as_mql(compiler, connection), - self.rhs.as_mql(compiler, connection), + self.lhs.as_mql(compiler, connection, as_expr=True), + self.rhs.as_mql(compiler, connection, as_expr=True), ] return connection.ops.combine_expression(self.connector, expressions) def expression_wrapper(self, compiler, connection): - return self.expression.as_mql(compiler, connection) + return self.expression.as_mql(compiler, connection, as_expr=True) def negated_expression(self, compiler, connection): @@ -100,10 +110,10 @@ def negated_expression(self, compiler, connection): def order_by(self, compiler, connection): - return self.expression.as_mql(compiler, connection) + return self.expression.as_mql(compiler, connection, as_expr=True) -def query(self, compiler, connection, get_wrapping_pipeline=None): +def query(self, compiler, connection, get_wrapping_pipeline=None, as_expr=False): subquery_compiler = self.get_compiler(connection=connection) subquery_compiler.pre_sql_setup(with_col_aliases=False) field_name, expr = subquery_compiler.columns[0] @@ -123,7 +133,7 @@ def query(self, compiler, connection, get_wrapping_pipeline=None): "as": table_output, "from": from_table, "let": { - compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection) + compiler.PARENT_FIELD_TEMPLATE.format(i): col.as_mql(compiler, connection, as_expr=True) for col, i in subquery_compiler.column_indices.items() }, } @@ -145,14 +155,16 @@ def query(self, compiler, connection, get_wrapping_pipeline=None): # Erase project_fields since the required value is projected above. subquery.project_fields = None compiler.subqueries.append(subquery) - return f"${table_output}.{field_name}" + if as_expr: + return f"${table_output}.{field_name}" + return f"{table_output}.{field_name}" def raw_sql(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("RawSQL is not supported on MongoDB.") -def ref(self, compiler, connection): # noqa: ARG001 +def ref(self, compiler, connection, as_expr=False): # noqa: ARG001 prefix = ( f"{self.source.alias}." if isinstance(self.source, Col) and self.source.alias != compiler.collection_name @@ -162,32 +174,43 @@ def ref(self, compiler, connection): # noqa: ARG001 refs, _ = compiler.columns[self.ordinal - 1] else: refs = self.refs - return f"${prefix}{refs}" + if as_expr: + prefix = f"${prefix}" + return f"{prefix}{refs}" + + +@property +def ref_is_simple_column(self): + return self.source.is_simple_column def star(self, compiler, connection): # noqa: ARG001 return {"$literal": True} -def subquery(self, compiler, connection, get_wrapping_pipeline=None): - return self.query.as_mql(compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) +def subquery(self, compiler, connection, get_wrapping_pipeline=None, as_expr=False): + return self.query.as_mql( + compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_expr=as_expr + ) def exists(self, compiler, connection, get_wrapping_pipeline=None): try: - lhs_mql = subquery(self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline) + lhs_mql = subquery( + self, compiler, connection, get_wrapping_pipeline=get_wrapping_pipeline, as_expr=True + ) except EmptyResultSet: - return Value(False).as_mql(compiler, connection) - return connection.mongo_operators["isnull"](lhs_mql, False) + return Value(False).as_mql(compiler, connection, as_expr=True) + return connection.mongo_expr_operators["isnull"](lhs_mql, False) def when(self, compiler, connection): - return self.condition.as_mql(compiler, connection) + return self.condition.as_mql(compiler, connection, as_expr=True) -def value(self, compiler, connection): # noqa: ARG001 +def value(self, compiler, connection, as_expr=False): # noqa: ARG001 value = self.value - if isinstance(value, (list, int)): + if isinstance(value, (list, int)) and as_expr: # Wrap lists & numbers in $literal to prevent ambiguity when Value # appears in $project. return {"$literal": value} @@ -210,20 +233,25 @@ def value(self, compiler, connection): # noqa: ARG001 def register_expressions(): - Case.as_mql = case + BaseExpression.as_mql = base_expression + BaseExpression.is_simple_column = False + Case.as_mql_expr = case Col.as_mql = col + Col.is_simple_column = True ColPairs.as_mql = col_pairs - CombinedExpression.as_mql = combined_expression - Exists.as_mql = exists + CombinedExpression.as_mql_expr = combined_expression + Exists.as_mql_expr = exists ExpressionList.as_mql = process_lhs - ExpressionWrapper.as_mql = expression_wrapper - NegatedExpression.as_mql = negated_expression - OrderBy.as_mql = order_by + ExpressionWrapper.as_mql_expr = expression_wrapper + NegatedExpression.as_mql_expr = negated_expression + OrderBy.as_mql_expr = order_by Query.as_mql = query RawSQL.as_mql = raw_sql Ref.as_mql = ref + Ref.is_simple_column = ref_is_simple_column ResolvedOuterRef.as_mql = ResolvedOuterRef.as_sql - Star.as_mql = star - Subquery.as_mql = subquery - When.as_mql = when + Star.as_mql_expr = star + Subquery.as_mql_expr = partialmethod(subquery, as_expr=True) + Subquery.as_mql_path = partialmethod(subquery, as_expr=False) + When.as_mql_expr = when Value.as_mql = value diff --git a/django_mongodb_backend/expressions/search.py b/django_mongodb_backend/expressions/search.py index 6b14be497..f759eba33 100644 --- a/django_mongodb_backend/expressions/search.py +++ b/django_mongodb_backend/expressions/search.py @@ -118,7 +118,7 @@ def _get_query_index(self, fields, compiler): def search_operator(self, compiler, connection): raise NotImplementedError - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_expr=False): index = self._get_query_index(self.get_search_fields(compiler, connection), compiler) return {"$search": {**self.search_operator(compiler, connection), "index": index}} @@ -154,11 +154,11 @@ def __init__(self, path, query, *, fuzzy=None, token_order=None, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -192,11 +192,11 @@ def __init__(self, path, value, *, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "value": self.value, } if self.score: @@ -228,11 +228,11 @@ def __init__(self, path, *, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), } if self.score: params["score"] = self.score.as_mql(compiler, connection) @@ -261,11 +261,11 @@ def __init__(self, path, value, *, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "value": self.value, } if self.score: @@ -302,11 +302,11 @@ def __init__(self, path, query, *, slop=None, synonyms=None, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -344,11 +344,11 @@ def __init__(self, path, query, *, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "defaultPath": self.path.as_mql(compiler, connection, as_path=True), + "defaultPath": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -387,11 +387,11 @@ def __init__(self, path, *, lt=None, lte=None, gt=None, gte=None, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), } if self.score: params["score"] = self.score.as_mql(compiler, connection) @@ -431,11 +431,11 @@ def __init__(self, path, query, *, allow_analyzed_field=None, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -481,11 +481,11 @@ def __init__(self, path, query, *, fuzzy=None, match_criteria=None, synonyms=Non super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -527,11 +527,11 @@ def __init__(self, path, query, allow_analyzed_field=None, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "query": self.query, } if self.score: @@ -573,11 +573,11 @@ def __init__(self, path, relation, geometry, *, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "relation": self.relation, "geometry": self.geometry, } @@ -617,11 +617,11 @@ def __init__(self, path, kind, geometry, *, score=None): super().__init__() def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def search_operator(self, compiler, connection): params = { - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), self.kind: self.geometry, } if self.score: @@ -817,9 +817,9 @@ def resolve(node, negated=False): return CompoundExpression(must=[lhs_compound, rhs_compound]) return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1) - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_expr=False): expression = self.resolve(self) - return expression.as_mql(compiler, connection) + return expression.as_mql(compiler, connection, as_expr=as_expr) class SearchVector(SearchExpression): @@ -879,7 +879,7 @@ def __ror__(self, other): raise NotSupportedError("SearchVector cannot be combined") def get_search_fields(self, compiler, connection): - return {self.path.as_mql(compiler, connection, as_path=True)} + return {self.path.as_mql(compiler, connection)} def _get_query_index(self, fields, compiler): for search_indexes in compiler.collection.list_search_indexes(): @@ -891,10 +891,10 @@ def _get_query_index(self, fields, compiler): return search_indexes["name"] return "default" - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_expr=False): params = { "index": self._get_query_index(self.get_search_fields(compiler, connection), compiler), - "path": self.path.as_mql(compiler, connection, as_path=True), + "path": self.path.as_mql(compiler, connection), "queryVector": self.query_vector, "limit": self.limit, } @@ -913,7 +913,7 @@ class SearchScoreOption(Expression): def __init__(self, definitions=None): self._definitions = definitions - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_expr=False): return self._definitions @@ -933,10 +933,15 @@ def __str__(self): def __repr__(self): return f"SearchText({self.lhs}, {self.rhs})" - def as_mql(self, compiler, connection): + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) + value = process_rhs(self, compiler, connection, as_expr=True) + return {"$gte": [lhs_mql, value]} + + def as_mql_path(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) value = process_rhs(self, compiler, connection) - return {"$gte": [lhs_mql, value]} + return {lhs_mql: {"$gte": value}} CharField.register_lookup(SearchTextLookup) diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 7e29c5003..18a048bf6 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -90,9 +90,6 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures): "auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key", # GenericRelation.value_to_string() assumes integer pk. "contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string", - # icontains doesn't work on ArrayField: - # Unsupported conversion from array to string in $convert - "model_fields_.test_arrayfield.QueryingTests.test_icontains", # ArrayField's contained_by lookup crashes with Exists: "both operands " # of $setIsSubset must be arrays. Second argument is of type: null" # https://jira.mongodb.org/browse/SERVER-99186 diff --git a/django_mongodb_backend/fields/array.py b/django_mongodb_backend/fields/array.py index a6369b21a..0daf2d41b 100644 --- a/django_mongodb_backend/fields/array.py +++ b/django_mongodb_backend/fields/array.py @@ -4,10 +4,11 @@ from django.db.models import Field, Func, IntegerField, Transform, Value from django.db.models.fields.mixins import CheckFieldDefaultMixin from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup +from django.utils.functional import cached_property from django.utils.translation import gettext_lazy as _ from django_mongodb_backend.forms import SimpleArrayField -from django_mongodb_backend.query_utils import process_lhs, process_rhs +from django_mongodb_backend.query_utils import is_constant_value, process_lhs, process_rhs from django_mongodb_backend.utils import prefix_validation_error from django_mongodb_backend.validators import ArrayMaxLengthValidator, LengthValidator @@ -230,9 +231,23 @@ def formfield(self, **kwargs): class Array(Func): - def as_mql(self, compiler, connection): + def as_mql_expr(self, compiler, connection): + return [ + expr.as_mql(compiler, connection, as_expr=True) + for expr in self.get_source_expressions() + ] + + def as_mql_path(self, compiler, connection): return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()] + @cached_property + def can_use_path(self): + return all(is_constant_value(expr) for expr in self.get_source_expressions()) + + @property + def is_simple_column(self): + return False + class ArrayRHSMixin: def __init__(self, lhs, rhs): @@ -251,9 +266,9 @@ def __init__(self, lhs, rhs): class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contains" - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) + value = process_rhs(self, compiler, connection, as_expr=True) return { "$and": [ {"$ne": [lhs_mql, None]}, @@ -262,14 +277,19 @@ def as_mql(self, compiler, connection): ] } + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + return {lhs_mql: {"$all": value}} + @ArrayField.register_lookup class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contained_by" - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) + value = process_rhs(self, compiler, connection, as_expr=True) return { "$and": [ {"$ne": [lhs_mql, None]}, @@ -293,7 +313,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) { "$facet": { "group": [ - {"$project": {"tmp_name": expr.as_mql(compiler, connection)}}, + {"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}}, { "$unwind": "$tmp_name", }, @@ -323,21 +343,29 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) }, ] - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) - value = process_rhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) + value = process_rhs(self, compiler, connection, as_expr=True) return { - "$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}] + "$and": [ + {"$ne": [lhs_mql, None]}, + {"$size": {"$setIntersection": [value, lhs_mql]}}, + ] } + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + value = process_rhs(self, compiler, connection) + return {lhs_mql: {"$in": value}} + @ArrayField.register_lookup class ArrayLenTransform(Transform): lookup_name = "len" output_field = IntegerField() - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}} @@ -363,10 +391,22 @@ def __init__(self, index, base_field, *args, **kwargs): self.index = index self.base_field = base_field - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + @property + def can_use_path(self): + return self.is_simple_column + + @property + def is_simple_column(self): + return self.lhs.is_simple_column + + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) return {"$arrayElemAt": [lhs_mql, self.index]} + def as_mql_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + return f"{lhs_mql}.{self.index}" + @property def output_field(self): return self.base_field @@ -387,8 +427,8 @@ def __init__(self, start, end, *args, **kwargs): self.start = start self.end = end - def as_mql(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) return {"$slice": [lhs_mql, self.start, self.end]} diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 83423def8..b3d9a2e02 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -5,8 +5,10 @@ from django.db import models from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Transform +from django.utils.functional import cached_property from django_mongodb_backend import forms +from django_mongodb_backend.query_utils import valid_path_key_name class EmbeddedModelField(models.Field): @@ -166,6 +168,19 @@ def __init__(self, field, *args, **kwargs): def get_lookup(self, name): return self.field.get_lookup(name) + @property + def can_use_path(self): + return self.is_simple_column + + @cached_property + def is_simple_column(self): + previous = self + while isinstance(previous, EmbeddedModelTransform): + if not valid_path_key_name(previous._field.column): + return False + previous = previous.lhs + return previous.is_simple_column + def get_transform(self, name): """ Validate that `name` is either a field of an embedded model or a @@ -185,21 +200,27 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection, as_path=False): + def _get_target_path(self): previous = self columns = [] while isinstance(previous, EmbeddedModelTransform): columns.insert(0, previous.field.column) previous = previous.lhs - if as_path: - mql = previous.as_mql(compiler, connection, as_path=True) - mql_path = ".".join(columns) - return f"{mql}.{mql_path}" - mql = previous.as_mql(compiler, connection) - for column in columns: - mql = {"$getField": {"input": mql, "field": column}} + return columns, previous + + def as_mql_expr(self, compiler, connection): + columns, parent_field = self._get_target_path() + mql = parent_field.as_mql(compiler, connection, as_expr=True) + for key in columns: + mql = {"$getField": {"input": mql, "field": key}} return mql + def as_mql_path(self, compiler, connection): + columns, parent_field = self._get_target_path() + mql = parent_field.as_mql(compiler, connection) + mql_path = ".".join(columns) + return f"{mql}.{mql_path}" + @property def output_field(self): return self._field diff --git a/django_mongodb_backend/fields/embedded_model_array.py b/django_mongodb_backend/fields/embedded_model_array.py index a220969cd..b75017194 100644 --- a/django_mongodb_backend/fields/embedded_model_array.py +++ b/django_mongodb_backend/fields/embedded_model_array.py @@ -5,12 +5,12 @@ from django.db.models.expressions import Col from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Lookup, Transform +from django.utils.functional import cached_property from django_mongodb_backend import forms -from django_mongodb_backend.query_utils import process_lhs, process_rhs - -from . import EmbeddedModelField -from .array import ArrayField, ArrayLenTransform +from django_mongodb_backend.fields import EmbeddedModelField +from django_mongodb_backend.fields.array import ArrayField, ArrayLenTransform +from django_mongodb_backend.query_utils import process_lhs, process_rhs, valid_path_key_name class EmbeddedModelArrayField(ArrayField): @@ -77,7 +77,7 @@ def _get_lookup(self, lookup_name): return lookup class EmbeddedModelArrayFieldLookups(Lookup): - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_expr=False): raise ValueError( "Lookups aren't supported on EmbeddedModelArrayField. " "Try querying one of its embedded fields instead." @@ -116,7 +116,7 @@ def get_lookup(self, name): class EmbeddedModelArrayFieldBuiltinLookup(Lookup): - def process_rhs(self, compiler, connection): + def process_rhs(self, compiler, connection, as_expr=False): value = self.rhs if not self.get_db_prep_lookup_value_is_iterable: value = [value] @@ -130,14 +130,14 @@ def process_rhs(self, compiler, connection): for v in value ] - def as_mql(self, compiler, connection): + def as_mql_expr(self, compiler, connection): # Querying a subfield within the array elements (via nested # KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over # the array and applying $in on the subfield. - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"] - values = process_rhs(self, compiler, connection) - lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name]( + values = process_rhs(self, compiler, connection, as_expr=True) + lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_expr_operators[self.lookup_name]( inner_lhs_mql, values ) return {"$anyElementTrue": lhs_mql} @@ -153,7 +153,7 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) { "$facet": { "gathered_data": [ - {"$project": {"tmp_name": expr.as_mql(compiler, connection)}}, + {"$project": {"tmp_name": expr.as_mql(compiler, connection, as_expr=True)}}, # To concatenate all the values from the RHS subquery, # use an $unwind followed by a $group. { @@ -236,6 +236,7 @@ def __init__(self, field, *args, **kwargs): column_name = f"$item.{field.column}" column_target.db_column = column_name column_target.set_attributes_from_name(column_name) + self._field = field self._lhs = Col(None, column_target) self._sub_transform = None @@ -243,6 +244,19 @@ def __call__(self, this, *args, **kwargs): self._lhs = self._sub_transform(self._lhs, *args, **kwargs) return self + @property + def can_use_path(self): + return self.is_simple_column + + @cached_property + def is_simple_column(self): + previous = self + while isinstance(previous, EmbeddedModelArrayFieldTransform): + if not valid_path_key_name(previous._field.column): + return False + previous = previous.lhs + return previous.is_simple_column and self._lhs.is_simple_column + def get_lookup(self, name): return self.output_field.get_lookup(name) @@ -275,9 +289,9 @@ def get_transform(self, name): f"{suggestion}" ) - def as_mql(self, compiler, connection): - inner_lhs_mql = self._lhs.as_mql(compiler, connection) - lhs_mql = process_lhs(self, compiler, connection) + def as_mql_expr(self, compiler, connection): + inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_expr=True) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) return { "$ifNull": [ { @@ -291,6 +305,11 @@ def as_mql(self, compiler, connection): ] } + def as_mql_path(self, compiler, connection): + inner_lhs_mql = self._lhs.as_mql(compiler, connection).removeprefix("$item.") + lhs_mql = process_lhs(self, compiler, connection) + return f"{lhs_mql}.{inner_lhs_mql}" + @property def output_field(self): return _EmbeddedModelArrayOutputField(self._lhs.output_field) diff --git a/django_mongodb_backend/fields/json.py b/django_mongodb_backend/fields/json.py index aeb792d75..cc5ccea9c 100644 --- a/django_mongodb_backend/fields/json.py +++ b/django_mongodb_backend/fields/json.py @@ -1,3 +1,6 @@ +from functools import partialmethod +from itertools import chain + from django.db import NotSupportedError from django.db.models.fields.json import ( ContainedBy, @@ -8,17 +11,20 @@ HasKeys, JSONExact, KeyTransform, + KeyTransformExact, KeyTransformIn, KeyTransformIsNull, KeyTransformNumericLookupMixin, ) -from django_mongodb_backend.lookups import builtin_lookup -from django_mongodb_backend.query_utils import process_lhs, process_rhs +from django_mongodb_backend.lookups import builtin_lookup_expr, builtin_lookup_path +from django_mongodb_backend.query_utils import process_lhs, process_rhs, valid_path_key_name -def build_json_mql_path(lhs, key_transforms): +def build_json_mql_path(lhs, key_transforms, as_expr=False): # Build the MQL path using the collected key transforms. + if not as_expr: + return ".".join(chain([lhs], key_transforms)) result = lhs for key in key_transforms: get_field = {"$getField": {"input": result, "field": key}} @@ -37,16 +43,18 @@ def build_json_mql_path(lhs, key_transforms): return result -def contained_by(self, compiler, connection): # noqa: ARG001 +def contained_by(self, compiler, connection, as_expr=False): # noqa: ARG001 raise NotSupportedError("contained_by lookup is not supported on this database backend.") -def data_contains(self, compiler, connection): # noqa: ARG001 +def data_contains(self, compiler, connection, as_expr=False): # noqa: ARG001 raise NotSupportedError("contains lookup is not supported on this database backend.") -def _has_key_predicate(path, root_column, negated=False): +def _has_key_predicate(path, root_column=None, negated=False, as_expr=False): """Return MQL to check for the existence of `path`.""" + if not as_expr: + return {path: {"$exists": not negated}} result = { "$and": [ # The path must exist (i.e. not be "missing"). @@ -61,10 +69,16 @@ def _has_key_predicate(path, root_column, negated=False): return result -def has_key_lookup(self, compiler, connection): +@property +def has_key_check_simple_expression(self): + rhs = [self.rhs] if not isinstance(self.rhs, (list, tuple)) else self.rhs + return self.is_simple_column and all(valid_path_key_name(key) for key in rhs) + + +def has_key_lookup(self, compiler, connection, as_expr=False): """Return MQL to check for the existence of a key.""" rhs = self.rhs - lhs = process_lhs(self, compiler, connection) + lhs = process_lhs(self, compiler, connection, as_expr=as_expr) if not isinstance(rhs, (list, tuple)): rhs = [rhs] paths = [] @@ -72,10 +86,10 @@ def has_key_lookup(self, compiler, connection): # in the code that follows. for key in rhs: rhs_json_path = key if isinstance(key, KeyTransform) else KeyTransform(key, self.lhs) - paths.append(rhs_json_path.as_mql(compiler, connection)) + paths.append(rhs_json_path.as_mql(compiler, connection, as_expr=as_expr)) keys = [] for path in paths: - keys.append(_has_key_predicate(path, lhs)) + keys.append(_has_key_predicate(path, lhs, as_expr=as_expr)) if self.mongo_operator is None: return keys[0] return {self.mongo_operator: keys} @@ -93,7 +107,7 @@ def json_exact_process_rhs(self, compiler, connection): ) -def key_transform(self, compiler, connection): +def key_transform(self, compiler, connection, as_expr=False): """ Return MQL for this KeyTransform (JSON path). @@ -108,28 +122,38 @@ def key_transform(self, compiler, connection): while isinstance(previous, KeyTransform): key_transforms.insert(0, previous.key_name) previous = previous.lhs - lhs_mql = previous.as_mql(compiler, connection) - return build_json_mql_path(lhs_mql, key_transforms) + lhs_mql = previous.as_mql(compiler, connection, as_expr=as_expr) + return build_json_mql_path(lhs_mql, key_transforms, as_expr=as_expr) -def key_transform_in(self, compiler, connection): +def key_transform_exact_path(self, compiler, connection): + lhs_mql = process_lhs(self, compiler, connection) + return { + "$and": [ + builtin_lookup_path(self, compiler, connection), + _has_key_predicate(lhs_mql, None), + ] + } + + +def key_transform_in_expr(self, compiler, connection): """ Return MQL to check if a JSON path exists and that its values are in the set of specified values (rhs). """ - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) # Traverse to the root column. previous = self.lhs while isinstance(previous, KeyTransform): previous = previous.lhs - root_column = previous.as_mql(compiler, connection) - value = process_rhs(self, compiler, connection) + root_column = previous.as_mql(compiler, connection, as_expr=True) + value = process_rhs(self, compiler, connection, as_expr=True) # Construct the expression to check if lhs_mql values are in rhs values. - expr = connection.mongo_operators[self.lookup_name](lhs_mql, value) - return {"$and": [_has_key_predicate(lhs_mql, root_column), expr]} + expr = connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) + return {"$and": [_has_key_predicate(lhs_mql, root_column, as_expr=True), expr]} -def key_transform_is_null(self, compiler, connection): +def key_transform_is_null_expr(self, compiler, connection): """ Return MQL to check the nullability of a key. @@ -139,37 +163,63 @@ def key_transform_is_null(self, compiler, connection): Reference: https://code.djangoproject.com/ticket/32252 """ - lhs_mql = process_lhs(self, compiler, connection) - rhs_mql = process_rhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) + rhs_mql = process_rhs(self, compiler, connection, as_expr=True) # Get the root column. previous = self.lhs while isinstance(previous, KeyTransform): previous = previous.lhs - root_column = previous.as_mql(compiler, connection) - return _has_key_predicate(lhs_mql, root_column, negated=rhs_mql) + root_column = previous.as_mql(compiler, connection, as_expr=True) + return _has_key_predicate(lhs_mql, root_column, negated=rhs_mql, as_expr=True) -def key_transform_numeric_lookup_mixin(self, compiler, connection): +def key_transform_is_null_path(self, compiler, connection): + """ + Return MQL to check the nullability of a key using the operator $exists. + """ + lhs_mql = process_lhs(self, compiler, connection) + rhs_mql = process_rhs(self, compiler, connection) + return _has_key_predicate(lhs_mql, None, negated=rhs_mql) + + +def key_transform_numeric_lookup_mixin_expr(self, compiler, connection): """ Return MQL to check if the field exists (i.e., is not "missing" or "null") and that the field matches the given numeric lookup expression. """ - expr = builtin_lookup(self, compiler, connection) - lhs = process_lhs(self, compiler, connection) + expr = builtin_lookup_expr(self, compiler, connection) + lhs = process_lhs(self, compiler, connection, as_expr=True) # Check if the type of lhs is not "missing" or "null". not_missing_or_null = {"$not": {"$in": [{"$type": lhs}, ["missing", "null"]]}} return {"$and": [expr, not_missing_or_null]} +@property +def keytransform_is_simple_column(self): + previous = self + while isinstance(previous, KeyTransform): + if not valid_path_key_name(previous.key_name): + return False + previous = previous.lhs + return previous.is_simple_column + + def register_json_field(): ContainedBy.as_mql = contained_by DataContains.as_mql = data_contains HasAnyKeys.mongo_operator = "$or" HasKey.mongo_operator = None - HasKeyLookup.as_mql = has_key_lookup + HasKeyLookup.as_mql_expr = partialmethod(has_key_lookup, as_expr=True) + HasKeyLookup.as_mql_path = partialmethod(has_key_lookup, as_expr=False) + HasKeyLookup.can_use_path = has_key_check_simple_expression HasKeys.mongo_operator = "$and" JSONExact.process_rhs = json_exact_process_rhs - KeyTransform.as_mql = key_transform - KeyTransformIn.as_mql = key_transform_in - KeyTransformIsNull.as_mql = key_transform_is_null - KeyTransformNumericLookupMixin.as_mql = key_transform_numeric_lookup_mixin + KeyTransform.as_mql_expr = partialmethod(key_transform, as_expr=True) + KeyTransform.as_mql_path = partialmethod(key_transform, as_expr=False) + KeyTransform.can_use_path = keytransform_is_simple_column + KeyTransform.is_simple_column = keytransform_is_simple_column + KeyTransformExact.as_mql_path = key_transform_exact_path + KeyTransformIn.as_mql_expr = key_transform_in_expr + KeyTransformIsNull.as_mql_expr = key_transform_is_null_expr + KeyTransformIsNull.as_mql_path = key_transform_is_null_path + KeyTransformNumericLookupMixin.as_mql_expr = key_transform_numeric_lookup_mixin_expr diff --git a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py index 6325ca4fc..5578e1c6b 100644 --- a/django_mongodb_backend/fields/polymorphic_embedded_model_array.py +++ b/django_mongodb_backend/fields/polymorphic_embedded_model_array.py @@ -80,7 +80,7 @@ def _get_lookup(self, lookup_name): return lookup class EmbeddedModelArrayFieldLookups(Lookup): - def as_mql(self, compiler, connection): + def as_mql(self, compiler, connection, as_expr=False): raise ValueError( "Lookups aren't supported on PolymorphicEmbeddedModelArrayField. " "Try querying one of its embedded fields instead." diff --git a/django_mongodb_backend/functions.py b/django_mongodb_backend/functions.py index c45800a0a..c5d3a270e 100644 --- a/django_mongodb_backend/functions.py +++ b/django_mongodb_backend/functions.py @@ -1,4 +1,5 @@ from datetime import datetime +from functools import partialmethod from django.conf import settings from django.db import NotSupportedError @@ -67,7 +68,7 @@ def cast(self, compiler, connection): output_type = connection.data_types[self.output_field.get_internal_type()] - lhs_mql = process_lhs(self, compiler, connection)[0] + lhs_mql = process_lhs(self, compiler, connection, as_expr=True)[0] if max_length := self.output_field.max_length: lhs_mql = {"$substrCP": [lhs_mql, 0, max_length]} # Skip the conversion for "object" as it doesn't need to be transformed for @@ -81,22 +82,22 @@ def cast(self, compiler, connection): def concat(self, compiler, connection): - return self.get_source_expressions()[0].as_mql(compiler, connection) + return self.get_source_expressions()[0].as_mql(compiler, connection, as_expr=True) def concat_pair(self, compiler, connection): # null on either side results in null for expression, wrap with coalesce. coalesced = self.coalesce() - return super(ConcatPair, coalesced).as_mql(compiler, connection) + return super(ConcatPair, coalesced).as_mql_expr(compiler, connection) def cot(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) return {"$divide": [1, {"$tan": lhs_mql}]} def extract(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) operator = EXTRACT_OPERATORS.get(self.lookup_name) if operator is None: raise NotSupportedError(f"{self.__class__.__name__} is not supported.") @@ -106,7 +107,7 @@ def extract(self, compiler, connection): def func(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) if self.function is None: raise NotSupportedError(f"{self} may need an as_mql() method.") operator = MONGO_OPERATORS.get(self.__class__, self.function.lower()) @@ -114,12 +115,12 @@ def func(self, compiler, connection): def left(self, compiler, connection): - return self.get_substr().as_mql(compiler, connection) + return self.get_substr().as_mql(compiler, connection, as_expr=True) def length(self, compiler, connection): # Check for null first since $strLenCP only accepts strings. - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}} @@ -136,7 +137,9 @@ def now(self, compiler, connection): # noqa: ARG001 def null_if(self, compiler, connection): """Return None if expr1==expr2 else expr1.""" - expr1, expr2 = (expr.as_mql(compiler, connection) for expr in self.get_source_expressions()) + expr1, expr2 = ( + expr.as_mql(compiler, connection, as_expr=True) for expr in self.get_source_expressions() + ) return {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}} @@ -144,10 +147,10 @@ def preserve_null(operator): # If the argument is null, the function should return null, not # $toLower/Upper's behavior of returning an empty string. def wrapped(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) return { "$cond": { - "if": connection.mongo_operators["isnull"](lhs_mql, True), + "if": connection.mongo_expr_operators["isnull"](lhs_mql, True), "then": None, "else": {f"${operator}": lhs_mql}, } @@ -157,24 +160,29 @@ def wrapped(self, compiler, connection): def replace(self, compiler, connection): - expression, text, replacement = process_lhs(self, compiler, connection) + expression, text, replacement = process_lhs(self, compiler, connection, as_expr=True) return {"$replaceAll": {"input": expression, "find": text, "replacement": replacement}} def round_(self, compiler, connection): # Round needs its own function because it's a special case that inherits # from Transform but has two arguments. - return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]} + return { + "$round": [ + expr.as_mql(compiler, connection, as_expr=True) + for expr in self.get_source_expressions() + ] + } def str_index(self, compiler, connection): - lhs = process_lhs(self, compiler, connection) + lhs = process_lhs(self, compiler, connection, as_expr=True) # StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB. return {"$add": [{"$indexOfCP": lhs}, 1]} def substr(self, compiler, connection): - lhs = process_lhs(self, compiler, connection) + lhs = process_lhs(self, compiler, connection, as_expr=True) # The starting index is zero-indexed on MongoDB rather than one-indexed. lhs[1] = {"$add": [lhs[1], -1]} # If no limit is specified, use the length of the string since $substrCP @@ -186,14 +194,14 @@ def substr(self, compiler, connection): def trim(operator): def wrapped(self, compiler, connection): - lhs = process_lhs(self, compiler, connection) + lhs = process_lhs(self, compiler, connection, as_expr=True) return {f"${operator}": {"input": lhs}} return wrapped def trunc(self, compiler, connection): - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) lhs_mql = {"date": lhs_mql, "unit": self.kind, "startOfWeek": "mon"} if timezone := self.get_tzname(): lhs_mql["timezone"] = timezone @@ -232,7 +240,7 @@ def trunc_convert_value(self, value, expression, connection): def trunc_date(self, compiler, connection): # Cast to date rather than truncate to date. - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) tzname = self.get_tzname() if tzname and tzname != "UTC": raise NotSupportedError(f"TruncDate with tzinfo ({tzname}) isn't supported on MongoDB.") @@ -255,7 +263,7 @@ def trunc_time(self, compiler, connection): tzname = self.get_tzname() if tzname and tzname != "UTC": raise NotSupportedError(f"TruncTime with tzinfo ({tzname}) isn't supported on MongoDB.") - lhs_mql = process_lhs(self, compiler, connection) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) return { "$dateFromString": { "dateString": { @@ -272,28 +280,29 @@ def trunc_time(self, compiler, connection): def register_functions(): - Cast.as_mql = cast - Concat.as_mql = concat - ConcatPair.as_mql = concat_pair - Cot.as_mql = cot - Extract.as_mql = extract - Func.as_mql = func - JSONArray.as_mql = process_lhs - Left.as_mql = left - Length.as_mql = length - Log.as_mql = log - Lower.as_mql = preserve_null("toLower") - LTrim.as_mql = trim("ltrim") - Now.as_mql = now - NullIf.as_mql = null_if - Replace.as_mql = replace - Round.as_mql = round_ - RTrim.as_mql = trim("rtrim") - StrIndex.as_mql = str_index - Substr.as_mql = substr - Trim.as_mql = trim("trim") - TruncBase.as_mql = trunc + Cast.as_mql_expr = cast + Concat.as_mql_expr = concat + ConcatPair.as_mql_expr = concat_pair + Cot.as_mql_expr = cot + Extract.as_mql_expr = extract + Func.as_mql_expr = func + Func.can_use_path = False + JSONArray.as_mql_expr = partialmethod(process_lhs, as_expr=True) + Left.as_mql_expr = left + Length.as_mql_expr = length + Log.as_mql_expr = log + Lower.as_mql_expr = preserve_null("toLower") + LTrim.as_mql_expr = trim("ltrim") + Now.as_mql_expr = now + NullIf.as_mql_expr = null_if + Replace.as_mql_expr = replace + Round.as_mql_expr = round_ + RTrim.as_mql_expr = trim("rtrim") + StrIndex.as_mql_expr = str_index + Substr.as_mql_expr = substr + Trim.as_mql_expr = trim("trim") + TruncBase.as_mql_expr = trunc TruncBase.convert_value = trunc_convert_value - TruncDate.as_mql = trunc_date - TruncTime.as_mql = trunc_time - Upper.as_mql = preserve_null("toUpper") + TruncDate.as_mql_expr = trunc_date + TruncTime.as_mql_expr = trunc_time + Upper.as_mql_expr = preserve_null("toUpper") diff --git a/django_mongodb_backend/gis/lookups.py b/django_mongodb_backend/gis/lookups.py index 29c2e1e96..8df8ed59c 100644 --- a/django_mongodb_backend/gis/lookups.py +++ b/django_mongodb_backend/gis/lookups.py @@ -2,7 +2,7 @@ from django.db import NotSupportedError -def gis_lookup(self, compiler, connection): # noqa: ARG001 +def gis_lookup(self, compiler, connection, as_expr=False): # noqa: ARG001 raise NotSupportedError(f"MongoDB does not support the {self.lookup_name} lookup.") diff --git a/django_mongodb_backend/lookups.py b/django_mongodb_backend/lookups.py index 8dda2bab3..3cc94acb1 100644 --- a/django_mongodb_backend/lookups.py +++ b/django_mongodb_backend/lookups.py @@ -4,14 +4,21 @@ BuiltinLookup, FieldGetDbPrepValueIterableMixin, IsNull, + Lookup, PatternLookup, UUIDTextMixin, ) -from .query_utils import process_lhs, process_rhs +from .query_utils import is_constant_value, process_lhs, process_rhs -def builtin_lookup(self, compiler, connection): +def builtin_lookup_expr(self, compiler, connection): + value = process_rhs(self, compiler, connection, as_expr=True) + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) + return connection.mongo_expr_operators[self.lookup_name](lhs_mql, value) + + +def builtin_lookup_path(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) value = process_rhs(self, compiler, connection) return connection.mongo_operators[self.lookup_name](lhs_mql, value) @@ -33,14 +40,17 @@ def field_resolve_expression_parameter(self, compiler, connection, sql, param): return sql, sql_params -def in_(self, compiler, connection): - db_rhs = getattr(self.rhs, "_db", None) - if db_rhs is not None and db_rhs != connection.alias: - raise ValueError( - "Subqueries aren't allowed across different databases. Force " - "the inner query to be evaluated using `list(inner_query)`." - ) - return builtin_lookup(self, compiler, connection) +def wrap_in(function): + def inner(self, compiler, connection): + db_rhs = getattr(self.rhs, "_db", None) + if db_rhs is not None and db_rhs != connection.alias: + raise ValueError( + "Subqueries aren't allowed across different databases. Force " + "the inner query to be evaluated using `list(inner_query)`." + ) + return function(self, compiler, connection) + + return inner def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr): # noqa: ARG001 @@ -51,7 +61,9 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) { "$group": { "_id": None, - "tmp_name": {"$addToSet": expr.as_mql(compiler, connection)}, + "tmp_name": { + "$addToSet": expr.as_mql(compiler, connection, as_expr=True) + }, } } ] @@ -75,7 +87,14 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr) ] -def is_null(self, compiler, connection): +def is_null_expr(self, compiler, connection): + if not isinstance(self.rhs, bool): + raise ValueError("The QuerySet value for an isnull lookup must be True or False.") + lhs_mql = process_lhs(self, compiler, connection, as_expr=True) + return connection.mongo_expr_operators["isnull"](lhs_mql, self.rhs) + + +def is_null_path(self, compiler, connection): if not isinstance(self.rhs, bool): raise ValueError("The QuerySet value for an isnull lookup must be True or False.") lhs_mql = process_lhs(self, compiler, connection) @@ -121,13 +140,25 @@ def uuid_text_mixin(self, compiler, connection): # noqa: ARG001 raise NotSupportedError("Pattern lookups on UUIDField are not supported.") +@property +def can_use_path(self): + simple_column = getattr(self.lhs, "is_simple_column", False) + constant_value = is_constant_value(self.rhs) + return simple_column and constant_value + + def register_lookups(): - BuiltinLookup.as_mql = builtin_lookup + BuiltinLookup.as_mql_expr = builtin_lookup_expr + BuiltinLookup.as_mql_path = builtin_lookup_path FieldGetDbPrepValueIterableMixin.resolve_expression_parameter = ( field_resolve_expression_parameter ) - In.as_mql = RelatedIn.as_mql = in_ + In.as_mql_expr = RelatedIn.as_mql_expr = wrap_in(builtin_lookup_expr) + In.as_mql_path = RelatedIn.as_mql_path = wrap_in(builtin_lookup_path) In.get_subquery_wrapping_pipeline = get_subquery_wrapping_pipeline - IsNull.as_mql = is_null + IsNull.as_mql_expr = is_null_expr + IsNull.as_mql_path = is_null_path + Lookup.can_use_path = can_use_path PatternLookup.prep_lookup_value_mongo = pattern_lookup_prep_lookup_value + # Patching the main method, it is not supported yet. UUIDTextMixin.as_mql = uuid_text_mixin diff --git a/django_mongodb_backend/query.py b/django_mongodb_backend/query.py index c86b8721b..8dc88974e 100644 --- a/django_mongodb_backend/query.py +++ b/django_mongodb_backend/query.py @@ -11,8 +11,6 @@ from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError -from .query_conversion.query_optimizer import convert_expr_to_match - def wrap_database_errors(func): @wraps(func) @@ -89,7 +87,7 @@ def get_pipeline(self): for query in self.subqueries or (): pipeline.extend(query.get_pipeline()) if self.match_mql: - pipeline.extend(convert_expr_to_match(self.match_mql)) + pipeline.append({"$match": self.match_mql}) if self.aggregation_pipeline: pipeline.extend(self.aggregation_pipeline) if self.project_fields: @@ -156,7 +154,9 @@ def _get_reroot_replacements(expression): # lhs_fields. if hand_side_value.alias != self.table_alias: pos = len(lhs_fields) - lhs_fields.append(hand_side_value.as_mql(compiler, connection)) + lhs_fields.append( + hand_side_value.as_mql(compiler, connection, as_expr=True) + ) else: pos = None columns.append((hand_side_value, pos)) @@ -168,6 +168,7 @@ def _get_reroot_replacements(expression): target.remote_field = col.target.remote_field column_target = Col(compiler.collection_name, target) if parent_pos is not None: + column_target.is_simple_column = False target_col = f"${parent_template}{parent_pos}" column_target.target.db_column = target_col column_target.target.set_attributes_from_name(target_col) @@ -184,10 +185,10 @@ def _get_reroot_replacements(expression): lhs, rhs = connection.ops.prepare_join_on_clause( self.parent_alias, lhs, compiler.collection_name, rhs ) - lhs_fields.append(lhs.as_mql(compiler, connection)) + lhs_fields.append(lhs.as_mql(compiler, connection, as_expr=True)) # In the lookup stage, the reference to this column doesn't include the # collection name. - rhs_fields.append(rhs.as_mql(compiler, connection)) + rhs_fields.append(rhs.as_mql(compiler, connection, as_expr=True)) # Handle any join conditions besides matching field pairs. extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias) extra_conditions = [] @@ -211,6 +212,23 @@ def _get_reroot_replacements(expression): compiler, connection ) ) + + # Match the conditions: + # self.table_name.field1 = parent_table.field1 + # AND + # self.table_name.field2 = parent_table.field2 + # AND + # ... + condition = { + "$expr": { + "$and": [ + {"$eq": [f"$${parent_template}{i}", field]} for i, field in enumerate(rhs_fields) + ] + } + } + if extra_conditions: + condition = {"$and": [condition, *extra_conditions]} + lookup_pipeline = [ { "$lookup": { @@ -222,25 +240,7 @@ def _get_reroot_replacements(expression): f"{parent_template}{i}": parent_field for i, parent_field in enumerate(lhs_fields) }, - "pipeline": [ - { - # Match the conditions: - # self.table_name.field1 = parent_table.field1 - # AND - # self.table_name.field2 = parent_table.field2 - # AND - # ... - "$match": { - "$expr": { - "$and": [ - {"$eq": [f"$${parent_template}{i}", field]} - for i, field in enumerate(rhs_fields) - ] - + extra_conditions - } - } - } - ], + "pipeline": [{"$match": condition}], # Rename the output as table_alias. "as": self.table_alias, } @@ -274,7 +274,7 @@ def _get_reroot_replacements(expression): return lookup_pipeline -def where_node(self, compiler, connection): +def where_node(self, compiler, connection, as_expr=False): if self.connector == AND: full_needed, empty_needed = len(self.children), 1 else: @@ -297,14 +297,16 @@ def where_node(self, compiler, connection): if len(self.children) > 2: rhs_sum = Mod(rhs_sum, 2) rhs = Exact(1, rhs_sum) - return self.__class__([lhs, rhs], AND, self.negated).as_mql(compiler, connection) + return self.__class__([lhs, rhs], AND, self.negated).as_mql( + compiler, connection, as_expr=as_expr + ) else: operator = "$or" children_mql = [] for child in self.children: try: - mql = child.as_mql(compiler, connection) + mql = child.as_mql(compiler, connection, as_expr=as_expr) except EmptyResultSet: empty_needed -= 1 except FullResultSet: @@ -331,13 +333,17 @@ def where_node(self, compiler, connection): raise FullResultSet if self.negated and mql: - mql = {"$not": mql} + mql = {"$not": [mql]} if as_expr else {"$nor": [mql]} return mql +def nothing_node(self, compiler, connection, as_expr=False): # noqa: ARG001 + return self.as_sql(compiler, connection) + + def register_nodes(): ExtraWhere.as_mql = extra_where Join.as_mql = join - NothingNode.as_mql = NothingNode.as_sql + NothingNode.as_mql = nothing_node WhereNode.as_mql = where_node diff --git a/django_mongodb_backend/query_conversion/__init__.py b/django_mongodb_backend/query_conversion/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/django_mongodb_backend/query_conversion/expression_converters.py b/django_mongodb_backend/query_conversion/expression_converters.py deleted file mode 100644 index dc9b7df5b..000000000 --- a/django_mongodb_backend/query_conversion/expression_converters.py +++ /dev/null @@ -1,172 +0,0 @@ -class BaseConverter: - """Base class for $expr to $match converters.""" - - @classmethod - def convert(cls, expr): - raise NotImplementedError("Subclasses must implement this method.") - - @classmethod - def is_simple_value(cls, value): - """Is the value is a simple type (not a dict)?""" - if value is None: - return True - if isinstance(value, str) and value.startswith("$"): - return False - if isinstance(value, (list, tuple, set)): - return all(cls.is_simple_value(v) for v in value) - # TODO: Support `$getField` conversion. - return not isinstance(value, dict) - - -class BinaryConverter(BaseConverter): - """ - Base class for converting binary operations. - - For example: - "$expr": { - {"$gt": ["$price", 100]} - } - is converted to: - {"price": {"$gt": 100}} - """ - - operator: str - - @classmethod - def convert(cls, args): - if isinstance(args, list) and len(args) == 2: - field_expr, value = args - # Check if first argument is a simple field reference. - if ( - isinstance(field_expr, str) - and field_expr.startswith("$") - and cls.is_simple_value(value) - ): - field_name = field_expr[1:] # Remove the $ prefix. - if cls.operator == "$eq": - return {field_name: value} - return {field_name: {cls.operator: value}} - return None - - -class EqConverter(BinaryConverter): - """ - Convert $eq operation to a $match query. - - For example: - "$expr": { - {"$eq": ["$status", "active"]} - } - is converted to: - {"status": "active"} - """ - - operator = "$eq" - - -class GtConverter(BinaryConverter): - operator = "$gt" - - -class GteConverter(BinaryConverter): - operator = "$gte" - - -class LtConverter(BinaryConverter): - operator = "$lt" - - -class LteConverter(BinaryConverter): - operator = "$lte" - - -class InConverter(BaseConverter): - """ - Convert $in operation to a $match query. - - For example: - "$expr": { - {"$in": ["$category", ["electronics", "books"]]} - } - is converted to: - {"category": {"$in": ["electronics", "books"]}} - """ - - @classmethod - def convert(cls, in_args): - if isinstance(in_args, list) and len(in_args) == 2: - field_expr, values = in_args - # Check if first argument is a simple field reference. - if isinstance(field_expr, str) and field_expr.startswith("$"): - field_name = field_expr[1:] # Remove the $ prefix. - if isinstance(values, (list, tuple, set)) and all( - cls.is_simple_value(v) for v in values - ): - return {field_name: {"$in": values}} - return None - - -class LogicalConverter(BaseConverter): - """ - Base class for converting logical operations to a $match query. - - For example: - "$expr": { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - ] - } - is converted to: - "$or": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - ] - """ - - @classmethod - def convert(cls, combined_conditions): - if isinstance(combined_conditions, list): - optimized_conditions = [] - for condition in combined_conditions: - if isinstance(condition, dict) and len(condition) == 1: - if optimized_condition := convert_expression(condition): - optimized_conditions.append(optimized_condition) - else: - # Any failure should stop optimization. - return None - if optimized_conditions: - return {cls._logical_op: optimized_conditions} - return None - - -class OrConverter(LogicalConverter): - _logical_op = "$or" - - -class AndConverter(LogicalConverter): - _logical_op = "$and" - - -OPTIMIZABLE_OPS = { - "$eq": EqConverter, - "$in": InConverter, - "$and": AndConverter, - "$or": OrConverter, - "$gt": GtConverter, - "$gte": GteConverter, - "$lt": LtConverter, - "$lte": LteConverter, -} - - -def convert_expression(expr): - """ - Optimize MQL by converting an $expr condition to $match. Return the $match - MQL, or None if not optimizable. - """ - if isinstance(expr, dict) and len(expr) == 1: - op = next(iter(expr.keys())) - if op in OPTIMIZABLE_OPS: - return OPTIMIZABLE_OPS[op].convert(expr[op]) - return None diff --git a/django_mongodb_backend/query_conversion/query_optimizer.py b/django_mongodb_backend/query_conversion/query_optimizer.py deleted file mode 100644 index 368c89504..000000000 --- a/django_mongodb_backend/query_conversion/query_optimizer.py +++ /dev/null @@ -1,73 +0,0 @@ -from .expression_converters import convert_expression - - -def convert_expr_to_match(query): - """ - Optimize an MQL query by converting conditions into a list of $match - stages. - """ - if "$expr" not in query: - return [query] - if query["$expr"] == {}: - return [{"$match": {}}] - return _process_expression(query["$expr"]) - - -def _process_expression(expr): - """Process an expression and extract optimizable conditions.""" - match_conditions = [] - remaining_conditions = [] - if isinstance(expr, dict): - has_and = "$and" in expr - has_or = "$or" in expr - # Do a top-level check for $and or $or because these should inform. - # If they fail, they should failover to a remaining conditions list. - # There's probably a better way to do this. - if has_and: - and_match_conditions = _process_logical_conditions("$and", expr["$and"]) - match_conditions.extend(and_match_conditions) - if has_or: - or_match_conditions = _process_logical_conditions("$or", expr["$or"]) - match_conditions.extend(or_match_conditions) - if not has_and and not has_or: - # Process single condition. - if optimized := convert_expression(expr): - match_conditions.append({"$match": optimized}) - else: - remaining_conditions.append({"$match": {"$expr": expr}}) - else: - # Can't optimize. - remaining_conditions.append({"$expr": expr}) - return match_conditions + remaining_conditions - - -def _process_logical_conditions(logical_op, logical_conditions): - """Process conditions within a logical array.""" - optimized_conditions = [] - match_conditions = [] - remaining_conditions = [] - for condition in logical_conditions: - _remaining_conditions = [] - if isinstance(condition, dict): - if optimized := convert_expression(condition): - optimized_conditions.append(optimized) - else: - _remaining_conditions.append(condition) - else: - _remaining_conditions.append(condition) - if _remaining_conditions: - # Any expressions that can't be optimized must remain in a $expr - # that preserves the logical operator. - if len(_remaining_conditions) > 1: - remaining_conditions.append({"$expr": {logical_op: _remaining_conditions}}) - else: - remaining_conditions.append({"$expr": _remaining_conditions[0]}) - if optimized_conditions: - optimized_conditions.extend(remaining_conditions) - if len(optimized_conditions) > 1: - match_conditions.append({"$match": {logical_op: optimized_conditions}}) - else: - match_conditions.append({"$match": optimized_conditions[0]}) - else: - match_conditions.append({"$match": {logical_op: remaining_conditions}}) - return match_conditions diff --git a/django_mongodb_backend/query_utils.py b/django_mongodb_backend/query_utils.py index 4b744241e..54ade6395 100644 --- a/django_mongodb_backend/query_utils.py +++ b/django_mongodb_backend/query_utils.py @@ -1,13 +1,16 @@ +import re + from django.core.exceptions import FullResultSet from django.db.models.aggregates import Aggregate -from django.db.models.expressions import Value +from django.db.models.expressions import CombinedExpression, Func, Value +from django.db.models.sql.query import Query def is_direct_value(node): return not hasattr(node, "as_sql") -def process_lhs(node, compiler, connection): +def process_lhs(node, compiler, connection, as_expr=False): if not hasattr(node, "lhs"): # node is a Func or Expression, possibly with multiple source expressions. result = [] @@ -15,27 +18,30 @@ def process_lhs(node, compiler, connection): if expr is None: continue try: - result.append(expr.as_mql(compiler, connection)) + result.append(expr.as_mql(compiler, connection, as_expr=as_expr)) except FullResultSet: - result.append(Value(True).as_mql(compiler, connection)) + result.append(Value(True).as_mql(compiler, connection, as_expr=as_expr)) if isinstance(node, Aggregate): return result[0] return result # node is a Transform with just one source expression, aliased as "lhs". if is_direct_value(node.lhs): return node - return node.lhs.as_mql(compiler, connection) + return node.lhs.as_mql(compiler, connection, as_expr=as_expr) -def process_rhs(node, compiler, connection): +def process_rhs(node, compiler, connection, as_expr=False): rhs = node.rhs if hasattr(rhs, "as_mql"): if getattr(rhs, "subquery", False) and hasattr(node, "get_subquery_wrapping_pipeline"): value = rhs.as_mql( - compiler, connection, get_wrapping_pipeline=node.get_subquery_wrapping_pipeline + compiler, + connection, + get_wrapping_pipeline=node.get_subquery_wrapping_pipeline, + as_expr=as_expr, ) else: - value = rhs.as_mql(compiler, connection) + value = rhs.as_mql(compiler, connection, as_expr=as_expr) else: _, value = node.process_rhs(compiler, connection) lookup_name = node.lookup_name @@ -47,7 +53,47 @@ def process_rhs(node, compiler, connection): return value -def regex_match(field, regex_vals, insensitive=False): +def regex_expr(field, regex_vals, insensitive=False): regex = {"$concat": regex_vals} if isinstance(regex_vals, tuple) else regex_vals options = "i" if insensitive else "" return {"$regexMatch": {"input": field, "regex": regex, "options": options}} + + +def regex_match(field, regex, insensitive=False): + options = "i" if insensitive else "" + return {field: {"$regex": regex, "$options": options}} + + +def is_constant_value(value): + if isinstance(value, CombinedExpression): + # Temporary: treat all CombinedExpressions as non-constant until + # constant cases are handled + return False + if isinstance(value, list): + return all(map(is_constant_value, value)) + if is_direct_value(value): + return True + if hasattr(value, "get_source_expressions"): + # Temporary: similar limitation as above, sub-expressions should be + # resolved in the future + constants_sub_expressions = all(map(is_constant_value, value.get_source_expressions())) + else: + constants_sub_expressions = True + constants_sub_expressions = constants_sub_expressions and not ( + isinstance(value, Query) + or value.contains_aggregate + or value.contains_over_clause + or value.contains_column_references + or value.contains_subquery + ) + return constants_sub_expressions and ( + isinstance(value, Value) + or + # Some closed functions cannot yet be converted to constant values. + # Allow Func with can_use_path as a temporary exception. + (isinstance(value, Func) and value.can_use_path) + ) + + +def valid_path_key_name(key_name): + return bool(re.fullmatch(r"[A-Za-z0-9_]+", key_name)) diff --git a/django_mongodb_backend/test.py b/django_mongodb_backend/test.py index 561832a15..ee35b4e21 100644 --- a/django_mongodb_backend/test.py +++ b/django_mongodb_backend/test.py @@ -1,6 +1,6 @@ """Not a public API.""" -from bson import SON, ObjectId +from bson import SON, Decimal128, ObjectId class MongoTestCaseMixin: @@ -16,6 +16,6 @@ def assertAggregateQuery(self, query, expected_collection, expected_pipeline): self.assertEqual(operator, "aggregate") self.assertEqual(collection, expected_collection) self.assertEqual( - eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId}, {}), # noqa: S307 + eval(pipeline[:-1], {"SON": SON, "ObjectId": ObjectId, "Decimal128": Decimal128}, {}), # noqa: S307 expected_pipeline, ) diff --git a/tests/expression_converter_/__init__.py b/tests/expression_converter_/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/expression_converter_/test_match_conversion.py b/tests/expression_converter_/test_match_conversion.py deleted file mode 100644 index e78e5c0cc..000000000 --- a/tests/expression_converter_/test_match_conversion.py +++ /dev/null @@ -1,215 +0,0 @@ -from django.test import SimpleTestCase - -from django_mongodb_backend.query_conversion.query_optimizer import convert_expr_to_match - - -class ConvertExprToMatchTests(SimpleTestCase): - def assertOptimizerEqual(self, input, expected): - result = convert_expr_to_match(input) - self.assertEqual(result, expected) - - def test_multiple_optimizable_conditions(self): - expr = { - "$expr": { - "$and": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - {"$gte": ["$price", 50]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - {"price": {"$gte": 50}}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_mixed_optimizable_and_non_optimizable_conditions(self): - expr = { - "$expr": { - "$and": [ - {"$eq": ["$status", "active"]}, - {"$gt": ["$price", "$min_price"]}, # Not optimizable - {"$in": ["$category", ["electronics"]]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - {"status": "active"}, - {"category": {"$in": ["electronics"]}}, - {"$expr": {"$gt": ["$price", "$min_price"]}}, - ], - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_non_optimizable_condition(self): - expr = {"$expr": {"$gt": ["$price", "$min_price"]}} - expected = [ - { - "$match": { - "$expr": {"$gt": ["$price", "$min_price"]}, - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_nested_logical_conditions(self): - expr = { - "$expr": { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - {"$and": [{"$eq": ["$verified", True]}, {"$lte": ["$price", 50]}]}, - ] - } - } - expected = [ - { - "$match": { - "$or": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - {"$and": [{"verified": True}, {"price": {"$lte": 50}}]}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_complex_nested_with_non_optimizable_parts(self): - expr = { - "$expr": { - "$and": [ - { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$gt": ["$views", 1000]}, - ] - }, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - {"$gt": ["$price", "$min_price"]}, # Not optimizable - ] - } - } - expected = [ - { - "$match": { - "$and": [ - { - "$or": [ - {"status": "active"}, - {"views": {"$gt": 1000}}, - ] - }, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - {"$expr": {"$gt": ["$price", "$min_price"]}}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_london_in_case(self): - expr = {"$expr": {"$in": ["$author_city", ["London"]]}} - expected = [{"$match": {"author_city": {"$in": ["London"]}}}] - self.assertOptimizerEqual(expr, expected) - - def test_deeply_nested_logical_operators(self): - expr = { - "$expr": { - "$and": [ - { - "$or": [ - {"$eq": ["$type", "premium"]}, - { - "$and": [ - {"$eq": ["$type", "standard"]}, - {"$in": ["$region", ["US", "CA"]]}, - ] - }, - ] - }, - {"$eq": ["$active", True]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - { - "$or": [ - {"type": "premium"}, - { - "$and": [ - {"type": "standard"}, - {"region": {"$in": ["US", "CA"]}}, - ] - }, - ] - }, - {"active": True}, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) - - def test_deeply_nested_logical_operator_with_variable(self): - expr = { - "$expr": { - "$and": [ - { - "$or": [ - {"$eq": ["$type", "premium"]}, - { - "$and": [ - {"$eq": ["$type", "$$standard"]}, # Not optimizable - {"$in": ["$region", ["US", "CA"]]}, - ] - }, - ] - }, - {"$eq": ["$active", True]}, - ] - } - } - expected = [ - { - "$match": { - "$and": [ - {"active": True}, - { - "$expr": { - "$or": [ - {"$eq": ["$type", "premium"]}, - { - "$and": [ - {"$eq": ["$type", "$$standard"]}, - {"$in": ["$region", ["US", "CA"]]}, - ] - }, - ] - } - }, - ] - } - } - ] - self.assertOptimizerEqual(expr, expected) diff --git a/tests/expression_converter_/test_op_expressions.py b/tests/expression_converter_/test_op_expressions.py deleted file mode 100644 index ce4caf2d4..000000000 --- a/tests/expression_converter_/test_op_expressions.py +++ /dev/null @@ -1,233 +0,0 @@ -import datetime -from uuid import UUID - -from bson import Decimal128 -from django.test import SimpleTestCase - -from django_mongodb_backend.query_conversion.expression_converters import convert_expression - - -class ConversionTestCase(SimpleTestCase): - CONVERTIBLE_TYPES = { - "int": 42, - "float": 3.14, - "decimal128": Decimal128("3.14"), - "boolean": True, - "NoneType": None, - "string": "string", - "datetime": datetime.datetime.now(datetime.timezone.utc), - "duration": datetime.timedelta(days=5, hours=3), - "uuid": UUID("12345678123456781234567812345678"), - } - - def assertConversionEqual(self, input, expected): - result = convert_expression(input) - self.assertEqual(result, expected) - - def assertNotOptimizable(self, input): - result = convert_expression(input) - self.assertIsNone(result) - - def _test_conversion_various_types(self, conversion_test): - for _type, val in self.CONVERTIBLE_TYPES.items(): - with self.subTest(_type=_type, val=val): - conversion_test(val) - - -class ExpressionTests(ConversionTestCase): - def test_non_dict(self): - self.assertNotOptimizable(["$status", "active"]) - - def test_empty_dict(self): - self.assertNotOptimizable({}) - - -class EqTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$eq": ["$status", "active"]}, {"status": "active"}) - - def test_no_conversion_non_string_field(self): - self.assertNotOptimizable({"$eq": [123, "active"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$eq": ["$status", {"$gt": 5}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$eq": ["$age", _type]}, {"age": _type}) - - def _test_conversion_valid_array_type(self, _type): - self.assertConversionEqual({"$eq": ["$age", _type]}, {"age": _type}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - def test_conversion_various_array_types(self): - self._test_conversion_various_types(self._test_conversion_valid_array_type) - - -class InTests(ConversionTestCase): - def test_conversion(self): - expr = {"$in": ["$category", ["electronics", "books", "clothing"]]} - expected = {"category": {"$in": ["electronics", "books", "clothing"]}} - self.assertConversionEqual(expr, expected) - - def test_no_conversion_non_string_field(self): - self.assertNotOptimizable({"$in": [123, ["electronics", "books"]]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$in": ["$status", [{"bad": "val"}]]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$in": ["$age", [_type]]}, {"age": {"$in": [_type]}}) - - def test_conversion_various_types(self): - for _type, val in self.CONVERTIBLE_TYPES.items(): - with self.subTest(_type=_type, val=val): - self._test_conversion_valid_type(val) - - -class LogicalTests(ConversionTestCase): - def test_and(self): - expr = { - "$and": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - ] - } - expected = { - "$and": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - ] - } - self.assertConversionEqual(expr, expected) - - def test_or(self): - expr = { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - ] - } - expected = { - "$or": [ - {"status": "active"}, - {"category": {"$in": ["electronics", "books"]}}, - ] - } - self.assertConversionEqual(expr, expected) - - def test_or_failure(self): - expr = { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$in": ["$category", ["electronics", "books"]]}, - { - "$and": [ - {"verified": True}, - {"$gt": ["$price", "$min_price"]}, # Not optimizable - ] - }, - ] - } - self.assertNotOptimizable(expr) - - def test_mixed(self): - expr = { - "$and": [ - { - "$or": [ - {"$eq": ["$status", "active"]}, - {"$gt": ["$views", 1000]}, - ] - }, - {"$in": ["$category", ["electronics", "books"]]}, - {"$eq": ["$verified", True]}, - {"$lte": ["$price", 2000]}, - ] - } - expected = { - "$and": [ - {"$or": [{"status": "active"}, {"views": {"$gt": 1000}}]}, - {"category": {"$in": ["electronics", "books"]}}, - {"verified": True}, - {"price": {"$lte": 2000}}, - ] - } - self.assertConversionEqual(expr, expected) - - -class GtTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$gt": ["$price", 100]}, {"price": {"$gt": 100}}) - - def test_no_conversion_non_simple_field(self): - self.assertNotOptimizable({"$gt": ["$price", "$min_price"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$gt": ["$price", {}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$gt": ["$price", _type]}, {"price": {"$gt": _type}}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - -class GteTests(ConversionTestCase): - def test_conversion(self): - expr = {"$gte": ["$price", 100]} - expected = {"price": {"$gte": 100}} - self.assertConversionEqual(expr, expected) - - def test_no_conversion_non_simple_field(self): - expr = {"$gte": ["$price", "$min_price"]} - self.assertNotOptimizable(expr) - - def test_no_conversion_dict_value(self): - expr = {"$gte": ["$price", {}]} - self.assertNotOptimizable(expr) - - def _test_conversion_valid_type(self, _type): - expr = {"$gte": ["$price", _type]} - expected = {"price": {"$gte": _type}} - self.assertConversionEqual(expr, expected) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - -class LtTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$lt": ["$price", 100]}, {"price": {"$lt": 100}}) - - def test_no_conversion_non_simple_field(self): - self.assertNotOptimizable({"$lt": ["$price", "$min_price"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$lt": ["$price", {}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$lt": ["$price", _type]}, {"price": {"$lt": _type}}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) - - -class LteTests(ConversionTestCase): - def test_conversion(self): - self.assertConversionEqual({"$lte": ["$price", 100]}, {"price": {"$lte": 100}}) - - def test_no_conversion_non_simple_field(self): - self.assertNotOptimizable({"$lte": ["$price", "$min_price"]}) - - def test_no_conversion_dict_value(self): - self.assertNotOptimizable({"$lte": ["$price", {}]}) - - def _test_conversion_valid_type(self, _type): - self.assertConversionEqual({"$lte": ["$price", _type]}, {"price": {"$lte": _type}}) - - def test_conversion_various_types(self): - self._test_conversion_various_types(self._test_conversion_valid_type) diff --git a/tests/expressions_/test_value.py b/tests/expressions_/test_value.py index 3ba86b899..4adbc1d3e 100644 --- a/tests/expressions_/test_value.py +++ b/tests/expressions_/test_value.py @@ -24,7 +24,7 @@ def test_decimal(self): self.assertEqual(Value(Decimal("1.0")).as_mql(None, None), Decimal128("1.0")) def test_list(self): - self.assertEqual(Value([1, 2]).as_mql(None, None), {"$literal": [1, 2]}) + self.assertEqual(Value([1, 2]).as_mql(None, None, as_expr=True), {"$literal": [1, 2]}) def test_time(self): self.assertEqual( @@ -36,7 +36,7 @@ def test_timedelta(self): self.assertEqual(Value(datetime.timedelta(3600)).as_mql(None, None), 311040000000.0) def test_int(self): - self.assertEqual(Value(1).as_mql(None, None), {"$literal": 1}) + self.assertEqual(Value(1).as_mql(None, None, as_expr=True), {"$literal": 1}) def test_str(self): self.assertEqual(Value("foo").as_mql(None, None), "foo") diff --git a/tests/lookup_/tests.py b/tests/lookup_/tests.py index 6fce89942..6166e1003 100644 --- a/tests/lookup_/tests.py +++ b/tests/lookup_/tests.py @@ -1,3 +1,4 @@ +from bson import SON from django.test import TestCase from django_mongodb_backend.test import MongoTestCaseMixin @@ -5,12 +6,12 @@ from .models import Book, Number -class NumericLookupTests(TestCase): +class NumericLookupTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.objs = Number.objects.bulk_create(Number(num=x) for x in range(5)) # Null values should be excluded in less than queries. - Number.objects.create() + cls.null_number = Number.objects.create() def test_lt(self): self.assertQuerySetEqual(Number.objects.filter(num__lt=3), self.objs[:3]) @@ -18,6 +19,20 @@ def test_lt(self): def test_lte(self): self.assertQuerySetEqual(Number.objects.filter(num__lte=3), self.objs[:4]) + def test_empty_range(self): + with self.assertNumQueries(0): + self.assertQuerySetEqual(Number.objects.filter(num__range=[3, 1]), []) + + def test_full_range(self): + with self.assertNumQueries(1) as ctx: + self.assertQuerySetEqual( + Number.objects.filter(num__range=[None, None]), [self.null_number, *self.objs] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "lookup__number", [{"$addFields": {"num": "$num"}}, {"$sort": SON([("num", 1)])}] + ) + class RegexTests(MongoTestCaseMixin, TestCase): def test_mql(self): @@ -29,15 +44,7 @@ def test_mql(self): self.assertAggregateQuery( query, "lookup__book", - [ - { - "$match": { - "$expr": { - "$regexMatch": {"input": "$title", "regex": "Moby Dick", "options": ""} - } - } - } - ], + [{"$match": {"title": {"$regex": "Moby Dick", "$options": ""}}}], ) diff --git a/tests/model_fields_/test_arrayfield.py b/tests/model_fields_/test_arrayfield.py index 06a918ebc..00fdf663e 100644 --- a/tests/model_fields_/test_arrayfield.py +++ b/tests/model_fields_/test_arrayfield.py @@ -21,6 +21,8 @@ from django.utils import timezone from django_mongodb_backend.fields import ArrayField +from django_mongodb_backend.fields.array import Array +from django_mongodb_backend.test import MongoTestCaseMixin from .models import ( ArrayEnumModel, @@ -216,7 +218,7 @@ def test_nested_nullable_base_field(self): self.assertEqual(instance.field_nested, [[None, None], [None, None]]) -class QueryingTests(TestCase): +class QueryingTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.objs = NullableIntegerArrayModel.objects.bulk_create( @@ -241,9 +243,34 @@ def test_empty_list(self): self.assertEqual(obj.field, []) self.assertEqual(obj.empty_array, []) - def test_exact(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1] + def test_exact_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[Value(3) / 3]), self.objs[:1] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$eq": ["$field", [{"$divide": [{"$literal": 3}, {"$literal": 3}]}]] + } + } + } + ], + ) + + def test_exact_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[1]), self.objs[:1] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__nullableintegerarraymodel", [{"$match": {"field": [1]}}] ) def test_exact_null_only_array(self): @@ -261,23 +288,42 @@ def test_exact_null_only_nested_array(self): obj2 = NullableIntegerArrayModel.objects.create( field_nested=[[None, None], [None, None]], ) - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter( - field_nested__exact=[[None, None]], - ), - [obj1], + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field_nested__exact=[[None, None]], + ), + [obj1], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"field_nested": [[None, None]]}}], ) - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter( - field_nested__exact=[[None, None], [None, None]], - ), - [obj2], + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field_nested__exact=[[None, None], [None, None]], + ), + [obj2], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"field_nested": [[None, None], [None, None]]}}], ) def test_exact_with_expression(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]), - self.objs[:1], + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__exact=[Value(1)]), + self.objs[:1], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__nullableintegerarraymodel", [{"$match": {"field": [1]}}] ) def test_exact_charfield(self): @@ -291,24 +337,140 @@ def test_exact_nested(self): ) def test_isnull(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:] + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__isnull=True), self.objs[-1:] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"$or": [{"field": {"$exists": False}}, {"field": None}]}}], ) - def test_gt(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__gt=[0]), self.objs[:4] + def test_gt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__gt=Array(Value(0) * 3)), + self.objs[:4], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$gt": ["$field", [{"$multiply": [{"$literal": 0}, {"$literal": 3}]}]] + } + } + } + ], ) - def test_lt(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1] + def test_gt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__gt=Array(0)), self.objs[:4] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__nullableintegerarraymodel", [{"$match": {"field": {"$gt": [0]}}}] ) - def test_in(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]), - self.objs[:2], + def test_lt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__lt=Array(Value(1) + 1)), + self.objs[:1], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$and": [ + {"$lt": ["$field", [{"$add": [{"$literal": 1}, {"$literal": 1}]}]]}, + { + "$not": { + "$or": [ + {"$eq": [{"$type": "$field"}, "missing"]}, + {"$eq": ["$field", None]}, + ] + } + }, + ] + } + } + } + ], + ) + + def test_lt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__lt=[2]), self.objs[:1] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$and": [ + {"field": {"$lt": [2]}}, + {"$and": [{"field": {"$exists": True}}, {"field": {"$ne": None}}]}, + ] + } + } + ], + ) + + def test_in_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter( + field__in=Array(Array(Value(1) * 1), Array(2)) + ), + self.objs[:2], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$in": ( + "$field", + [ + [{"$multiply": [{"$literal": 1}, {"$literal": 1}]}], + [{"$literal": 2}], + ], + ) + } + } + } + ], + ) + + def test_in_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__in=[[1], [2]]), + self.objs[:2], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"field": {"$in": ([1], [2])}}}], ) def test_in_subquery(self): @@ -352,10 +514,45 @@ def test_contained_by_including_F_object(self): self.objs[:3], ) - def test_contains(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__contains=[2]), - self.objs[1:3], + def test_contains_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contains=[Value(1) + 1]), + self.objs[1:3], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$and": [ + {"$ne": ["$field", None]}, + {"$ne": [[{"$add": [{"$literal": 1}, {"$literal": 1}]}], None]}, + { + "$setIsSubset": [ + [{"$add": [{"$literal": 1}, {"$literal": 1}]}], + "$field", + ] + }, + ] + } + } + } + ], + ) + + def test_contains_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__contains=[2]), + self.objs[1:3], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__nullableintegerarraymodel", [{"$match": {"field": {"$all": [2]}}}] ) def test_contains_subquery(self): @@ -395,7 +592,16 @@ def test_contains_including_expression(self): def test_icontains(self): instance = CharArrayModel.objects.create(field=["FoO"]) - self.assertSequenceEqual(CharArrayModel.objects.filter(field__icontains="foo"), [instance]) + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + CharArrayModel.objects.filter(field__icontains="foo"), [instance] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__chararraymodel", + [{"$match": {"field": {"$regex": "foo", "$options": "i"}}}], + ) def test_contains_charfield(self): self.assertSequenceEqual(CharArrayModel.objects.filter(field__contains=["text"]), []) @@ -455,10 +661,51 @@ def test_index_used_on_nested_data(self): NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance] ) - def test_overlap(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), - self.objs[0:3], + def test_overlap_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__overlap=[1, Value(1) + 1]), + self.objs[0:3], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [ + { + "$match": { + "$expr": { + "$and": [ + {"$ne": ["$field", None]}, + { + "$size": { + "$setIntersection": [ + [ + {"$literal": 1}, + {"$add": [{"$literal": 1}, {"$literal": 1}]}, + ], + "$field", + ] + } + }, + ] + } + } + } + ], + ) + + def test_overlap_path(self): + with self.assertNumQueries(1) as ctx: + self.assertSequenceEqual( + NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), + self.objs[0:3], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__nullableintegerarraymodel", + [{"$match": {"field": {"$in": [1, 2]}}}], ) def test_index_annotation(self): diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index 332cba79b..931900114 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -10,12 +10,14 @@ Max, OuterRef, Sum, + Value, ) from django.test import SimpleTestCase, TestCase from django.test.utils import isolate_apps from django_mongodb_backend.fields import EmbeddedModelField from django_mongodb_backend.models import EmbeddedModel +from django_mongodb_backend.test import MongoTestCaseMixin from .models import ( Address, @@ -130,7 +132,7 @@ def test_embedded_model_field_respects_db_column(self): self.assertEqual(query[0]["data"]["integer_"], 5) -class QueryingTests(TestCase): +class QueryingTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.objs = [ @@ -144,23 +146,360 @@ def setUpTestData(cls): for x in range(6) ] - def test_exact(self): - self.assertCountEqual(Holder.objects.filter(data__integer=3), [self.objs[3]]) + def test_exact_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer=Value(4) - 1), [self.objs[3]]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$eq": [ + {"$getField": {"input": "$data", "field": "integer_"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_exact_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer=3), [self.objs[3]]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery(query, "model_fields__holder", [{"$match": {"data.integer_": 3}}]) + + def test_lt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__lt=Value(4) - 1), self.objs[:3] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$and": [ + { + "$lt": [ + {"$getField": {"input": "$data", "field": "integer_"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + }, + { + "$not": { + "$or": [ + { + "$eq": [ + { + "$type": { + "$getField": { + "input": "$data", + "field": "integer_", + } + } + }, + "missing", + ] + }, + { + "$eq": [ + { + "$getField": { + "input": "$data", + "field": "integer_", + } + }, + None, + ] + }, + ] + } + }, + ] + } + } + } + ], + ) - def test_lt(self): - self.assertCountEqual(Holder.objects.filter(data__integer__lt=3), self.objs[:3]) + def test_lt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__lt=3), self.objs[:3]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$and": [ + {"data.integer_": {"$lt": 3}}, + { + "$and": [ + {"data.integer_": {"$exists": True}}, + {"data.integer_": {"$ne": None}}, + ] + }, + ] + } + } + ], + ) - def test_lte(self): - self.assertCountEqual(Holder.objects.filter(data__integer__lte=3), self.objs[:4]) + def test_lte_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__lte=Value(4) - 1), self.objs[:4] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$and": [ + { + "$lte": [ + {"$getField": {"input": "$data", "field": "integer_"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + }, + { + "$not": { + "$or": [ + { + "$eq": [ + { + "$type": { + "$getField": { + "input": "$data", + "field": "integer_", + } + } + }, + "missing", + ] + }, + { + "$eq": [ + { + "$getField": { + "input": "$data", + "field": "integer_", + } + }, + None, + ] + }, + ] + } + }, + ] + } + } + } + ], + ) - def test_gt(self): - self.assertCountEqual(Holder.objects.filter(data__integer__gt=3), self.objs[4:]) + def test_lte_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__lte=3), self.objs[:4]) + query = ctx.captured_queries[0]["sql"] - def test_gte(self): - self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:]) + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$and": [ + {"data.integer_": {"$lte": 3}}, + { + "$and": [ + {"data.integer_": {"$exists": True}}, + {"data.integer_": {"$ne": None}}, + ] + }, + ] + } + } + ], + ) - def test_range(self): - self.assertCountEqual(Holder.objects.filter(data__integer__range=(2, 4)), self.objs[2:5]) + def test_gt_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__gt=Value(4) - 1), self.objs[4:] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$gt": [ + {"$getField": {"input": "$data", "field": "integer_"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_gt_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__gt=3), self.objs[4:]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__holder", [{"$match": {"data.integer_": {"$gt": 3}}}] + ) + + def test_gte_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__gte=Value(4) - 1), self.objs[3:] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$gte": [ + {"$getField": {"input": "$data", "field": "integer_"}}, + {"$subtract": [{"$literal": 4}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_gte_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual(Holder.objects.filter(data__integer__gte=3), self.objs[3:]) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__holder", [{"$match": {"data.integer_": {"$gte": 3}}}] + ) + + def test_range_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__range=(2, Value(5) - 1)), self.objs[2:5] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$expr": { + "$and": [ + { + "$or": [ + { + "$or": [ + {"$eq": [{"$type": {"$literal": 2}}, "missing"]}, + {"$eq": [{"$literal": 2}, None]}, + ] + }, + { + "$gte": [ + { + "$getField": { + "input": "$data", + "field": "integer_", + } + }, + {"$literal": 2}, + ] + }, + ] + }, + { + "$or": [ + { + "$or": [ + { + "$eq": [ + { + "$type": { + "$subtract": [ + {"$literal": 5}, + {"$literal": 1}, + ] + } + }, + "missing", + ] + }, + { + "$eq": [ + { + "$subtract": [ + {"$literal": 5}, + {"$literal": 1}, + ] + }, + None, + ] + }, + ] + }, + { + "$lte": [ + { + "$getField": { + "input": "$data", + "field": "integer_", + } + }, + {"$subtract": [{"$literal": 5}, {"$literal": 1}]}, + ] + }, + ] + }, + ] + } + } + } + ], + ) + + def test_range_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Holder.objects.filter(data__integer__range=(2, 4)), self.objs[2:5] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__holder", + [ + { + "$match": { + "$and": [{"data.integer_": {"$gte": 2}}, {"data.integer_": {"$lte": 4}}] + } + } + ], + ) def test_exact_decimal(self): # EmbeddedModelField lookups call @@ -247,6 +586,17 @@ def test_nested(self): ) self.assertCountEqual(Book.objects.filter(author__address__city="NYC"), [obj]) + def test_filter_by_simple_annotate(self): + obj = Book.objects.create( + author=Author(name="Shakespeare", age=55, address=Address(city="NYC", state="NY")) + ) + with self.assertNumQueries(1) as ctx: + book_from_ny = ( + Book.objects.annotate(city=F("author__address__city")).filter(city="NYC").first() + ) + self.assertCountEqual(book_from_ny.city, obj.author.address.city) + self.assertIn("{'$match': {'author.address.city': 'NYC'}}", ctx.captured_queries[0]["sql"]) + class ArrayFieldTests(TestCase): @classmethod diff --git a/tests/model_fields_/test_embedded_model_array.py b/tests/model_fields_/test_embedded_model_array.py index 5ae396e2a..837fa2b94 100644 --- a/tests/model_fields_/test_embedded_model_array.py +++ b/tests/model_fields_/test_embedded_model_array.py @@ -5,11 +5,13 @@ from django.core.exceptions import FieldDoesNotExist from django.db import connection, models from django.db.models.expressions import Value +from django.db.models.functions import Concat from django.test import SimpleTestCase, TestCase from django.test.utils import CaptureQueriesContext, isolate_apps from django_mongodb_backend.fields import ArrayField, EmbeddedModelArrayField from django_mongodb_backend.models import EmbeddedModel +from django_mongodb_backend.test import MongoTestCaseMixin from .models import Artifact, Audit, Exhibit, Movie, Restoration, Review, Section, Tour @@ -84,7 +86,7 @@ def test_embedded_model_field_respects_db_column(self): self.assertEqual(query[0]["reviews"][0]["title_"], "Awesome") -class QueryingTests(TestCase): +class QueryingTests(MongoTestCaseMixin, TestCase): @classmethod def setUpTestData(cls): cls.egypt = Exhibit.objects.create( @@ -177,23 +179,171 @@ def setUpTestData(cls): cls.audit_2 = Audit.objects.create(section_number=2, reviewed=True) cls.audit_3 = Audit.objects.create(section_number=5, reviewed=False) - def test_exact(self): - self.assertCountEqual( - Exhibit.objects.filter(sections__number=1), [self.egypt, self.wonders] + def test_exact_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__number=Value(2) - 1), [self.egypt, self.wonders] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": "$sections", + "as": "item", + "in": { + "$eq": [ + "$$item.number", + { + "$subtract": [ + {"$literal": 2}, + {"$literal": 1}, + ] + }, + ] + }, + } + }, + [], + ] + } + } + } + } + ], ) - def test_array_index(self): - self.assertCountEqual( - Exhibit.objects.filter(sections__0__number=1), - [self.egypt, self.wonders], + def test_exact_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__number=1), [self.egypt, self.wonders] + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__exhibit", [{"$match": {"sections.number": 1}}] ) - def test_nested_array_index(self): - self.assertCountEqual( - Exhibit.objects.filter( - main_section__artifacts__restorations__0__restored_by="Zacarias" - ), - [self.lost_empires], + def test_array_index_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__0__number=Value(2) - 1), + [self.egypt, self.wonders], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$eq": [ + { + "$getField": { + "input": {"$arrayElemAt": ["$sections", 0]}, + "field": "number", + } + }, + {"$subtract": [{"$literal": 2}, {"$literal": 1}]}, + ] + } + } + } + ], + ) + + def test_array_index_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter(sections__0__number=1), + [self.egypt, self.wonders], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, "model_fields__exhibit", [{"$match": {"sections.0.number": 1}}] + ) + + def test_nested_array_index_expr(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter( + main_section__artifacts__restorations__0__restored_by=Concat( + Value("Z"), Value("acarias") + ) + ), + [self.lost_empires], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + "$expr": { + "$anyElementTrue": { + "$ifNull": [ + { + "$map": { + "input": { + "$getField": { + "input": "$main_section", + "field": "artifacts", + } + }, + "as": "item", + "in": { + "$eq": [ + { + "$getField": { + "input": { + "$arrayElemAt": [ + "$$item.restorations", + 0, + ] + }, + "field": "restored_by", + } + }, + { + "$concat": [ + {"$ifNull": ["Z", ""]}, + {"$ifNull": ["acarias", ""]}, + ] + }, + ] + }, + } + }, + [], + ] + } + } + } + } + ], + ) + + def test_nested_array_index_path(self): + with self.assertNumQueries(1) as ctx: + self.assertCountEqual( + Exhibit.objects.filter( + main_section__artifacts__restorations__0__restored_by="Zacarias" + ), + [self.lost_empires], + ) + query = ctx.captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [{"$match": {"main_section.artifacts.restorations.0.restored_by": "Zacarias"}}], ) def test_array_slice(self): @@ -207,7 +357,21 @@ def test_filter_unsupported_lookups_in_json(self): kwargs = {f"main_section__artifacts__metadata__origin__{lookup}": ["Pergamon", "Egypt"]} with CaptureQueriesContext(connection) as captured_queries: self.assertCountEqual(Exhibit.objects.filter(**kwargs), []) - self.assertIn(f"'field': '{lookup}'", captured_queries[0]["sql"]) + query = captured_queries[0]["sql"] + self.assertAggregateQuery( + query, + "model_fields__exhibit", + [ + { + "$match": { + f"main_section.artifacts.metadata.origin.{lookup}": [ + "Pergamon", + "Egypt", + ] + } + } + ], + ) def test_len(self): self.assertCountEqual(Exhibit.objects.filter(sections__len=10), []) @@ -310,10 +474,19 @@ def test_nested_lookup(self): with self.assertRaisesMessage(ValueError, msg): Exhibit.objects.filter(sections__artifacts__name="") - def test_foreign_field_exact(self): + def test_foreign_field_exact_path(self): + """Querying from a foreign key to an EmbeddedModelArrayField.""" + with self.assertNumQueries(1) as ctx: + qs = Tour.objects.filter(exhibit__sections__number=1) + self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + self.assertNotIn("anyElementTrue", ctx.captured_queries[0]["sql"]) + + def test_foreign_field_exact_expr(self): """Querying from a foreign key to an EmbeddedModelArrayField.""" - qs = Tour.objects.filter(exhibit__sections__number=1) - self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + with self.assertNumQueries(1) as ctx: + qs = Tour.objects.filter(exhibit__sections__number=Value(2) - Value(1)) + self.assertCountEqual(qs, [self.egypt_tour, self.wonders_tour]) + self.assertIn("anyElementTrue", ctx.captured_queries[0]["sql"]) def test_foreign_field_with_slice(self): qs = Tour.objects.filter(exhibit__sections__0_2__number__in=[1, 2]) diff --git a/tests/queries_/test_mql.py b/tests/queries_/test_mql.py index ffd1e2e32..e8837bf8a 100644 --- a/tests/queries_/test_mql.py +++ b/tests/queries_/test_mql.py @@ -11,9 +11,7 @@ class MQLTests(MongoTestCaseMixin, TestCase): def test_all(self): with self.assertNumQueries(1) as ctx: list(Author.objects.all()) - self.assertAggregateQuery( - ctx.captured_queries[0]["sql"], "queries__author", [{"$match": {}}] - ) + self.assertAggregateQuery(ctx.captured_queries[0]["sql"], "queries__author", []) def test_join(self): with self.assertNumQueries(1) as ctx: @@ -29,12 +27,14 @@ def test_join(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Bob"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Bob"}, + ] } } ], @@ -62,12 +62,14 @@ def test_filter_on_local_and_related_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "John"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "John"}, + ] } } ], @@ -123,22 +125,19 @@ def test_filter_on_self_join_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - { - "$and": [ - { - "$eq": [ - "$group_id", - ObjectId("6891ff7822e475eddc20f159"), - ] - }, - {"$eq": ["$name", "parent"]}, - ] - }, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + { + "$and": [ + {"group_id": ObjectId("6891ff7822e475eddc20f159")}, + {"name": "parent"}, + ] + }, + ] } } ], @@ -171,17 +170,16 @@ def test_filter_on_reverse_foreignkey_relation(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$order_id"]}, - { - "$eq": [ - "$status", - ObjectId("6891ff7822e475eddc20f159"), + "$and": [ + { + "$expr": { + "$and": [ + {"$eq": ["$$parent__field__0", "$order_id"]} ] - }, - ] - } + } + }, + {"status": ObjectId("6891ff7822e475eddc20f159")}, + ] } } ], @@ -215,17 +213,16 @@ def test_filter_on_local_and_nested_join_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$order_id"]}, - { - "$eq": [ - "$status", - ObjectId("6891ff7822e475eddc20f159"), + "$and": [ + { + "$expr": { + "$and": [ + {"$eq": ["$$parent__field__0", "$order_id"]} ] - }, - ] - } + } + }, + {"status": ObjectId("6891ff7822e475eddc20f159")}, + ] } } ], @@ -240,12 +237,14 @@ def test_filter_on_local_and_nested_join_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "My Order"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "My Order"}, + ] } } ], @@ -276,6 +275,7 @@ def test_negated_related_filter_is_not_pushable(self): [ { "$lookup": { + "as": "queries__author", "from": "queries__author", "let": {"parent__field__0": "$author_id"}, "pipeline": [ @@ -285,11 +285,10 @@ def test_negated_related_filter_is_not_pushable(self): } } ], - "as": "queries__author", } }, {"$unwind": "$queries__author"}, - {"$match": {"$expr": {"$not": {"$eq": ["$queries__author.name", "John"]}}}}, + {"$match": {"$nor": [{"queries__author.name": "John"}]}}, ], ) @@ -341,21 +340,25 @@ def test_push_equality_between_parent_and_child_fields(self): [ { "$lookup": { + "as": "queries__orderitem", "from": "queries__orderitem", "let": {"parent__field__0": "$_id", "parent__field__1": "$_id"}, "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$order_id"]}, - {"$eq": ["$status", "$$parent__field__1"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [ + {"$eq": ["$$parent__field__0", "$order_id"]} + ] + } + }, + {"$expr": {"$eq": ["$status", "$$parent__field__1"]}}, + ] } } ], - "as": "queries__orderitem", } }, {"$unwind": "$queries__orderitem"}, @@ -398,12 +401,14 @@ def test_simple_related_filter_is_pushed(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Alice"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Alice"}, + ] } } ], @@ -416,6 +421,7 @@ def test_simple_related_filter_is_pushed(self): ) def test_subquery_join_is_pushed(self): + # TODO; isn't fully OPTIMIZED with self.assertNumQueries(1) as ctx: list(Library.objects.filter(~models.Q(readers__name="Alice"))) @@ -436,12 +442,21 @@ def test_subquery_join_is_pushed(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Alice"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [ + { + "$eq": [ + "$$parent__field__0", + "$_id", + ] + } + ] + } + }, + {"name": "Alice"}, + ] } } ], @@ -480,21 +495,28 @@ def test_subquery_join_is_pushed(self): }, { "$match": { - "$expr": { - "$not": { - "$eq": [ - { - "$not": { - "$or": [ - {"$eq": [{"$type": "$__subquery0.a"}, "missing"]}, - {"$eq": ["$__subquery0.a", None]}, - ] - } - }, - True, - ] + "$nor": [ + { + "$expr": { + "$eq": [ + { + "$not": { + "$or": [ + { + "$eq": [ + {"$type": "$__subquery0.a"}, + "missing", + ] + }, + {"$eq": ["$__subquery0.a", None]}, + ] + } + }, + True, + ] + } } - } + ] } }, ], @@ -531,12 +553,14 @@ def test_filter_on_local_and_related_fields(self): "pipeline": [ { "$match": { - "$expr": { - "$and": [ - {"$eq": ["$$parent__field__0", "$_id"]}, - {"$eq": ["$name", "Alice"]}, - ] - } + "$and": [ + { + "$expr": { + "$and": [{"$eq": ["$$parent__field__0", "$_id"]}] + } + }, + {"name": "Alice"}, + ] } } ],