@@ -208,7 +208,56 @@ def value(self, compiler, connection): # noqa: ARG001
208
208
return value
209
209
210
210
211
- class SearchExpression (Expression ):
211
+ class Operator :
212
+ AND = "AND"
213
+ OR = "OR"
214
+ NOT = "NOT"
215
+
216
+ def __init__ (self , operator ):
217
+ self .operator = operator
218
+
219
+ def __eq__ (self , other ):
220
+ if isinstance (other , str ):
221
+ return self .operator == other
222
+ return self .operator == other .operator
223
+
224
+ def negate (self ):
225
+ if self .operator == self .AND :
226
+ return Operator (self .OR )
227
+ if self .operator == self .OR :
228
+ return Operator (self .AND )
229
+ return Operator (self .operator )
230
+
231
+
232
+ class SearchCombinable :
233
+ def _combine (self , other , connector , reversed ):
234
+ if not isinstance (self , CompoundExpression | CombinedSearchExpression ):
235
+ lhs = CompoundExpression (must = [self ])
236
+ else :
237
+ lhs = self
238
+ if not isinstance (other , CompoundExpression | CombinedSearchExpression ):
239
+ rhs = CompoundExpression (must = [other ])
240
+ else :
241
+ rhs = other
242
+ return CombinedSearchExpression (lhs , connector , rhs )
243
+
244
+ def __invert__ (self ):
245
+ return CombinedSearchExpression (self , Operator (Operator .NOT ), None )
246
+
247
+ def __and__ (self , other ):
248
+ return CombinedSearchExpression (self , Operator (Operator .AND ), other )
249
+
250
+ def __rand__ (self , other ):
251
+ return CombinedSearchExpression (self , Operator (Operator .AND ), other )
252
+
253
+ def __or__ (self , other ):
254
+ return CombinedSearchExpression (self , Operator (Operator .OR ), other )
255
+
256
+ def __ror__ (self , other ):
257
+ return CombinedSearchExpression (self , Operator (Operator .OR ), other )
258
+
259
+
260
+ class SearchExpression (SearchCombinable , Expression ):
212
261
output_field = FloatField ()
213
262
214
263
def get_source_expressions (self ):
@@ -525,6 +574,21 @@ def __init__(
525
574
self .filter = filter
526
575
super ().__init__ ()
527
576
577
+ def __invert__ (self ):
578
+ return ValueError ("SearchVector cannot be negated" )
579
+
580
+ def __and__ (self , other ):
581
+ raise NotSupportedError ("SearchVector cannot be combined" )
582
+
583
+ def __rand__ (self , other ):
584
+ raise NotSupportedError ("SearchVector cannot be combined" )
585
+
586
+ def __or__ (self , other ):
587
+ raise NotSupportedError ("SearchVector cannot be combined" )
588
+
589
+ def __ror__ (self , other ):
590
+ raise NotSupportedError ("SearchVector cannot be combined" )
591
+
528
592
def as_mql (self , compiler , connection ):
529
593
params = {
530
594
"index" : self .index ,
@@ -541,15 +605,16 @@ def as_mql(self, compiler, connection):
541
605
return {"$vectorSearch" : params }
542
606
543
607
544
- class SearchScoreOption :
545
- """Class to mutate scoring on a search operation"""
546
-
547
- def __init__ (self , definitions = None ):
548
- self .definitions = definitions
549
-
550
-
551
608
class CompoundExpression (SearchExpression ):
552
- def __init__ (self , must = None , must_not = None , should = None , filter = None , score = None ):
609
+ def __init__ (
610
+ self ,
611
+ must = None ,
612
+ must_not = None ,
613
+ should = None ,
614
+ filter = None ,
615
+ score = None ,
616
+ minimum_should_match = None ,
617
+ ):
553
618
self .must = must or []
554
619
self .must_not = must_not or []
555
620
self .should = should or []
@@ -558,13 +623,67 @@ def __init__(self, must=None, must_not=None, should=None, filter=None, score=Non
558
623
559
624
def as_mql (self , compiler , connection ):
560
625
params = {}
561
- for param in ["must" , "must_not" , "should" , "filter" ]:
562
- clauses = getattr (self , param )
563
- if clauses :
564
- params [param ] = [clause .as_mql (compiler , connection ) for clause in clauses ]
626
+ if self .must :
627
+ params ["must" ] = [clause .as_mql (compiler , connection ) for clause in self .must ]
628
+ if self .must_not :
629
+ params ["mustNot" ] = [clause .as_mql (compiler , connection ) for clause in self .must_not ]
630
+ if self .should :
631
+ params ["should" ] = [clause .as_mql (compiler , connection ) for clause in self .should ]
632
+ if self .filter :
633
+ params ["filter" ] = [clause .as_mql (compiler , connection ) for clause in self .filter ]
634
+ if self .minimum_should_match is not None :
635
+ params ["minimumShouldMatch" ] = self .minimum_should_match
565
636
566
637
return {"$compound" : params }
567
638
639
+ def negate (self ):
640
+ return CompoundExpression (must = self .must_not , must_not = self .must + self .filter )
641
+
642
+
643
+ class CombinedSearchExpression (SearchExpression ):
644
+ def __init__ (self , lhs , operator , rhs ):
645
+ self .lhs = lhs
646
+ self .operator = operator
647
+ self .rhs = rhs
648
+
649
+ @staticmethod
650
+ def _flatten (node , negated = False ):
651
+ if node is None :
652
+ return None
653
+ # Leaf, resolve the compoundExpression
654
+ if isinstance (node , CompoundExpression ):
655
+ return node .negate () if negated else node
656
+ # Apply De Morgan's Laws.
657
+ operator = node .operator .negate () if negated else node .operator
658
+ negated = negated != (node .operator == Operator .NOT )
659
+ lhs_compound = node ._flatten (node .lhs , negated )
660
+ rhs_compound = node ._flatten (node .rhs , negated )
661
+ if operator == Operator .OR :
662
+ return CompoundExpression (should = [lhs_compound , rhs_compound ], minimum_should_match = 1 )
663
+ if node .operator == Operator .AND :
664
+ return CompoundExpression (
665
+ must = lhs_compound .must + rhs_compound .must ,
666
+ must_not = lhs_compound .must_not + rhs_compound .must_not ,
667
+ should = lhs_compound .should + rhs_compound .should ,
668
+ filter = lhs_compound .filter + rhs_compound .filter ,
669
+ )
670
+ # it also can be written as:
671
+ # this way is more consistent with OR, but the above is shorter in the debug query.
672
+ # return CompoundExpression(must=[lhs_compound, rhs_compound])
673
+ # not operator
674
+ return lhs_compound
675
+
676
+ def as_mql (self , compiler , connection ):
677
+ expression = self ._flatten (self )
678
+ return expression .as_mql (compiler , connection )
679
+
680
+
681
+ class SearchScoreOption :
682
+ """Class to mutate scoring on a search operation"""
683
+
684
+ def __init__ (self , definitions = None ):
685
+ self .definitions = definitions
686
+
568
687
569
688
def register_expressions ():
570
689
Case .as_mql = case
0 commit comments