Skip to content

Commit 351905c

Browse files
committed
make similarities required
1 parent 863db4b commit 351905c

File tree

4 files changed

+37
-68
lines changed

4 files changed

+37
-68
lines changed

django_mongodb_backend/indexes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class VectorSearchIndex(SearchIndex):
167167
VALID_FIELD_TYPES = frozenset(("boolean", "date", "number", "objectId", "string", "uuid"))
168168
_error_id_prefix = "django_mongodb_backend.indexes.VectorSearchIndex"
169169

170-
def __init__(self, *, fields=(), similarities="cosine", name=None):
170+
def __init__(self, *, fields=(), name=None, similarities):
171171
super().__init__(fields=fields, name=name)
172172
self.similarities = similarities
173173
self._multiple_similarities = isinstance(similarities, tuple | list)
@@ -248,8 +248,7 @@ def check(self, model, connection):
248248

249249
def deconstruct(self):
250250
path, args, kwargs = super().deconstruct()
251-
if self.similarities != "cosine":
252-
kwargs["similarities"] = self.similarities
251+
kwargs["similarities"] = self.similarities
253252
return path, args, kwargs
254253

255254
def get_pymongo_index_model(

docs/source/ref/models/indexes.rst

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,24 @@ model has multiple indexes).
3131
``VectorSearchIndex``
3232
=====================
3333

34-
.. class:: VectorSearchIndex(fields=(), similarities="cosine", name=None)
34+
.. class:: VectorSearchIndex(*, fields=(), name=None, similarities)
3535

3636
.. versionadded:: 5.2.0b0
3737

3838
A subclass of :class:`SearchIndex` that creates a :doc:`vector search index
3939
<atlas:atlas-vector-search/vector-search-type>` on the given field(s).
4040

41-
The index should reference at least one vector field: an :class:`.ArrayField`
41+
The index must reference at least one vector field: an :class:`.ArrayField`
4242
with a :attr:`~.ArrayField.base_field` of :class:`~django.db.models.FloatField`
43-
or :class:`~django.db.models.IntegerField`. It cannot reference an
44-
:class:`.ArrayField` of any other type. Each :class:`.ArrayField` must have a
45-
:attr:`~.ArrayField.size`.
43+
or :class:`~django.db.models.IntegerField` and a :attr:`~.ArrayField.size`. It
44+
cannot reference an :class:`.ArrayField` of any other type.
4645

4746
It may also have other fields to filter on, provided the field stores
4847
``boolean``, ``date``, ``objectId``, ``numeric``, ``string``, or ``uuid``.
4948

50-
Available values for ``similarities`` are ``"cosine"``, ``"dotProduct"``, and
51-
``"euclidean"`` (see :ref:`atlas:avs-similarity-functions`). You can provide
52-
this value either a string, in which case that value will be applied to all
53-
vector fields, or a list or tuple of values with a similarity corresponding to
54-
each vector field.
49+
Available values for the required ``similarities`` keyword argument are
50+
``"cosine"``, ``"dotProduct"``, and ``"euclidean"`` (see
51+
:ref:`atlas:avs-similarity-functions` for how to choose). You can provide this
52+
value either a string, in which case that value will be applied to all vector
53+
fields, or a list or tuple of values with a similarity corresponding to each
54+
vector field.

tests/indexes_/test_checks.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Article(models.Model):
4444
vector = ArrayField(models.FloatField(), size=10)
4545

4646
class Meta:
47-
indexes = [VectorSearchIndex(fields=["title", "vector"])]
47+
indexes = [VectorSearchIndex(fields=["title", "vector"], similarities="cosine")]
4848

4949
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
5050
self.assertEqual(
@@ -72,7 +72,7 @@ class Article(models.Model):
7272
title_embedded = ArrayField(models.FloatField())
7373

7474
class Meta:
75-
indexes = [VectorSearchIndex(fields=["title_embedded"])]
75+
indexes = [VectorSearchIndex(fields=["title_embedded"], similarities="cosine")]
7676

7777
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
7878
self.assertEqual(
@@ -91,7 +91,7 @@ class Article(models.Model):
9191
title_embedded = ArrayField(models.CharField(), size=30)
9292

9393
class Meta:
94-
indexes = [VectorSearchIndex(fields=["title_embedded"])]
94+
indexes = [VectorSearchIndex(fields=["title_embedded"], similarities="cosine")]
9595

9696
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
9797
self.assertEqual(
@@ -112,7 +112,7 @@ class Article(models.Model):
112112
vector = ArrayField(models.FloatField(), size=10)
113113

114114
class Meta:
115-
indexes = [VectorSearchIndex(fields=["data", "vector"])]
115+
indexes = [VectorSearchIndex(fields=["data", "vector"], similarities="cosine")]
116116

117117
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
118118
self.assertEqual(
@@ -127,7 +127,7 @@ class Meta:
127127
],
128128
)
129129

130-
def test_invalid_number_similarity_function_singular(self):
130+
def test_fields_and_similarities_mismatch(self):
131131
class Article(models.Model):
132132
vector = ArrayField(models.FloatField(), size=10)
133133

@@ -153,44 +153,17 @@ class Meta:
153153
],
154154
)
155155

