Skip to content

Commit 22a5b80

Browse files
committed
Added system check for similarities and vector fields count mismatch.
1 parent d784806 commit 22a5b80

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

django_mongodb_backend/indexes.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def check(self, model, connection):
185185
)
186186
)
187187
viewed = set()
188+
expected_similarities = 0
188189
for field_name, _ in self.fields_orders:
189190
if field_name in viewed:
190191
errors.append(
@@ -197,9 +198,11 @@ def check(self, model, connection):
197198
id=f"{self._error_id_prefix}.E005",
198199
)
199200
)
201+
continue
200202
viewed.add(field_name)
201203
field_ = model._meta.get_field(field_name)
202204
if isinstance(field_, ArrayField):
205+
expected_similarities += 1
203206
try:
204207
int(field_.size)
205208
except (ValueError, TypeError):
@@ -235,6 +238,22 @@ def check(self, model, connection):
235238
id=f"{self._error_id_prefix}.E003",
236239
)
237240
)
241+
if isinstance(self.similarities, list) and expected_similarities != len(self.similarities):
242+
given_similarities = len(self.similarities)
243+
similarity_function_text = (
244+
"similarities functions" if given_similarities != 1 else "similarity function"
245+
)
246+
errors.append(
247+
Error(
248+
f"An Atlas vector search index requires the same number of similarities and "
249+
f"vector fields, but {expected_similarities} "
250+
f"{similarity_function_text} were expected and "
251+
f"{given_similarities} {'were' if given_similarities != 1 else 'was'} "
252+
"provided.",
253+
obj=self,
254+
id=f"{self._error_id_prefix}.E006",
255+
)
256+
)
238257
return errors
239258

240259
def deconstruct(self):

tests/indexes_/test_checks.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,59 @@ class Meta:
213213
],
214214
)
215215

216+
def test_invalid_number_similarity_function_singular(self):
217+
class Article(models.Model):
218+
vector_data = ArrayField(models.DecimalField(), size=10)
219+
220+
class Meta:
221+
indexes = [
222+
VectorSearchIndex(
223+
fields=["vector_data"],
224+
similarities=["dotProduct", "cosine"],
225+
)
226+
]
227+
228+
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
229+
self.assertEqual(
230+
errors,
231+
[
232+
checks.Error(
233+
"An Atlas vector search index requires the same number of similarities "
234+
"and vector fields, but 1 similarity function were expected and 2 "
235+
"were provided.",
236+
id="django_mongodb_backend.indexes.VectorSearchIndex.E006",
237+
obj=Article._meta.indexes[0],
238+
),
239+
],
240+
)
241+
242+
def test_invalid_number_similarity_function_plural(self):
243+
class Article(models.Model):
244+
vector1 = ArrayField(models.DecimalField(), size=10)
245+
vector2 = ArrayField(models.DecimalField(), size=10)
246+
247+
class Meta:
248+
indexes = [
249+
VectorSearchIndex(
250+
fields=["vector1", "vector2"],
251+
similarities=["dotProduct"],
252+
)
253+
]
254+
255+
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
256+
self.assertEqual(
257+
errors,
258+
[
259+
checks.Error(
260+
"An Atlas vector search index requires the same number of similarities "
261+
"and vector fields, but 2 similarities functions were expected and 1 "
262+
"was provided.",
263+
id="django_mongodb_backend.indexes.VectorSearchIndex.E006",
264+
obj=Article._meta.indexes[0],
265+
),
266+
],
267+
)
268+
216269
def test_simple(self):
217270
class Article(models.Model):
218271
vector_data = ArrayField(models.DecimalField(), size=10)

0 commit comments

Comments
 (0)