@@ -593,61 +593,6 @@ def get_search_fields(self):
593
593
return needed_fields
594
594
595
595
596
- class SearchVector (SearchExpression ):
597
- def __init__ (
598
- self ,
599
- path ,
600
- query_vector ,
601
- limit ,
602
- num_candidates = None ,
603
- exact = None ,
604
- filter = None ,
605
- ):
606
- self .path = path
607
- self .query_vector = query_vector
608
- self .limit = limit
609
- self .num_candidates = num_candidates
610
- self .exact = exact
611
- self .filter = filter
612
- super ().__init__ ()
613
-
614
- def __invert__ (self ):
615
- return ValueError ("SearchVector cannot be negated" )
616
-
617
- def __and__ (self , other ):
618
- raise NotSupportedError ("SearchVector cannot be combined" )
619
-
620
- def __rand__ (self , other ):
621
- raise NotSupportedError ("SearchVector cannot be combined" )
622
-
623
- def __or__ (self , other ):
624
- raise NotSupportedError ("SearchVector cannot be combined" )
625
-
626
- def __ror__ (self , other ):
627
- raise NotSupportedError ("SearchVector cannot be combined" )
628
-
629
- def get_search_fields (self ):
630
- return {self .path }
631
-
632
- def _get_query_index (self , field , compiler ):
633
- return "default"
634
-
635
- def as_mql (self , compiler , connection ):
636
- params = {
637
- "index" : self ._get_query_index (self .get_search_fields ()),
638
- "path" : self .path ,
639
- "queryVector" : self .query_vector ,
640
- "limit" : self .limit ,
641
- }
642
- if self .num_candidates is not None :
643
- params ["numCandidates" ] = self .num_candidates
644
- if self .exact is not None :
645
- params ["exact" ] = self .exact
646
- if self .filter is not None :
647
- params ["filter" ] = self .filter
648
- return {"$vectorSearch" : params }
649
-
650
-
651
596
class CompoundExpression (SearchExpression ):
652
597
def __init__ (
653
598
self ,
@@ -711,22 +656,76 @@ def resolve(node, negated=False):
711
656
if operator == Operator .OR :
712
657
return CompoundExpression (should = [lhs_compound , rhs_compound ], minimum_should_match = 1 )
713
658
if operator == Operator .AND :
714
- # NOTE: we can't just do the code below, think about this case (A | B) & (C | D)
715
- # return CompoundExpression(
716
- # must=lhs_compound.must + rhs_compound.must,
717
- # must_not=lhs_compound.must_not + rhs_compound.must_not,
718
- # should=lhs_compound.should + rhs_compound.should,
719
- # filter=lhs_compound.filter + rhs_compound.filter,
720
- # )
721
659
return CompoundExpression (must = [lhs_compound , rhs_compound ])
722
- # not operator
723
660
return lhs_compound
724
661
725
662
def as_mql (self , compiler , connection ):
726
663
expression = self .resolve (self )
727
664
return expression .as_mql (compiler , connection )
728
665
729
666
667
+ class SearchVector (SearchExpression ):
668
+ def __init__ (
669
+ self ,
670
+ path ,
671
+ query_vector ,
672
+ limit ,
673
+ num_candidates = None ,
674
+ exact = None ,
675
+ filter = None ,
676
+ ):
677
+ self .path = path
678
+ self .query_vector = query_vector
679
+ self .limit = limit
680
+ self .num_candidates = num_candidates
681
+ self .exact = exact
682
+ self .filter = filter
683
+ super ().__init__ ()
684
+
685
+ def __invert__ (self ):
686
+ return ValueError ("SearchVector cannot be negated" )
687
+
688
+ def __and__ (self , other ):
689
+ raise NotSupportedError ("SearchVector cannot be combined" )
690
+
691
+ def __rand__ (self , other ):
692
+ raise NotSupportedError ("SearchVector cannot be combined" )
693
+
694
+ def __or__ (self , other ):
695
+ raise NotSupportedError ("SearchVector cannot be combined" )
696
+
697
+ def __ror__ (self , other ):
698
+ raise NotSupportedError ("SearchVector cannot be combined" )
699
+
700
+ def get_search_fields (self ):
701
+ return {self .path }
702
+
703
+ def _get_query_index (self , fields , compiler ):
704
+ for search_indexes in compiler .collection .list_search_indexes ():
705
+ if search_indexes ["type" ] == "vectorSearch" :
706
+ index_field = {
707
+ field ["path" ] for field in search_indexes ["latestDefinition" ]["fields" ]
708
+ }
709
+ if fields .issubset (index_field ):
710
+ return search_indexes ["name" ]
711
+ return "default"
712
+
713
+ def as_mql (self , compiler , connection ):
714
+ params = {
715
+ "index" : self ._get_query_index (self .get_search_fields (), compiler ),
716
+ "path" : self .path ,
717
+ "queryVector" : self .query_vector ,
718
+ "limit" : self .limit ,
719
+ }
720
+ if self .num_candidates is not None :
721
+ params ["numCandidates" ] = self .num_candidates
722
+ if self .exact is not None :
723
+ params ["exact" ] = self .exact
724
+ if self .filter is not None :
725
+ params ["filter" ] = self .filter
726
+ return {"$vectorSearch" : params }
727
+
728
+
730
729
class SearchScoreOption :
731
730
"""Class to mutate scoring on a search operation"""
732
731
0 commit comments