Skip to content

Commit d500e3a

Browse files
committed
Add atlas vector search index
1 parent f6689c9 commit d500e3a

File tree

5 files changed

+213
-12
lines changed

5 files changed

+213
-12
lines changed

django_mongodb_backend/indexes.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
1+
import itertools
2+
13
from django.db import NotSupportedError
2-
from django.db.models import Index
4+
from django.db.models import (
5+
BooleanField,
6+
CharField,
7+
DateField,
8+
DateTimeField,
9+
DecimalField,
10+
FloatField,
11+
Index,
12+
IntegerField,
13+
TextField,
14+
UUIDField,
15+
)
316
from django.db.models.lookups import BuiltinLookup
417
from django.db.models.sql.query import Query
518
from django.db.models.sql.where import AND, XOR, WhereNode
619
from pymongo import ASCENDING, DESCENDING
720
from pymongo.operations import IndexModel, SearchIndexModel
821

22+
from django_mongodb_backend.fields import ArrayField, ObjectIdAutoField, ObjectIdField
23+
924
from .query_utils import process_rhs
1025

1126
MONGO_INDEX_OPERATORS = {
@@ -122,12 +137,81 @@ def create_mongodb_index(
122137
for field_name, _ in self.fields_orders:
123138
field_ = model._meta.get_field(field_name)
124139
type_ = connection.mongo_data_types[field_.get_internal_type()]
125-
fields[field_name] = {"type": type_}
140+
field_path = column_prefix + model._meta.get_field(field_name).column
141+
fields[field_path] = {"type": type_}
126142
return SearchIndexModel(
127143
definition={"mappings": {"dynamic": False, "fields": fields}}, name=self.name
128144
)
129145

130146

147+
class AtlasVectorSearchIndex(Index):
148+
suffix = "atlas_vector_search"
149+
ALLOWED_SIMILARITY_FUNCTIONS = ("euclidean", "cosine", "dotProduct")
150+
151+
def __init__(self, *expressions, similarities="cosine", **kwargs):
152+
super().__init__(*expressions, **kwargs)
153+
# validate the similarities types
154+
if isinstance(similarities, str):
155+
self._check_similarity_functions([similarities])
156+
else:
157+
self._check_similarity_functions(similarities)
158+
self.similarities = similarities
159+
160+
def _check_similarity_functions(self, similarities):
161+
for func in similarities:
162+
if func not in self.ALLOWED_SIMILARITY_FUNCTIONS:
163+
raise ValueError(
164+
f"{func} isn't a valid similarity function, options "
165+
f"'are {','.join(self.ALLOWED_SIMILARITY_FUNCTIONS)}"
166+
)
167+
168+
def create_mongodb_index(
169+
self, model, schema_editor, connection=None, field=None, unique=False, column_prefix=""
170+
):
171+
similarities = (
172+
itertools.cycle([self.similarities])
173+
if isinstance(self.similarities, str)
174+
else iter(self.similarities)
175+
)
176+
fields = []
177+
for field_name, _ in self.fields_orders:
178+
field_ = model._meta.get_field(field_name)
179+
field_path = column_prefix + model._meta.get_field(field_name).column
180+
mappings = {"path": field_path}
181+
if isinstance(field_, ArrayField):
182+
try:
183+
vector_size = int(field_.size)
184+
except (ValueError, TypeError) as err:
185+
raise ValueError("Atlas vector search requires fixed size.") from err
186+
if not isinstance(field_.base_field, FloatField | DecimalField):
187+
raise ValueError("Base type must be Float or Decimal.")
188+
mappings.update(
189+
{
190+
"type": "vector",
191+
"numDimensions": vector_size,
192+
"similarity": next(similarities),
193+
}
194+
)
195+
elif isinstance(
196+
field_,
197+
BooleanField
198+
| IntegerField
199+
| DateField
200+
| DateTimeField
201+
| CharField
202+
| TextField
203+
| UUIDField
204+
| ObjectIdField
205+
| ObjectIdAutoField,
206+
):
207+
mappings["type"] = "filter"
208+
else:
209+
field_type = field_.get_internal_type()
210+
raise ValueError(f"Unsupported filter of type {field_type}.")
211+
fields.append(mappings)
212+
return SearchIndexModel(definition={"fields": fields}, name=self.name, type="vectorSearch")
213+
214+
131215
def register_indexes():
132216
BuiltinLookup.as_mql_idx = builtin_lookup_idx
133217
Index._get_condition_mql = _get_condition_mql

django_mongodb_backend/introspection.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from django.db.models import Index
33
from pymongo import ASCENDING, DESCENDING
44

5-
from django_mongodb_backend.indexes import AtlasSearchIndex
5+
from django_mongodb_backend.indexes import AtlasSearchIndex, AtlasVectorSearchIndex
66

77

88
class DatabaseIntrospection(BaseDatabaseIntrospection):
@@ -37,7 +37,14 @@ def _get_atlas_index_info(self, table_name):
3737
constraints = {}
3838
indexes = self.connection.get_collection(table_name).list_search_indexes()
3939
for details in indexes:
40-
columns = list(details["latestDefinition"]["mappings"].get("fields", {}).keys())
40+
if details["type"] == "vectorSearch":
41+
columns = [field["path"] for field in details["latestDefinition"]["fields"]]
42+
type_ = AtlasVectorSearchIndex.suffix
43+
options = details
44+
else:
45+
columns = list(details["latestDefinition"]["mappings"].get("fields", {}).keys())
46+
options = details["latestDefinition"]["mappings"]
47+
type_ = AtlasSearchIndex.suffix
4148
constraints[details["name"]] = {
4249
"check": False,
4350
"columns": columns,
@@ -46,9 +53,9 @@ def _get_atlas_index_info(self, table_name):
4653
"index": True,
4754
"orders": [],
4855
"primary_key": False,
49-
"type": AtlasSearchIndex.suffix,
56+
"type": type_,
5057
"unique": False,
51-
"options": details["latestDefinition"]["mappings"],
58+
"options": options,
5259
}
5360
return constraints
5461

django_mongodb_backend/schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from django.db.models import Index, UniqueConstraint
55
from pymongo.operations import IndexModel, SearchIndexModel
66

7-
from django_mongodb_backend.indexes import AtlasSearchIndex
7+
from django_mongodb_backend.indexes import AtlasSearchIndex, AtlasVectorSearchIndex
88

99
from .fields import EmbeddedModelField
1010
from .query import wrap_database_errors
@@ -310,7 +310,7 @@ def _(self, index: Index, model):
310310
return self.get_collection(model._meta.db_table).drop_index(index.name)
311311

312312
@_remove_index.register
313-
def _(self, index: AtlasSearchIndex, model):
313+
def _(self, index: AtlasSearchIndex | AtlasVectorSearchIndex, model):
314314
return self.get_collection(model._meta.db_table).drop_search_index(index.name)
315315

316316
@ignore_embedded_models

tests/indexes_/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from django.db import models
22

3-
from django_mongodb_backend.fields import EmbeddedModelField
3+
from django_mongodb_backend.fields import ArrayField, EmbeddedModelField
44
from django_mongodb_backend.models import EmbeddedModel
55

66

@@ -15,3 +15,7 @@ class Article(models.Model):
1515
data = models.JSONField()
1616
embedded = EmbeddedModelField(Data)
1717
auto_now = models.DateTimeField(auto_now=True)
18+
title_embedded = ArrayField(models.FloatField(), size=10)
19+
description_embedded = ArrayField(models.FloatField(), size=10)
20+
number_list = ArrayField(models.FloatField())
21+
name_list = ArrayField(models.CharField(max_length=30), size=10)

tests/indexes_/test_atlas_indexes.py

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from django.db import connection
55
from django.test import TestCase
66

7-
from django_mongodb_backend.indexes import AtlasSearchIndex
7+
from django_mongodb_backend.indexes import AtlasSearchIndex, AtlasVectorSearchIndex
88

99
from .models import Article, Data
1010

@@ -32,7 +32,7 @@ def assertAddRemoveIndex(self, editor, model, index):
3232
),
3333
)
3434

