|
1 | 1 | import itertools
|
2 | 2 | from collections import defaultdict
|
3 | 3 |
|
| 4 | +from django.core.checks import Error |
4 | 5 | from django.db import NotSupportedError
|
5 | 6 | from django.db.models import (
|
| 7 | + BooleanField, |
| 8 | + CharField, |
| 9 | + DateField, |
| 10 | + DateTimeField, |
6 | 11 | DecimalField,
|
7 | 12 | FloatField,
|
8 | 13 | Index,
|
| 14 | + IntegerField, |
| 15 | + TextField, |
| 16 | + UUIDField, |
9 | 17 | )
|
10 | 18 | from django.db.models.lookups import BuiltinLookup
|
11 | 19 | from django.db.models.sql.query import Query
|
12 | 20 | from django.db.models.sql.where import AND, XOR, WhereNode
|
13 | 21 | from pymongo import ASCENDING, DESCENDING
|
14 | 22 | from pymongo.operations import IndexModel, SearchIndexModel
|
15 | 23 |
|
16 |
| -from django_mongodb_backend.fields import ArrayField |
| 24 | +from django_mongodb_backend.fields import ArrayField, ObjectIdAutoField, ObjectIdField |
17 | 25 |
|
18 | 26 | from .query_utils import process_rhs
|
19 | 27 |
|
@@ -151,19 +159,66 @@ class VectorSearchIndex(SearchIndex):
|
151 | 159 | def __init__(self, *expressions, similarities="cosine", **kwargs):
|
152 | 160 | super().__init__(*expressions, **kwargs)
|
153 | 161 | # validate the similarities types
|
154 |
| - if isinstance(similarities, str): |
155 |
| - self._check_similarity_functions([similarities]) |
156 |
| - else: |
157 |
| - self._check_similarity_functions(similarities) |
158 | 162 | self.similarities = similarities
|
159 | 163 |
|
160 |
| - def _check_similarity_functions(self, similarities): |
| 164 | + def check(self, model): |
| 165 | + errors = [] |
| 166 | + error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex" |
| 167 | + similarities = ( |
| 168 | + self.similarity if isinstance(self.similarities, list) else [self.similarities] |
| 169 | + ) |
161 | 170 | for func in similarities:
|
162 | 171 | if func not in self.ALLOWED_SIMILARITY_FUNCTIONS:
|
163 |
| - raise ValueError( |
| 172 | + errors.append( |
164 | 173 | f"{func} isn't a valid similarity function, options "
|
165 |
| - f"'are {','.join(self.ALLOWED_SIMILARITY_FUNCTIONS)}" |
| 174 | + f"'are {','.join(self.ALLOWED_SIMILARITY_FUNCTIONS)}", |
| 175 | + obj=self, |
| 176 | + id=f"{error_id_prefix}.E003", |
| 177 | + ) |
| 178 | + for field_name, _ in self.fields_orders: |
| 179 | + field_ = model._meta.get_field(field_name) |
| 180 | + if isinstance(field_, ArrayField): |
| 181 | + try: |
| 182 | + int(field_.size) |
| 183 | + except (ValueError, TypeError): |
| 184 | + errors.append( |
| 185 | + Error( |
| 186 | + "Atlas vector search requires size.", |
| 187 | + obj=self, |
| 188 | + id=f"{error_id_prefix}.E001", |
| 189 | + ) |
| 190 | + ) |
| 191 | + if not isinstance(field_.base_field, FloatField | DecimalField): |
| 192 | + errors.append( |
| 193 | + Error( |
| 194 | + "Base type must be Float or Decimal.", |
| 195 | + obj=self, |
| 196 | + id=f"{error_id_prefix}.E002", |
| 197 | + ) |
| 198 | + ) |
| 199 | + # filter - for fields that contain boolean, date, objectId, |
| 200 | + # numeric, string, or UUID values. Reference: |
| 201 | + # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#atlas-vector-search-index-fields |
| 202 | + elif not isinstance( |
| 203 | + field_, |
| 204 | + BooleanField |
| 205 | + | IntegerField |
| 206 | + | DateField |
| 207 | + | DateTimeField |
| 208 | + | CharField |
| 209 | + | TextField |
| 210 | + | UUIDField |
| 211 | + | ObjectIdField |
| 212 | + | ObjectIdAutoField, |
| 213 | + ): |
| 214 | + errors.append( |
| 215 | + Error( |
| 216 | + f"Unsupported filter of type {field_.get_internal_type()}.", |
| 217 | + obj=self, |
| 218 | + id="django_mongodb_backend.indexes.VectorSearchIndex.E003", |
| 219 | + ) |
166 | 220 | )
|
| 221 | + return errors |
167 | 222 |
|
168 | 223 | def deconstruct(self):
|
169 | 224 | path, args, kwargs = super().deconstruct()
|
@@ -198,16 +253,7 @@ def get_pymongo_index_model(
|
198 | 253 | }
|
199 | 254 | )
|
200 | 255 | else:
|
201 |
| - field_type = field_.db_type(schema_editor.connection) |
202 |
| - search_type = self.search_index_data_types(field_, field_type) |
203 |
| - # filter - for fields that contain boolean, date, objectId, numeric, |
204 |
| - # string, or UUID values. Reference: |
205 |
| - # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#atlas-vector-search-index-fields |
206 |
| - if search_type in ("number", "string", "boolean", "objectId", "uuid", "date"): |
207 |
| - mappings["type"] = "filter" |
208 |
| - else: |
209 |
| - field_type = field_.get_internal_type() |
210 |
| - raise ValueError(f"Unsupported filter of type {field_type}.") |
| 256 | + mappings["type"] = "filter" |
211 | 257 | fields.append(mappings)
|
212 | 258 | return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")
|
213 | 259 |
|
|
0 commit comments