Skip to content

Commit ab413ec

Browse files
committed
Check function in VectorSearchIndex.
1 parent a5580ba commit ab413ec

File tree

1 file changed

+64
-18
lines changed

1 file changed

+64
-18
lines changed

django_mongodb_backend/indexes.py

Lines changed: 64 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
import itertools
22
from collections import defaultdict
33

4+
from django.core.checks import Error
45
from django.db import NotSupportedError
56
from django.db.models import (
7+
BooleanField,
8+
CharField,
9+
DateField,
10+
DateTimeField,
611
DecimalField,
712
FloatField,
813
Index,
14+
IntegerField,
15+
TextField,
16+
UUIDField,
917
)
1018
from django.db.models.lookups import BuiltinLookup
1119
from django.db.models.sql.query import Query
1220
from django.db.models.sql.where import AND, XOR, WhereNode
1321
from pymongo import ASCENDING, DESCENDING
1422
from pymongo.operations import IndexModel, SearchIndexModel
1523

16-
from django_mongodb_backend.fields import ArrayField
24+
from django_mongodb_backend.fields import ArrayField, ObjectIdAutoField, ObjectIdField
1725

1826
from .query_utils import process_rhs
1927

@@ -151,19 +159,66 @@ class VectorSearchIndex(SearchIndex):
151159
def __init__(self, *expressions, similarities="cosine", **kwargs):
152160
super().__init__(*expressions, **kwargs)
153161
# validate the similarities types
154-
if isinstance(similarities, str):
155-
self._check_similarity_functions([similarities])
156-
else:
157-
self._check_similarity_functions(similarities)
158162
self.similarities = similarities
159163

160-
def _check_similarity_functions(self, similarities):
164+
def check(self, model):
165+
errors = []
166+
error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"
167+
similarities = (
168+
self.similarity if isinstance(self.similarities, list) else [self.similarities]
169+
)
161170
for func in similarities:
162171
if func not in self.ALLOWED_SIMILARITY_FUNCTIONS:
163-
raise ValueError(
172+
errors.append(
164173
f"{func} isn't a valid similarity function, options "
165-
f"'are {','.join(self.ALLOWED_SIMILARITY_FUNCTIONS)}"
174+
f"'are {','.join(self.ALLOWED_SIMILARITY_FUNCTIONS)}",
175+
obj=self,
176+
id=f"{error_id_prefix}.E003",
177+
)
178+
for field_name, _ in self.fields_orders:
179+
field_ = model._meta.get_field(field_name)
180+
if isinstance(field_, ArrayField):
181+
try:
182+
int(field_.size)
183+
except (ValueError, TypeError):
184+
errors.append(
185+
Error(
186+
"Atlas vector search requires size.",
187+
obj=self,
188+
id=f"{error_id_prefix}.E001",
189+
)
190+
)
191+
if not isinstance(field_.base_field, FloatField | DecimalField):
192+
errors.append(
193+
Error(
194+
"Base type must be Float or Decimal.",
195+
obj=self,
196+
id=f"{error_id_prefix}.E002",
197+
)
198+
)
199+
# filter - for fields that contain boolean, date, objectId,
200+
# numeric, string, or UUID values. Reference:
201+
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#atlas-vector-search-index-fields
202+
elif not isinstance(
203+
field_,
204+
BooleanField
205+
| IntegerField
206+
| DateField
207+
| DateTimeField
208+
| CharField
209+
| TextField
210+
| UUIDField
211+
| ObjectIdField
212+
| ObjectIdAutoField,
213+
):
214+
errors.append(
215+
Error(
216+
f"Unsupported filter of type {field_.get_internal_type()}.",
217+
obj=self,
218+
id="django_mongodb_backend.indexes.VectorSearchIndex.E003",
219+
)
166220
)
221+
return errors
167222

168223
def deconstruct(self):
169224
path, args, kwargs = super().deconstruct()
@@ -198,16 +253,7 @@ def get_pymongo_index_model(
198253
}
199254
)
200255
else:
201-
field_type = field_.db_type(schema_editor.connection)
202-
search_type = self.search_index_data_types(field_, field_type)
203-
# filter - for fields that contain boolean, date, objectId, numeric,
204-
# string, or UUID values. Reference:
205-
# https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-type/#atlas-vector-search-index-fields
206-
if search_type in ("number", "string", "boolean", "objectId", "uuid", "date"):
207-
mappings["type"] = "filter"
208-
else:
209-
field_type = field_.get_internal_type()
210-
raise ValueError(f"Unsupported filter of type {field_type}.")
256+
mappings["type"] = "filter"
211257
fields.append(mappings)
212258
return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")
213259

0 commit comments

Comments
 (0)