Skip to content

Commit cf64ab2

Browse files
committed
Add combinable operators.
1 parent d467f7d commit cf64ab2

File tree

1 file changed

+132
-13
lines changed

1 file changed

+132
-13
lines changed

django_mongodb_backend/expressions/builtins.py

Lines changed: 132 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,56 @@ def value(self, compiler, connection): # noqa: ARG001
208208
return value
209209

210210

211-
class SearchExpression(Expression):
211+
class Operator:
212+
AND = "AND"
213+
OR = "OR"
214+
NOT = "NOT"
215+
216+
def __init__(self, operator):
217+
self.operator = operator
218+
219+
def __eq__(self, other):
220+
if isinstance(other, str):
221+
return self.operator == other
222+
return self.operator == other.operator
223+
224+
def negate(self):
225+
if self.operator == self.AND:
226+
return Operator(self.OR)
227+
if self.operator == self.OR:
228+
return Operator(self.AND)
229+
return Operator(self.operator)
230+
231+
232+
class SearchCombinable:
233+
def _combine(self, other, connector, reversed):
234+
if not isinstance(self, CompoundExpression | CombinedSearchExpression):
235+
lhs = CompoundExpression(must=[self])
236+
else:
237+
lhs = self
238+
if not isinstance(other, CompoundExpression | CombinedSearchExpression):
239+
rhs = CompoundExpression(must=[other])
240+
else:
241+
rhs = other
242+
return CombinedSearchExpression(lhs, connector, rhs)
243+
244+
def __invert__(self):
245+
return CombinedSearchExpression(self, Operator(Operator.NOT), None)
246+
247+
def __and__(self, other):
248+
return CombinedSearchExpression(self, Operator(Operator.AND), other)
249+
250+
def __rand__(self, other):
251+
return CombinedSearchExpression(self, Operator(Operator.AND), other)
252+
253+
def __or__(self, other):
254+
return CombinedSearchExpression(self, Operator(Operator.OR), other)
255+
256+
def __ror__(self, other):
257+
return CombinedSearchExpression(self, Operator(Operator.OR), other)
258+
259+
260+
class SearchExpression(SearchCombinable, Expression):
212261
output_field = FloatField()
213262

214263
def get_source_expressions(self):
@@ -525,6 +574,21 @@ def __init__(
525574
self.filter = filter
526575
super().__init__()
527576

577+
def __invert__(self):
578+
return ValueError("SearchVector cannot be negated")
579+
580+
def __and__(self, other):
581+
raise NotSupportedError("SearchVector cannot be combined")
582+
583+
def __rand__(self, other):
584+
raise NotSupportedError("SearchVector cannot be combined")
585+
586+
def __or__(self, other):
587+
raise NotSupportedError("SearchVector cannot be combined")
588+
589+
def __ror__(self, other):
590+
raise NotSupportedError("SearchVector cannot be combined")
591+
528592
def as_mql(self, compiler, connection):
529593
params = {
530594
"index": self.index,
@@ -541,15 +605,16 @@ def as_mql(self, compiler, connection):
541605
return {"$vectorSearch": params}
542606

543607

544-
class SearchScoreOption:
545-
"""Class to mutate scoring on a search operation"""
546-
547-
def __init__(self, definitions=None):
548-
self.definitions = definitions
549-
550-
551608
class CompoundExpression(SearchExpression):
552-
def __init__(self, must=None, must_not=None, should=None, filter=None, score=None):
609+
def __init__(
610+
self,
611+
must=None,
612+
must_not=None,
613+
should=None,
614+
filter=None,
615+
score=None,
616+
minimum_should_match=None,
617+
):
553618
self.must = must or []
554619
self.must_not = must_not or []
555620
self.should = should or []
@@ -558,13 +623,67 @@ def __init__(self, must=None, must_not=None, should=None, filter=None, score=Non
558623

559624
def as_mql(self, compiler, connection):
560625
params = {}
561-
for param in ["must", "must_not", "should", "filter"]:
562-
clauses = getattr(self, param)
563-
if clauses:
564-
params[param] = [clause.as_mql(compiler, connection) for clause in clauses]
626+
if self.must:
627+
params["must"] = [clause.as_mql(compiler, connection) for clause in self.must]
628+
if self.must_not:
629+
params["mustNot"] = [clause.as_mql(compiler, connection) for clause in self.must_not]
630+
if self.should:
631+
params["should"] = [clause.as_mql(compiler, connection) for clause in self.should]
632+
if self.filter:
633+
params["filter"] = [clause.as_mql(compiler, connection) for clause in self.filter]
634+
if self.minimum_should_match is not None:
635+
params["minimumShouldMatch"] = self.minimum_should_match
565636

566637
return {"$compound": params}
567638

639+
def negate(self):
640+
return CompoundExpression(must=self.must_not, must_not=self.must + self.filter)
641+
642+
643+
class CombinedSearchExpression(SearchExpression):
644+
def __init__(self, lhs, operator, rhs):
645+
self.lhs = lhs
646+
self.operator = operator
647+
self.rhs = rhs
648+
649+
@staticmethod
650+
def _flatten(node, negated=False):
651+
if node is None:
652+
return None
653+
# Leaf, resolve the compoundExpression
654+
if isinstance(node, CompoundExpression):
655+
return node.negate() if negated else node
656+
# Apply De Morgan's Laws.
657+
operator = node.operator.negate() if negated else node.operator
658+
negated = negated != (node.operator == Operator.NOT)
659+
lhs_compound = node._flatten(node.lhs, negated)
660+
rhs_compound = node._flatten(node.rhs, negated)
661+
if operator == Operator.OR:
662+
return CompoundExpression(should=[lhs_compound, rhs_compound], minimum_should_match=1)
663+
if node.operator == Operator.AND:
664+
return CompoundExpression(
665+
must=lhs_compound.must + rhs_compound.must,
666+
must_not=lhs_compound.must_not + rhs_compound.must_not,
667+
should=lhs_compound.should + rhs_compound.should,
668+
filter=lhs_compound.filter + rhs_compound.filter,
669+
)
670+
# it also can be written as:
671+
# this way is more consistent with OR, but the above is shorter in the debug query.
672+
# return CompoundExpression(must=[lhs_compound, rhs_compound])
673+
# not operator
674+
return lhs_compound
675+
676+
def as_mql(self, compiler, connection):
677+
expression = self._flatten(self)
678+
return expression.as_mql(compiler, connection)
679+
680+
681+
class SearchScoreOption:
682+
"""Class to mutate scoring on a search operation"""
683+
684+
def __init__(self, definitions=None):
685+
self.definitions = definitions
686+
568687

569688
def register_expressions():
570689
Case.as_mql = case

0 commit comments

Comments
 (0)