Skip to content

Commit bdac663

Browse files
committed
Add atlas vector search index
1 parent e134012 commit bdac663

File tree

5 files changed

+212
-13
lines changed

5 files changed

+212
-13
lines changed

django_mongodb_backend/indexes.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
11
from collections import defaultdict
2+
import itertools
23

34
from django.db import NotSupportedError
4-
from django.db.models import Index
5+
from django.db.models import (
6+
BooleanField,
7+
CharField,
8+
DateField,
9+
DateTimeField,
10+
DecimalField,
11+
FloatField,
12+
Index,
13+
IntegerField,
14+
TextField,
15+
UUIDField,
16+
)
517
from django.db.models.lookups import BuiltinLookup
618
from django.db.models.sql.query import Query
719
from django.db.models.sql.where import AND, XOR, WhereNode
820
from pymongo import ASCENDING, DESCENDING
921
from pymongo.operations import IndexModel, SearchIndexModel
1022

23+
from django_mongodb_backend.fields import ArrayField, ObjectIdAutoField, ObjectIdField
24+
1125
from .query_utils import process_rhs
1226

1327
MONGO_INDEX_OPERATORS = {
@@ -111,7 +125,6 @@ def create_mongodb_index(
111125
unique=False,
112126
column_prefix="",
113127
):
114-
from collections import defaultdict
115128

116129
if self.contains_expressions:
117130
return None
@@ -163,12 +176,81 @@ def create_mongodb_index(
163176
for field_name, _ in self.fields_orders:
164177
field_ = model._meta.get_field(field_name)
165178
type_ = connection.mongo_data_types[field_.get_internal_type()]
166-
fields[field_name] = {"type": type_}
179+
field_path = column_prefix + model._meta.get_field(field_name).column
180+
fields[field_path] = {"type": type_}
167181
return SearchIndexModel(
168182
definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name
169183
)
170184

171185

186+
class AtlasVectorSearchIndex(Index):
187+
suffix = "atlas_vector_search"
188+
ALLOWED_SIMILARITY_FUNCTIONS = ("euclidean", "cosine", "dotProduct")
189+
190+
def __init__(self, *expressions, similarities="cosine", **kwargs):
191+
super().__init__(*expressions, **kwargs)
192+
# validate the similarities types
193+
if isinstance(similarities, str):
194+
self._check_similarity_functions([similarities])
195+
else:
196+
self._check_similarity_functions(similarities)
197+
self.similarities = similarities
198+
199+
def _check_similarity_functions(self, similarities):
200+
for func in similarities:
201+
if func not in self.ALLOWED_SIMILARITY_FUNCTIONS:
202+
raise ValueError(
203+
f"{func} isn't a valid similarity function, options "
204+
f"'are {','.join(self.ALLOWED_SIMILARITY_FUNCTIONS)}"
205+
)
206+
207+
def create_mongodb_index(
208+
self, model, schema_editor, connection=None, field=None, unique=False, column_prefix=""
209+
):
210+
similarities = (
211+
itertools.cycle([self.similarities])
212+
if isinstance(self.similarities, str)
213+
else iter(self.similarities)
214+
)
215+
fields = []
216+
for field_name, _ in self.fields_orders:
217+
field_ = model._meta.get_field(field_name)
218+
field_path = column_prefix + model._meta.get_field(field_name).column
219+
mappings = {"path": field_path}
220+
if isinstance(field_, ArrayField):
221+
try:
222+
vector_size = int(field_.size)
223+
except (ValueError, TypeError) as err:
224+
raise ValueError("Atlas vector search requires fixed size.") from err
225+
if not isinstance(field_.base_field, FloatField | DecimalField):
226+
raise ValueError("Base type must be Float or Decimal.")
227+
mappings.update(
228+
{
229+
"type": "vector",
230+
"numDimensions": vector_size,
231+
"similarity": next(similarities),
232+
}
233+
)
234+
elif isinstance(
235+
field_,
236+
BooleanField
237+
| IntegerField
238+
| DateField
239+
| DateTimeField
240+
| CharField
241+
| TextField
242+
| UUIDField
243+
| ObjectIdField
244+
| ObjectIdAutoField,
245+
):
246+
mappings["type"] = "filter"
247+
else:
248+
field_type = field_.get_internal_type()
249+
raise ValueError(f"Unsupported filter of type {field_type}.")
250+
fields.append(mappings)
251+
return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")
252+
253+
172254
def register_indexes():
173255
BuiltinLookup.as_mql_idx = builtin_lookup_idx
174256
Index._get_condition_mql = _get_condition_mql

django_mongodb_backend/introspection.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from django.db.models import Index
33
from pymongo import ASCENDING, DESCENDING
44

5-
from django_mongodb_backend.indexes import AtlasSearchIndex
5+
from django_mongodb_backend.indexes import AtlasSearchIndex, AtlasVectorSearchIndex
66

77

88
class DatabaseIntrospection(BaseDatabaseIntrospection):
@@ -37,7 +37,14 @@ def _get_atlas_index_info(self, table_name):
3737
constraints = {}
3838
indexes = self.connection.get_collection(table_name).list_search_indexes()
3939
for details in indexes:
40-
columns = list(details["latestDefinition"]["mappings"].get("fields", {}).keys())
40+
if details["type"] == "vectorSearch":
41+
columns = [field["path"] for field in details["latestDefinition"]["fields"]]
42+
type_ = AtlasVectorSearchIndex.suffix
43+
options = details
44+
else:
45+
columns = list(details["latestDefinition"]["mappings"].get("fields", {}).keys())
46+
options = details["latestDefinition"]["mappings"]
47+
type_ = AtlasSearchIndex.suffix
4148
constraints[details["name"]] = {
4249
"check": False,
4350
"columns": columns,
@@ -46,9 +53,9 @@ def _get_atlas_index_info(self, table_name):
4653
"index": True,
4754
"orders": [],
4855
"primary_key": False,
49-
"type": AtlasSearchIndex.suffix,
56+
"type": type_,
5057
"unique": False,
51-
"options": details["latestDefinition"]["mappings"],
58+
"options": options,
5259
}
5360
return constraints
5461

django_mongodb_backend/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from django.db.models import Index, UniqueConstraint
55
from pymongo.operations import IndexModel, SearchIndexModel
66

7-
from django_mongodb_backend.indexes import AtlasSearchIndex
7+
from django_mongodb_backend.indexes import AtlasSearchIndex, AtlasVectorSearchIndex
88

99
from .fields import EmbeddedModelField
1010
from .query import wrap_database_errors
@@ -305,7 +305,7 @@ def _(self, index: Index, model):
305305
return self.get_collection(model._meta.db_table).drop_index(index.name)
306306

307307
@_remove_index.register
308-
def _(self, index: AtlasSearchIndex, model):
308+
def _(self, index: AtlasSearchIndex | AtlasVectorSearchIndex, model):
309309
return self.get_collection(model._meta.db_table).drop_search_index(index.name)
310310

311311
@ignore_embedded_models

tests/indexes_/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.db import models
22

3-
from django_mongodb_backend.fields import EmbeddedModelField
3+
from django_mongodb_backend.fields import ArrayField, EmbeddedModelField
44
from django_mongodb_backend.models import EmbeddedModel
55

66

@@ -15,3 +15,7 @@ class Article(models.Model):
1515
data = models.JSONField()
1616
embedded = EmbeddedModelField(Data)
1717
auto_now = models.DateTimeField(auto_now=True)
18+
title_embedded = ArrayField(models.FloatField(), size=10)
19+
description_embedded = ArrayField(models.FloatField(), size=10)
20+
number_list = ArrayField(models.FloatField())
21+
name_list = ArrayField(models.CharField(max_length=30), size=10)

tests/indexes_/test_atlas_indexes.py

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from django.db import connection
55
from django.test import TestCase
66

7-
from django_mongodb_backend.indexes import AtlasSearchIndex
7+
from django_mongodb_backend.indexes import AtlasSearchIndex, AtlasVectorSearchIndex
88

99
from .models import Article, Data
1010

@@ -32,7 +32,7 @@ def assertAddRemoveIndex(self, editor, model, index):
3232
),
3333
)
3434

