|
| 1 | +from django.db import NotSupportedError |
1 | 2 | from django.db.models import Expression, FloatField
|
2 | 3 | from django.db.models.expressions import F, Value
|
3 | 4 |
|
@@ -878,6 +879,108 @@ def as_mql(self, compiler, connection):
|
878 | 879 | return expression.as_mql(compiler, connection)
|
879 | 880 |
|
880 | 881 |
|
| 882 | +class SearchVector(SearchExpression): |
| 883 | + """ |
| 884 | + Atlas Search expression that performs vector similarity search on embedded vectors. |
| 885 | +
|
| 886 | + This expression uses the **knnBeta** operator to find documents whose vector |
| 887 | + embeddings are most similar to a given query vector. |
| 888 | +
|
| 889 | + Example: |
| 890 | + SearchVector("embedding", [0.1, 0.2, 0.3], limit=10, num_candidates=100) |
| 891 | +
|
| 892 | + Args: |
| 893 | + path: The document path to the vector field (as string or expression). |
| 894 | + query_vector: The query vector to compare against. |
| 895 | + limit: Maximum number of matching documents to return. |
| 896 | + num_candidates: Optional number of candidates to consider during search. |
| 897 | + exact: Optional flag to enforce exact matching. |
| 898 | + filter: Optional filter expression to narrow candidate documents. |
| 899 | +
|
| 900 | + Reference: https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/ |
| 901 | + """ |
| 902 | + |
| 903 | + def __init__( |
| 904 | + self, |
| 905 | + path, |
| 906 | + query_vector, |
| 907 | + limit, |
| 908 | + num_candidates=None, |
| 909 | + exact=None, |
| 910 | + filter=None, |
| 911 | + ): |
| 912 | + self.path = cast_as_field(path) |
| 913 | + self.query_vector = cast_as_value(query_vector) |
| 914 | + self.limit = cast_as_value(limit) |
| 915 | + self.num_candidates = cast_as_value(num_candidates) |
| 916 | + self.exact = cast_as_value(exact) |
| 917 | + self.filter = cast_as_value(filter) |
| 918 | + super().__init__() |
| 919 | + |
| 920 | + def __invert__(self): |
| 921 | + return ValueError("SearchVector cannot be negated") |
| 922 | + |
| 923 | + def __and__(self, other): |
| 924 | + raise NotSupportedError("SearchVector cannot be combined") |
| 925 | + |
| 926 | + def __rand__(self, other): |
| 927 | + raise NotSupportedError("SearchVector cannot be combined") |
| 928 | + |
| 929 | + def __or__(self, other): |
| 930 | + raise NotSupportedError("SearchVector cannot be combined") |
| 931 | + |
| 932 | + def __ror__(self, other): |
| 933 | + raise NotSupportedError("SearchVector cannot be combined") |
| 934 | + |
| 935 | + def get_search_fields(self, compiler, connection): |
| 936 | + return {self.path.as_mql(compiler, connection, as_path=True)} |
| 937 | + |
| 938 | + def get_source_expressions(self): |
| 939 | + return [ |
| 940 | + self.path, |
| 941 | + self.query_vector, |
| 942 | + self.limit, |
| 943 | + self.num_candidates, |
| 944 | + self.exact, |
| 945 | + self.filter, |
| 946 | + ] |
| 947 | + |
| 948 | + def set_source_expressions(self, exprs): |
| 949 | + ( |
| 950 | + self.path, |
| 951 | + self.query_vector, |
| 952 | + self.limit, |
| 953 | + self.num_candidates, |
| 954 | + self.exact, |
| 955 | + self.filter, |
| 956 | + ) = exprs |
| 957 | + |
| 958 | + def _get_query_index(self, fields, compiler): |
| 959 | + for search_indexes in compiler.collection.list_search_indexes(): |
| 960 | + if search_indexes["type"] == "vectorSearch": |
| 961 | + index_field = { |
| 962 | + field["path"] for field in search_indexes["latestDefinition"]["fields"] |
| 963 | + } |
| 964 | + if fields.issubset(index_field): |
| 965 | + return search_indexes["name"] |
| 966 | + return "default" |
| 967 | + |
| 968 | + def as_mql(self, compiler, connection): |
| 969 | + params = { |
| 970 | + "index": self._get_query_index(self.get_search_fields(compiler, connection), compiler), |
| 971 | + "path": self.path.as_mql(compiler, connection, as_path=True), |
| 972 | + "queryVector": self.query_vector.value, |
| 973 | + "limit": self.limit.value, |
| 974 | + } |
| 975 | + if self.num_candidates is not None: |
| 976 | + params["numCandidates"] = self.num_candidates.value |
| 977 | + if self.exact is not None: |
| 978 | + params["exact"] = self.exact.value |
| 979 | + if self.filter is not None: |
| 980 | + params["filter"] = self.filter.as_mql(compiler, connection) |
| 981 | + return {"$vectorSearch": params} |
| 982 | + |
| 983 | + |
881 | 984 | class SearchScoreOption(Expression):
|
882 | 985 | """Class to mutate scoring on a search operation"""
|
883 | 986 |
|
|
0 commit comments