35-
def test_simple_atlas_index(self):
35+
def test_simple(self):
3636
with connection.schema_editor() as editor:
3737
index = AtlasSearchIndex(
3838
name="recent_article_idx",
@@ -41,7 +41,7 @@ def test_simple_atlas_index(self):
4141
editor.add_index(index=index, model=Article)
4242
self.assertAddRemoveIndex(editor, Article, index)
4343

44-
def test_multiple_fields_atlas_index(self):
44+
def test_multiple_fields(self):
4545
with connection.schema_editor() as editor:
4646
index = AtlasSearchIndex(
4747
name="recent_article_idx",
@@ -106,6 +106,112 @@ def setUpTestData(cls):
106106
data="{json: i}",
107107
embedded=Data(integer=i),
108108
auto_now=datetime.datetime.now(),
109+
title_embedded=[0.1] * 10,
110+
description_embedded=[2.5] * 10,
111+
number_list=[2] * i,
112+
name_list=[f"name_{i}"] * 10,
113+
)
114+
for i in range(5)
115+
]
116+
cls.objs = Article.objects.bulk_create(articles)
117+
118+
119+
class AtlasSearchIndexTests(TestCase):
120+
# Schema editor is used to create the index to test that it works.
121+
# available_apps = ["indexes"]
122+
available_apps = None # could be removed?
123+
124+
def assertAddRemoveIndex(self, editor, model, index):
125+
editor.add_index(index=index, model=model)
126+
self.assertIn(
127+
index.name,
128+
connection.introspection.get_constraints(
129+
cursor=None,
130+
table_name=model._meta.db_table,
131+
),
132+
)
133+
editor.remove_index(index=index, model=model)
134+
self.assertNotIn(
135+
index.name,
136+
connection.introspection.get_constraints(
137+
cursor=None,
138+
table_name=model._meta.db_table,
139+
),
140+
)
141+
142+
def test_simple_atlas_vector_search(self):
143+
with connection.schema_editor() as editor:
144+
index = AtlasVectorSearchIndex(
145+
name="recent_article_idx",
146+
fields=["number"],
147+
)
148+
editor.add_index(index=index, model=Article)
149+
self.assertAddRemoveIndex(editor, Article, index)
150+
151+
def test_multiple_fields(self):
152+
with connection.schema_editor() as editor:
153+
index = AtlasVectorSearchIndex(
154+
name="recent_article_idx",
155+
fields=["headline", "number", "body", "description_embedded"],
156+
)
157+
editor.add_index(index=index, model=Article)
158+
index_info = connection.introspection.get_constraints(
159+
cursor=None,
160+
table_name=Article._meta.db_table,
161+
)
162+
expected_options = {
163+
"latestDefinition": {
164+
"fields": [
165+
{"path": "headline", "type": "filter"},
166+
{"path": "number", "type": "filter"},
167+
{"path": "body", "type": "filter"},
168+
{
169+
"numDimensions": 10,
170+
"path": "description_embedded",
171+
"similarity": "cosine",
172+
"type": "vector",
173+
},
174+
]
175+
},
176+
"latestVersion": 0,
177+
"name": "recent_article_idx",
178+
"queryable": False,
179+
"type": "vectorSearch",
180+
}
181+
self.assertCountEqual(index_info[index.name]["columns"], index.fields)
182+
index_info[index.name]["options"].pop("id")
183+
index_info[index.name]["options"].pop("status")
184+
self.assertEqual(index_info[index.name]["options"], expected_options)
185+
self.assertAddRemoveIndex(editor, Article, index)
186+
187+
def test_field_not_exists(self):
188+
index = AtlasVectorSearchIndex(
189+
name="recent_article_idx",
190+
fields=["headline", "number1", "title_embedded"],
191+
)
192+
with connection.schema_editor() as editor:
193+
msg = "Article has no field named 'number1'"
194+
with self.assertRaisesMessage(
195+
FieldDoesNotExist, msg
196+
), connection.schema_editor() as editor:
197+
editor.add_index(index=index, model=Article)
198+
199+
200+
class AtlasSearchIndexTestsWithData(AtlasSearchIndexTests):
201+
@classmethod
202+
def setUpTestData(cls):
203+
articles = [
204+
Article(
205+
headline=f"Title {i}",
206+
number=i,
207+
body=f"body {i}",
208+
data="{json: i}",
209+
embedded=Data(integer=i),
210+
auto_now=datetime.datetime.now(),
211+
title_embedded=[0.1] * 10,
212+
description_embedded=[2.5] * 10,
213+
number_list=[2] * i,
214+
name_list=[f"name_{i}"] * 10,
109215
)
110216
for i in range(5)
111217
]

0 commit comments

Comments
 (0)