35-
def test_simple_atlas_index(self):
35+
def test_simple(self):
3636
with connection.schema_editor() as editor:
3737
index = AtlasSearchIndex(
3838
name="recent_article_idx",
@@ -41,7 +41,7 @@ def test_simple_atlas_index(self):
4141
editor.add_index(index=index, model=Article)
4242
self.assertAddRemoveIndex(editor, Article, index)
4343

44-
def test_multiple_fields_atlas_index(self):
44+
def test_multiple_fields(self):
4545
with connection.schema_editor() as editor:
4646
index = AtlasSearchIndex(
4747
name="recent_article_idx",
@@ -106,6 +106,112 @@ def setUpTestData(cls):
106106
data="{json: i}",
107107
embedded=Data(integer=i),
108108
auto_now=datetime.datetime.now(),
109+
title_embedded=[0.1] * 10,
110+
description_embedded=[2.5] * 10,
111+
number_list=[2] * i,
112+
name_list=[f"name_{i}"] * 10,
113+
)
114+
for i in range(5)
115+
]
116+
cls.objs = Article.objects.bulk_create(articles)
117+
118+
119+
class AtlasSearchIndexTests(TestCase):
120+
# Schema editor is used to create the index to test that it works.
121+
# available_apps = ["indexes"]
122+
available_apps = None # could be removed?
123+
124+
def assertAddRemoveIndex(self, editor, model, index):
125+
editor.add_index(index=index, model=model)
126+
self.assertIn(
127+
index.name,
128+
connection.introspection.get_constraints(
129+
cursor=None,
130+
table_name=model._meta.db_table,
131+
),
132+
)
133+
editor.remove_index(index=index, model=model)
134+
self.assertNotIn(
135+
index.name,
136+
connection.introspection.get_constraints(
137+
cursor=None,
138+
table_name=model._meta.db_table,
139+
),
140+
)
141+
142+
def test_simple_atlas_vector_search(self):
143+
with connection.schema_editor() as editor:
144+
index = AtlasVectorSearchIndex(
145+
name="recent_article_idx",
146+
fields=["number"],
147+
)
148+
editor.add_index(index=index, model=Article)
149+
self.assertAddRemoveIndex(editor, Article, index)
150+
151+
def test_multiple_fields(self):
152+
with connection.schema_editor() as editor:
153+
index = AtlasVectorSearchIndex(
154+
name="recent_article_idx",
155+
fields=["headline", "number", "body", "description_embedded"],
156+
)
157+
editor.add_index(index=index, model=Article)
158+
index_info = connection.introspection.get_constraints(
159+
cursor=None,
160+
table_name=Article._meta.db_table,
161+
)
162+
expected_options = {
163+
"latestDefinition": {
164+
"fields": [
165+
{"path": "headline", "type": "filter"},
166+
{"path": "number", "type": "filter"},
167+
{"path": "body", "type": "filter"},
168+
{
169+
"numDimensions": 10,
170+
"path": "description_embedded",
171+
"similarity": "cosine",
172+
"type": "vector",
173+
},
174+
]
175+
},
176+
"latestVersion": 0,
177+
"name": "recent_article_idx",
178+
"queryable": False,
179+
"type": "vectorSearch",
180+
}
181+
self.assertCountEqual(index_info[index.name]["columns"], index.fields)
182+
index_info[index.name]["options"].pop("id")
183+
index_info[index.name]["options"].pop("status")
184+
self.assertEqual(index_info[index.name]["options"], expected_options)
185+
self.assertAddRemoveIndex(editor, Article, index)
186+
187+
def test_field_not_exists(self):
188+
index = AtlasVectorSearchIndex(
189+
name="recent_article_idx",
190+
fields=["headline", "number1", "title_embedded"],
191+
)
192+
with connection.schema_editor() as editor:
193+
msg = "Article has no field named 'number1'"
194+
with self.assertRaisesMessage(
195+
FieldDoesNotExist, msg
196+
), connection.schema_editor() as editor:
197+
editor.add_index(index=index, model=Article)
198+
199+
200+
class AtlasSearchIndexTestsWithData(AtlasSearchIndexTests):
201+
@classmethod
202+
def setUpTestData(cls):
203+
articles = [
204+
Article(
205+
headline=f"Title {i}",
206+
number=i,
207+
body=f"body {i}",
208+
data="{json: i}",
209+
embedded=Data(integer=i),
210+
auto_now=datetime.datetime.now(),
211+
title_embedded=[0.1] * 10,
212+
description_embedded=[2.5] * 10,
213+
number_list=[2] * i,
214+
name_list=[f"name_{i}"] * 10,
109215
)
110216
for i in range(5)
111217
]

0 commit comments

Comments
 (0)