@@ -213,7 +213,56 @@ def value(self, compiler, connection): # noqa: ARG001
213
213
return value
214
214
215
215
216
- class SearchExpression (Expression ):
216
+ class Operator :
217
+ AND = "AND"
218
+ OR = "OR"
219
+ NOT = "NOT"
220
+
221
+ def __init__ (self , operator ):
222
+ self .operator = operator
223
+
224
+ def __eq__ (self , other ):
225
+ if isinstance (other , str ):
226
+ return self .operator == other
227
+ return self .operator == other .operator
228
+
229
+ def negate (self ):
230
+ if self .operator == self .AND :
231
+ return Operator (self .OR )
232
+ if self .operator == self .OR :
233
+ return Operator (self .AND )
234
+ return Operator (self .operator )
235
+
236
+
237
+ class SearchCombinable :
238
+ def _combine (self , other , connector , reversed ):
239
+ if not isinstance (self , CompoundExpression | CombinedSearchExpression ):
240
+ lhs = CompoundExpression (must = [self ])
241
+ else :
242
+ lhs = self
243
+ if not isinstance (other , CompoundExpression | CombinedSearchExpression ):
244
+ rhs = CompoundExpression (must = [other ])
245
+ else :
246
+ rhs = other
247
+ return CombinedSearchExpression (lhs , connector , rhs )
248
+
249
+ def __invert__ (self ):
250
+ return CombinedSearchExpression (self , Operator (Operator .NOT ), None )
251
+
252
+ def __and__ (self , other ):
253
+ return CombinedSearchExpression (self , Operator (Operator .AND ), other )
254
+
255
+ def __rand__ (self , other ):
256
+ return CombinedSearchExpression (self , Operator (Operator .AND ), other )
257
+
258
+ def __or__ (self , other ):
259
+ return CombinedSearchExpression (self , Operator (Operator .OR ), other )
260
+
261
+ def __ror__ (self , other ):
262
+ return CombinedSearchExpression (self , Operator (Operator .OR ), other )
263
+
264
+
265
+ class SearchExpression (SearchCombinable , Expression ):
217
266
output_field = FloatField ()
218
267
219
268
def get_source_expressions (self ):
@@ -530,6 +579,21 @@ def __init__(
530
579
self .filter = filter
531
580
super ().__init__ ()
532
581
582
+ def __invert__ (self ):
583
+ return ValueError ("SearchVector cannot be negated" )
584
+
585
+ def __and__ (self , other ):
586
+ raise NotSupportedError ("SearchVector cannot be combined" )
587
+
588
+ def __rand__ (self , other ):
589
+ raise NotSupportedError ("SearchVector cannot be combined" )
590
+
591
+ def __or__ (self , other ):
592
+ raise NotSupportedError ("SearchVector cannot be combined" )
593
+
594
+ def __ror__ (self , other ):
595
+ raise NotSupportedError ("SearchVector cannot be combined" )
596
+
533
597
def as_mql (self , compiler , connection ):
534
598
params = {
535
599
"index" : self .index ,
@@ -546,15 +610,16 @@ def as_mql(self, compiler, connection):
546
610
return {"$vectorSearch" : params }
547
611
548
612
549
- class SearchScoreOption :
550
- """Class to mutate scoring on a search operation"""
551
-
552
- def __init__ (self , definitions = None ):
553
- self .definitions = definitions
554
-
555
-
556
613
class CompoundExpression (SearchExpression ):
557
- def __init__ (self , must = None , must_not = None , should = None , filter = None , score = None ):
614
+ def __init__ (
615
+ self ,
616
+ must = None ,
617
+ must_not = None ,
618
+ should = None ,
619
+ filter = None ,
620
+ score = None ,
621
+ minimum_should_match = None ,
622
+ ):
558
623
self .must = must or []
559
624
self .must_not = must_not or []
560
625
self .should = should or []
@@ -563,13 +628,67 @@ def __init__(self, must=None, must_not=None, should=None, filter=None, score=Non
563
628
564
629
def as_mql (self , compiler , connection ):
565
630
params = {}
566
- for param in ["must" , "must_not" , "should" , "filter" ]:
567
- clauses = getattr (self , param )
568
- if clauses :
569
- params [param ] = [clause .as_mql (compiler , connection ) for clause in clauses ]
631
+ if self .must :
632
+ params ["must" ] = [clause .as_mql (compiler , connection ) for clause in self .must ]
633
+ if self .must_not :
634
+ params ["mustNot" ] = [clause .as_mql (compiler , connection ) for clause in self .must_not ]
635
+ if self .should :
636
+ params ["should" ] = [clause .as_mql (compiler , connection ) for clause in self .should ]
637
+ if self .filter :
638
+ params ["filter" ] = [clause .as_mql (compiler , connection ) for clause in self .filter ]
639
+ if self .minimum_should_match is not None :
640
+ params ["minimumShouldMatch" ] = self .minimum_should_match
570
641
571
642
return {"$compound" : params }
572
643
644
+ def negate (self ):
645
+ return CompoundExpression (must = self .must_not , must_not = self .must + self .filter )
646
+
647
+
648
+ class CombinedSearchExpression (SearchExpression ):
649
+ def __init__ (self , lhs , operator , rhs ):
650
+ self .lhs = lhs
651
+ self .operator = operator
652
+ self .rhs = rhs
653
+
654
+ @staticmethod
655
+ def _flatten (node , negated = False ):
656
+ if node is None :
657
+ return None
658
+ # Leaf, resolve the compoundExpression
659
+ if isinstance (node , CompoundExpression ):
660
+ return node .negate () if negated else node
661
+ # Apply De Morgan's Laws.
662
+ operator = node .operator .negate () if negated else node .operator
663
+ negated = negated != (node .operator == Operator .NOT )
664
+ lhs_compound = node ._flatten (node .lhs , negated )
665
+ rhs_compound = node ._flatten (node .rhs , negated )
666
+ if operator == Operator .OR :
667
+ return CompoundExpression (should = [lhs_compound , rhs_compound ], minimum_should_match = 1 )
668
+ if node .operator == Operator .AND :
669
+ return CompoundExpression (
670
+ must = lhs_compound .must + rhs_compound .must ,
671
+ must_not = lhs_compound .must_not + rhs_compound .must_not ,
672
+ should = lhs_compound .should + rhs_compound .should ,
673
+ filter = lhs_compound .filter + rhs_compound .filter ,
674
+ )
675
+ # it also can be written as:
676
+ # this way is more consistent with OR, but the above is shorter in the debug query.
677
+ # return CompoundExpression(must=[lhs_compound, rhs_compound])
678
+ # not operator
679
+ return lhs_compound
680
+
681
+ def as_mql (self , compiler , connection ):
682
+ expression = self ._flatten (self )
683
+ return expression .as_mql (compiler , connection )
684
+
685
+
686
+ class SearchScoreOption :
687
+ """Class to mutate scoring on a search operation"""
688
+
689
+ def __init__ (self , definitions = None ):
690
+ self .definitions = definitions
691
+
573
692
574
693
def register_expressions ():
575
694
Case .as_mql = case
0 commit comments