Skip to content

Commit 0a56eb1

Browse files
committed
Make operators combinable and add compound expressions.
1 parent 7eeb73b commit 0a56eb1

File tree

1 file changed

+202
-1
lines changed

1 file changed

+202
-1
lines changed

django_mongodb_backend/expressions/search.py

Lines changed: 202 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,65 @@ def cast_as_field(path):
1212
return F(path) if isinstance(path, str) else path
1313

1414

15-
class SearchExpression(Expression):
15+
class Operator:
16+
AND = "AND"
17+
OR = "OR"
18+
NOT = "NOT"
19+
20+
def __init__(self, operator):
21+
self.operator = operator
22+
23+
def __eq__(self, other):
24+
if isinstance(other, str):
25+
return self.operator == other
26+
return self.operator == other.operator
27+
28+
def negate(self):
29+
if self.operator == self.AND:
30+
return Operator(self.OR)
31+
if self.operator == self.OR:
32+
return Operator(self.AND)
33+
return Operator(self.operator)
34+
35+
def __hash__(self):
36+
return hash(self.operator)
37+
38+
def __str__(self):
39+
return self.operator
40+
41+
def __repr__(self):
42+
return self.operator
43+
44+
45+
class SearchCombinable:
46+
def _combine(self, other, connector):
47+
if not isinstance(self, CompoundExpression | CombinedSearchExpression):
48+
lhs = CompoundExpression(must=[self])
49+
else:
50+
lhs = self
51+
if other and not isinstance(other, CompoundExpression | CombinedSearchExpression):
52+
rhs = CompoundExpression(must=[other])
53+
else:
54+
rhs = other
55+
return CombinedSearchExpression(lhs, connector, rhs)
56+
57+
def __invert__(self):
58+
return self._combine(None, Operator(Operator.NOT))
59+
60+
def __and__(self, other):
61+
return self._combine(other, Operator(Operator.AND))
62+
63+
def __rand__(self, other):
64+
return self._combine(other, Operator(Operator.AND))
65+
66+
def __or__(self, other):
67+
return self._combine(other, Operator(Operator.OR))
68+
69+
def __ror__(self, other):
70+
return self._combine(other, Operator(Operator.OR))
71+
72+
73+
class SearchExpression(SearchCombinable, Expression):
1674
"""Base expression node for MongoDB Atlas `$search` stages.
1775
1876
This class bridges Django's `Expression` API with the MongoDB Atlas
@@ -677,6 +735,149 @@ def get_search_fields(self, compiler, connection):
677735
return needed_fields
678736

679737

738+
class CompoundExpression(SearchExpression):
739+
"""
740+
Compound expression that combines multiple search clauses using boolean logic.
741+
742+
This expression corresponds to the `compound` operator in MongoDB Atlas Search,
743+
allowing fine-grained control by combining multiple sub-expressions with
744+
`must`, `must_not`, `should`, and `filter` clauses.
745+
746+
Example:
747+
CompoundExpression(
748+
must=[expr1, expr2],
749+
must_not=[expr3],
750+
should=[expr4],
751+
minimum_should_match=1
752+
)
753+
754+
Args:
755+
must: List of expressions that **must** match.
756+
must_not: List of expressions that **must not** match.
757+
should: List of expressions that **should** match (optional relevance boost).
758+
filter: List of expressions to filter results without affecting relevance.
759+
score: Optional expression to adjust scoring.
760+
minimum_should_match: Minimum number of `should` clauses that must match.
761+
762+
Reference: https://www.mongodb.com/docs/atlas/atlas-search/compound/
763+
"""
764+
765+
def __init__(
766+
self,
767+
must=None,
768+
must_not=None,
769+
should=None,
770+
filter=None,
771+
score=None,
772+
minimum_should_match=None,
773+
):
774+
self.must = must or []
775+
self.must_not = must_not or []
776+
self.should = should or []
777+
self.filter = filter or []
778+
self.score = score
779+
self.minimum_should_match = minimum_should_match
780+
781+
def get_search_fields(self, compiler, connection):
782+
fields = set()
783+
for clause in self.must + self.should + self.filter + self.must_not:
784+
fields.update(clause.get_search_fields(compiler, connection))
785+
return fields
786+
787+
def resolve_expression(
788+
self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
789+
):
790+
c = self.copy()
791+
c.is_summary = summarize
792+
c.must = [
793+
expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.must
794+
]
795+
c.must_not = [
796+
expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.must_not
797+
]
798+
c.should = [
799+
expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.should
800+
]
801+
c.filter = [
802+
expr.resolve_expression(query, allow_joins, reuse, summarize) for expr in self.filter
803+
]
804+
return c
805+
806+
def search_operator(self, compiler, connection):
807+
params = {}
808+
if self.must:
809+
params["must"] = [clause.search_operator(compiler, connection) for clause in self.must]
810+
if self.must_not:
811+
params["mustNot"] = [
812+
clause.search_operator(compiler, connection) for clause in self.must_not
813+
]
814+
if self.should:
815+
params["should"] = [
816+
clause.search_operator(compiler, connection) for clause in self.should
817+
]
818+
if self.filter:
819+
params["filter"] = [
820+
clause.search_operator(compiler, connection) for clause in self.filter
821+
]
822+
if self.minimum_should_match is not None:
823+
params["minimumShouldMatch"] = self.minimum_should_match
824+
return {"compound": params}
825+
826+
def negate(self):
827+
return CompoundExpression(must_not=[self])
828+
829+
830+
class CombinedSearchExpression(SearchExpression):
831+
"""
832+
Combines two search expressions with a logical operator.
833+
834+
This expression allows combining two Atlas Search expressions
835+
(left-hand side and right-hand side) using a boolean operator
836+
such as `and`, `or`, or `not`.
837+
838+
Example:
839+
CombinedSearchExpression(expr1, "and", expr2)
840+
841+
Args:
842+
lhs: The left-hand search expression.
843+
operator: The boolean operator as a string (e.g., "and", "or", "not").
844+
rhs: The right-hand search expression.
845+
"""
846+
847+
def __init__(self, lhs, operator, rhs):
848+
self.lhs = lhs
849+
self.operator = operator
850+
self.rhs = rhs
851+
852+
def get_source_expressions(self):
853+
return [self.lhs, self.rhs]
854+
855+
def set_source_expressions(self, exprs):
856+
self.lhs, self.rhs = exprs
857+
858+
@staticmethod
859+
def resolve(node, negated=False):
860+
if node is None:
861+
return None
862+
# Leaf, resolve the compoundExpression
863+
if isinstance(node, CompoundExpression):
864+
return node.negate() if negated else node
865+
# Apply De Morgan's Laws.
866+
operator = node.operator.negate() if negated else node.operator
867+
negated = negated != (node.operator == Operator.NOT)
868+
lhs_compound = node.resolve(node.lhs, negated)
869+
rhs_compound = node.resolve(node.rhs, negated)
870+
if operator == Operator.OR:
871+
return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1)
872+
if operator == Operator.AND:
873+
return CompoundExpression(must=[lhs_compound, rhs_compound])
874+
return lhs_compound
875+
876+
def as_mql(self, compiler, connection):
877+
expression = self.resolve(self)
878+
return expression.as_mql(compiler, connection)
879+
880+
680881
class SearchScoreOption(Expression):
681882
"""Class to mutate scoring on a search operation"""
682883

0 commit comments

Comments
 (0)