@@ -44,7 +44,7 @@ class Article(models.Model):
44
44
vector = ArrayField (models .FloatField (), size = 10 )
45
45
46
46
class Meta :
47
- indexes = [VectorSearchIndex (fields = ["title" , "vector" ])]
47
+ indexes = [VectorSearchIndex (fields = ["title" , "vector" ], similarities = "cosine" )]
48
48
49
49
errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
50
50
self .assertEqual (
@@ -72,7 +72,7 @@ class Article(models.Model):
72
72
title_embedded = ArrayField (models .FloatField ())
73
73
74
74
class Meta :
75
- indexes = [VectorSearchIndex (fields = ["title_embedded" ])]
75
+ indexes = [VectorSearchIndex (fields = ["title_embedded" ], similarities = "cosine" )]
76
76
77
77
errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
78
78
self .assertEqual (
@@ -91,7 +91,7 @@ class Article(models.Model):
91
91
title_embedded = ArrayField (models .CharField (), size = 30 )
92
92
93
93
class Meta :
94
- indexes = [VectorSearchIndex (fields = ["title_embedded" ])]
94
+ indexes = [VectorSearchIndex (fields = ["title_embedded" ], similarities = "cosine" )]
95
95
96
96
errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
97
97
self .assertEqual (
@@ -112,7 +112,7 @@ class Article(models.Model):
112
112
vector = ArrayField (models .FloatField (), size = 10 )
113
113
114
114
class Meta :
115
- indexes = [VectorSearchIndex (fields = ["data" , "vector" ])]
115
+ indexes = [VectorSearchIndex (fields = ["data" , "vector" ], similarities = "cosine" )]
116
116
117
117
errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
118
118
self .assertEqual (
@@ -127,7 +127,7 @@ class Meta:
127
127
],
128
128
)
129
129
130
- def test_invalid_number_similarity_function_singular (self ):
130
+ def test_fields_and_similarities_mismatch (self ):
131
131
class Article (models .Model ):
132
132
vector = ArrayField (models .FloatField (), size = 10 )
133
133
@@ -153,44 +153,17 @@ class Meta:
153
153
],
154
154
)
155
155
156
- def test_invalid_number_similarity_function_plural (self ):
157
- class Article (models .Model ):
158
- vector1 = ArrayField (models .FloatField (), size = 10 )
159
- vector2 = ArrayField (models .FloatField (), size = 10 )
160
-
161
- class Meta :
162
- indexes = [
163
- VectorSearchIndex (
164
- fields = ["vector1" , "vector2" ],
165
- similarities = ["dotProduct" ],
166
- )
167
- ]
168
-
169
- errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
170
- self .assertEqual (
171
- errors ,
172
- [
173
- checks .Error (
174
- "VectorSearchIndex requires the same number of similarities "
175
- "and vector fields; Article has 2 ArrayField(s) but similarities "
176
- "has 1 element(s)." ,
177
- id = "django_mongodb_backend.indexes.VectorSearchIndex.E005" ,
178
- obj = Article ,
179
- ),
180
- ],
181
- )
182
-
183
156
def test_simple (self ):
184
157
class Article (models .Model ):
185
158
vector = ArrayField (models .FloatField (), size = 10 )
186
159
187
160
class Meta :
188
- indexes = [VectorSearchIndex (fields = ["vector" ])]
161
+ indexes = [VectorSearchIndex (fields = ["vector" ], similarities = "cosine" )]
189
162
190
163
errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
191
164
self .assertEqual (errors , [])
192
165
193
- def test_all_valid_fields (self ):
166
+ def test_valid_fields (self ):
194
167
class Data (EmbeddedModel ):
195
168
integer = models .IntegerField ()
196
169
@@ -216,6 +189,7 @@ class Meta:
216
189
"boolean" ,
217
190
"date" ,
218
191
],
192
+ similarities = "cosine" ,
219
193
)
220
194
]
221
195
@@ -227,7 +201,11 @@ class NoSearchVectorModel(models.Model):
227
201
text = models .CharField (max_length = 100 )
228
202
229
203
class Meta :
230
- indexes = [VectorSearchIndex (name = "recent_test_idx" , fields = ["text" ])]
204
+ indexes = [
205
+ VectorSearchIndex (
206
+ name = "recent_test_idx" , fields = ["text" ], similarities = "cosine"
207
+ )
208
+ ]
231
209
232
210
errors = checks .run_checks (app_configs = self .apps .get_app_configs (), databases = {"default" })
233
211
self .assertEqual (
0 commit comments