@@ -598,61 +598,6 @@ def get_search_fields(self):
598
598
return needed_fields
599
599
600
600
601
- class SearchVector (SearchExpression ):
602
- def __init__ (
603
- self ,
604
- path ,
605
- query_vector ,
606
- limit ,
607
- num_candidates = None ,
608
- exact = None ,
609
- filter = None ,
610
- ):
611
- self .path = path
612
- self .query_vector = query_vector
613
- self .limit = limit
614
- self .num_candidates = num_candidates
615
- self .exact = exact
616
- self .filter = filter
617
- super ().__init__ ()
618
-
619
- def __invert__ (self ):
620
- return ValueError ("SearchVector cannot be negated" )
621
-
622
- def __and__ (self , other ):
623
- raise NotSupportedError ("SearchVector cannot be combined" )
624
-
625
- def __rand__ (self , other ):
626
- raise NotSupportedError ("SearchVector cannot be combined" )
627
-
628
- def __or__ (self , other ):
629
- raise NotSupportedError ("SearchVector cannot be combined" )
630
-
631
- def __ror__ (self , other ):
632
- raise NotSupportedError ("SearchVector cannot be combined" )
633
-
634
- def get_search_fields (self ):
635
- return {self .path }
636
-
637
- def _get_query_index (self , field , compiler ):
638
- return "default"
639
-
640
- def as_mql (self , compiler , connection ):
641
- params = {
642
- "index" : self ._get_query_index (self .get_search_fields ()),
643
- "path" : self .path ,
644
- "queryVector" : self .query_vector ,
645
- "limit" : self .limit ,
646
- }
647
- if self .num_candidates is not None :
648
- params ["numCandidates" ] = self .num_candidates
649
- if self .exact is not None :
650
- params ["exact" ] = self .exact
651
- if self .filter is not None :
652
- params ["filter" ] = self .filter
653
- return {"$vectorSearch" : params }
654
-
655
-
656
601
class CompoundExpression (SearchExpression ):
657
602
def __init__ (
658
603
self ,
@@ -716,22 +661,76 @@ def resolve(node, negated=False):
716
661
if operator == Operator .OR :
717
662
return CompoundExpression (should = [lhs_compound , rhs_compound ], minimum_should_match = 1 )
718
663
if operator == Operator .AND :
719
- # NOTE: we can't just do the code below, think about this case (A | B) & (C | D)
720
- # return CompoundExpression(
721
- # must=lhs_compound.must + rhs_compound.must,
722
- # must_not=lhs_compound.must_not + rhs_compound.must_not,
723
- # should=lhs_compound.should + rhs_compound.should,
724
- # filter=lhs_compound.filter + rhs_compound.filter,
725
- # )
726
664
return CompoundExpression (must = [lhs_compound , rhs_compound ])
727
- # not operator
728
665
return lhs_compound
729
666
730
667
def as_mql (self , compiler , connection ):
731
668
expression = self .resolve (self )
732
669
return expression .as_mql (compiler , connection )
733
670
734
671
672
+ class SearchVector (SearchExpression ):
673
+ def __init__ (
674
+ self ,
675
+ path ,
676
+ query_vector ,
677
+ limit ,
678
+ num_candidates = None ,
679
+ exact = None ,
680
+ filter = None ,
681
+ ):
682
+ self .path = path
683
+ self .query_vector = query_vector
684
+ self .limit = limit
685
+ self .num_candidates = num_candidates
686
+ self .exact = exact
687
+ self .filter = filter
688
+ super ().__init__ ()
689
+
690
+ def __invert__ (self ):
691
+ return ValueError ("SearchVector cannot be negated" )
692
+
693
+ def __and__ (self , other ):
694
+ raise NotSupportedError ("SearchVector cannot be combined" )
695
+
696
+ def __rand__ (self , other ):
697
+ raise NotSupportedError ("SearchVector cannot be combined" )
698
+
699
+ def __or__ (self , other ):
700
+ raise NotSupportedError ("SearchVector cannot be combined" )
701
+
702
+ def __ror__ (self , other ):
703
+ raise NotSupportedError ("SearchVector cannot be combined" )
704
+
705
+ def get_search_fields (self ):
706
+ return {self .path }
707
+
708
+ def _get_query_index (self , fields , compiler ):
709
+ for search_indexes in compiler .collection .list_search_indexes ():
710
+ if search_indexes ["type" ] == "vectorSearch" :
711
+ index_field = {
712
+ field ["path" ] for field in search_indexes ["latestDefinition" ]["fields" ]
713
+ }
714
+ if fields .issubset (index_field ):
715
+ return search_indexes ["name" ]
716
+ return "default"
717
+
718
+ def as_mql (self , compiler , connection ):
719
+ params = {
720
+ "index" : self ._get_query_index (self .get_search_fields (), compiler ),
721
+ "path" : self .path ,
722
+ "queryVector" : self .query_vector ,
723
+ "limit" : self .limit ,
724
+ }
725
+ if self .num_candidates is not None :
726
+ params ["numCandidates" ] = self .num_candidates
727
+ if self .exact is not None :
728
+ params ["exact" ] = self .exact
729
+ if self .filter is not None :
730
+ params ["filter" ] = self .filter
731
+ return {"$vectorSearch" : params }
732
+
733
+
735
734
class SearchScoreOption :
736
735
"""Class to mutate scoring on a search operation"""
737
736
0 commit comments