Skip to content

Commit 8b49f07

Browse files
committed
Add combinable operators.
1 parent 2b91578 commit 8b49f07

File tree

1 file changed

+132
-13
lines changed

1 file changed

+132
-13
lines changed

django_mongodb_backend/expressions/builtins.py

Lines changed: 132 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,56 @@ def value(self, compiler, connection): # noqa: ARG001
213213
return value
214214

215215

216-
class SearchExpression(Expression):
216+
class Operator:
217+
AND = "AND"
218+
OR = "OR"
219+
NOT = "NOT"
220+
221+
def __init__(self, operator):
222+
self.operator = operator
223+
224+
def __eq__(self, other):
225+
if isinstance(other, str):
226+
return self.operator == other
227+
return self.operator == other.operator
228+
229+
def negate(self):
230+
if self.operator == self.AND:
231+
return Operator(self.OR)
232+
if self.operator == self.OR:
233+
return Operator(self.AND)
234+
return Operator(self.operator)
235+
236+
237+
class SearchCombinable:
238+
def _combine(self, other, connector, reversed):
239+
if not isinstance(self, CompoundExpression | CombinedSearchExpression):
240+
lhs = CompoundExpression(must=[self])
241+
else:
242+
lhs = self
243+
if not isinstance(other, CompoundExpression | CombinedSearchExpression):
244+
rhs = CompoundExpression(must=[other])
245+
else:
246+
rhs = other
247+
return CombinedSearchExpression(lhs, connector, rhs)
248+
249+
def __invert__(self):
250+
return CombinedSearchExpression(self, Operator(Operator.NOT), None)
251+
252+
def __and__(self, other):
253+
return CombinedSearchExpression(self, Operator(Operator.AND), other)
254+
255+
def __rand__(self, other):
256+
return CombinedSearchExpression(self, Operator(Operator.AND), other)
257+
258+
def __or__(self, other):
259+
return CombinedSearchExpression(self, Operator(Operator.OR), other)
260+
261+
def __ror__(self, other):
262+
return CombinedSearchExpression(self, Operator(Operator.OR), other)
263+
264+
265+
class SearchExpression(SearchCombinable, Expression):
217266
output_field = FloatField()
218267

219268
def get_source_expressions(self):
@@ -530,6 +579,21 @@ def __init__(
530579
self.filter = filter
531580
super().__init__()
532581

582+
def __invert__(self):
583+
return ValueError("SearchVector cannot be negated")
584+
585+
def __and__(self, other):
586+
raise NotSupportedError("SearchVector cannot be combined")
587+
588+
def __rand__(self, other):
589+
raise NotSupportedError("SearchVector cannot be combined")
590+
591+
def __or__(self, other):
592+
raise NotSupportedError("SearchVector cannot be combined")
593+
594+
def __ror__(self, other):
595+
raise NotSupportedError("SearchVector cannot be combined")
596+
533597
def as_mql(self, compiler, connection):
534598
params = {
535599
"index": self.index,
@@ -546,15 +610,16 @@ def as_mql(self, compiler, connection):
546610
return {"$vectorSearch": params}
547611

548612

549-
class SearchScoreOption:
550-
"""Class to mutate scoring on a search operation"""
551-
552-
def __init__(self, definitions=None):
553-
self.definitions = definitions
554-
555-
556613
class CompoundExpression(SearchExpression):
557-
def __init__(self, must=None, must_not=None, should=None, filter=None, score=None):
614+
def __init__(
615+
self,
616+
must=None,
617+
must_not=None,
618+
should=None,
619+
filter=None,
620+
score=None,
621+
minimum_should_match=None,
622+
):
558623
self.must = must or []
559624
self.must_not = must_not or []
560625
self.should = should or []
@@ -563,13 +628,67 @@ def __init__(self, must=None, must_not=None, should=None, filter=None, score=Non
563628

564629
def as_mql(self, compiler, connection):
565630
params = {}
566-
for param in ["must", "must_not", "should", "filter"]:
567-
clauses = getattr(self, param)
568-
if clauses:
569-
params[param] = [clause.as_mql(compiler, connection) for clause in clauses]
631+
if self.must:
632+
params["must"] = [clause.as_mql(compiler, connection) for clause in self.must]
633+
if self.must_not:
634+
params["mustNot"] = [clause.as_mql(compiler, connection) for clause in self.must_not]
635+
if self.should:
636+
params["should"] = [clause.as_mql(compiler, connection) for clause in self.should]
637+
if self.filter:
638+
params["filter"] = [clause.as_mql(compiler, connection) for clause in self.filter]
639+
if self.minimum_should_match is not None:
640+
params["minimumShouldMatch"] = self.minimum_should_match
570641

571642
return {"$compound": params}
572643

644+
def negate(self):
645+
return CompoundExpression(must=self.must_not, must_not=self.must + self.filter)
646+
647+
648+
class CombinedSearchExpression(SearchExpression):
649+
def __init__(self, lhs, operator, rhs):
650+
self.lhs = lhs
651+
self.operator = operator
652+
self.rhs = rhs
653+
654+
@staticmethod
655+
def _flatten(node, negated=False):
656+
if node is None:
657+
return None
658+
# Leaf, resolve the compoundExpression
659+
if isinstance(node, CompoundExpression):
660+
return node.negate() if negated else node
661+
# Apply De Morgan's Laws.
662+
operator = node.operator.negate() if negated else node.operator
663+
negated = negated != (node.operator == Operator.NOT)
664+
lhs_compound = node._flatten(node.lhs, negated)
665+
rhs_compound = node._flatten(node.rhs, negated)
666+
if operator == Operator.OR:
667+
return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1)
668+
if node.operator == Operator.AND:
669+
return CompoundExpression(
670+
must=lhs_compound.must + rhs_compound.must,
671+
must_not=lhs_compound.must_not + rhs_compound.must_not,
672+
should=lhs_compound.should + rhs_compound.should,
673+
filter=lhs_compound.filter + rhs_compound.filter,
674+
)
675+
# it also can be written as:
676+
# this way is more consistent with OR, but the above is shorter in the debug query.
677+
# return CompoundExpression(must=[lhs_compound, rhs_compound])
678+
# not operator
679+
return lhs_compound
680+
681+
def as_mql(self, compiler, connection):
682+
expression = self._flatten(self)
683+
return expression.as_mql(compiler, connection)
684+
685+
686+
class SearchScoreOption:
687+
"""Class to mutate scoring on a search operation"""
688+
689+
def __init__(self, definitions=None):
690+
self.definitions = definitions
691+
573692

574693
def register_expressions():
575694
Case.as_mql = case

0 commit comments

Comments
 (0)