156-
def test_invalid_number_similarity_function_plural(self):
157-
class Article(models.Model):
158-
vector1 = ArrayField(models.FloatField(), size=10)
159-
vector2 = ArrayField(models.FloatField(), size=10)
160-
161-
class Meta:
162-
indexes = [
163-
VectorSearchIndex(
164-
fields=["vector1", "vector2"],
165-
similarities=["dotProduct"],
166-
)
167-
]
168-
169-
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
170-
self.assertEqual(
171-
errors,
172-
[
173-
checks.Error(
174-
"VectorSearchIndex requires the same number of similarities "
175-
"and vector fields; Article has 2 ArrayField(s) but similarities "
176-
"has 1 element(s).",
177-
id="django_mongodb_backend.indexes.VectorSearchIndex.E005",
178-
obj=Article,
179-
),
180-
],
181-
)
182-
183156
def test_simple(self):
184157
class Article(models.Model):
185158
vector = ArrayField(models.FloatField(), size=10)
186159

187160
class Meta:
188-
indexes = [VectorSearchIndex(fields=["vector"])]
161+
indexes = [VectorSearchIndex(fields=["vector"], similarities="cosine")]
189162

190163
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
191164
self.assertEqual(errors, [])
192165

193-
def test_all_valid_fields(self):
166+
def test_valid_fields(self):
194167
class Data(EmbeddedModel):
195168
integer = models.IntegerField()
196169

@@ -216,6 +189,7 @@ class Meta:
216189
"boolean",
217190
"date",
218191
],
192+
similarities="cosine",
219193
)
220194
]
221195

@@ -227,7 +201,11 @@ class NoSearchVectorModel(models.Model):
227201
text = models.CharField(max_length=100)
228202

229203
class Meta:
230-
indexes = [VectorSearchIndex(name="recent_test_idx", fields=["text"])]
204+
indexes = [
205+
VectorSearchIndex(
206+
name="recent_test_idx", fields=["text"], similarities="cosine"
207+
)
208+
]
231209

232210
errors = checks.run_checks(app_configs=self.apps.get_app_configs(), databases={"default"})
233211
self.assertEqual(

tests/indexes_/test_search_indexes.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_search_index_not_created(self):
2626
editor.remove_index(index=index, model=SearchIndexTestModel)
2727

2828
def test_vector_index_not_created(self):
29-
index = VectorSearchIndex(name="recent_test_idx", fields=["number"])
29+
index = VectorSearchIndex(name="recent_test_idx", fields=["number"], similarities="cosine")
3030
with connection.schema_editor() as editor, self.assertNumQueries(0):
3131
editor.add_index(index=index, model=SearchIndexTestModel)
3232
self.assertNotIn(
@@ -67,27 +67,18 @@ def test_no_extra_kargs(self):
6767
with self.assertRaisesMessage(TypeError, msg):
6868
VectorSearchIndex(condition="")
6969

70+
def test_no_similarities(self):
71+
msg = (
72+
"VectorSearchIndex.__init__() missing 1 required keyword-only argument: 'similarities'"
73+
)
74+
with self.assertRaisesMessage(TypeError, msg):
75+
VectorSearchIndex(name="recent_test_idx", fields=["number"])
76+
7077
def test_deconstruct(self):
71-
index = VectorSearchIndex(name="recent_test_idx", fields=["number"])
78+
index = VectorSearchIndex(name="recent_test_idx", fields=["number"], similarities="cosine")
7279
name, args, kwargs = index.deconstruct()
73-
self.assertEqual(kwargs, {"name": "recent_test_idx", "fields": ["number"]})
74-
new = VectorSearchIndex(*args, **kwargs)
75-
self.assertEqual(new.similarities, index.similarities)
76-
77-
def test_deconstruct_with_similarities(self):
78-
index = VectorSearchIndex(
79-
name="recent_test_idx",
80-
fields=["number", "char"],
81-
similarities=["cosine", "dotProduct"],
82-
)
83-
path, args, kwargs = index.deconstruct()
8480
self.assertEqual(
85-
kwargs,
86-
{
87-
"name": "recent_test_idx",
88-
"fields": ["number", "char"],
89-
"similarities": ["cosine", "dotProduct"],
90-
},
81+
kwargs, {"name": "recent_test_idx", "fields": ["number"], "similarities": "cosine"}
9182
)
9283
new = VectorSearchIndex(*args, **kwargs)
9384
self.assertEqual(new.similarities, index.similarities)
@@ -195,7 +186,7 @@ def test_valid_fields(self):
195186
@skipUnlessDBFeature("supports_atlas_search")
196187
class VectorSearchIndexSchemaTests(SchemaAssertionMixin, TestCase):
197188
def test_simple(self):
198-
index = VectorSearchIndex(name="recent_test_idx", fields=["integer"])
189+
index = VectorSearchIndex(name="recent_test_idx", fields=["integer"], similarities="cosine")
199190
with connection.schema_editor() as editor:
200191
self.assertAddRemoveIndex(editor, index=index, model=SearchIndexTestModel)
201192

@@ -212,6 +203,7 @@ def test_multiple_fields(self):
212203
"vector_float",
213204
"vector_integer",
214205
],
206+
similarities="cosine",
215207
)
216208
with connection.schema_editor() as editor:
217209
editor.add_index(index=index, model=SearchIndexTestModel)

0 commit comments

Comments
 (0)