@@ -144,25 +144,64 @@ def read_objects_by_external_ids(self, ids: List[int]) -> OrderedDict:
144144 return {"object" : objects , "external_id" : external_ids }
145145
146146
147- def evaluate_query (index_uri , query_kwargs , dim_id , vector_dim_offset , config = None ):
147+ def assert_equal (
148+ index_type : str ,
149+ ids : np .array ,
150+ expected_ids : np .array ,
151+ ivf_pq_accuracy_threshold : float ,
152+ ):
153+ """
154+ IVF_PQ index has a lower recall rate than other indexes b/c of PQ-encoding, so we need to lower
155+ the threshold.
156+
157+ Parameters
158+ ----------
159+ index_type: str
160+ The index type.
161+ ids: np.array
162+ The ids returned by the query.
163+ expected_ids: np.array
164+ The expected ids.
165+ ivf_pq_accuracy_threshold: float
166+ The minimum fraction of expected_ids that must be in ids.
167+ """
168+ assert len (ids ) == len (expected_ids )
169+ if index_type == "IVF_PQ" :
170+ matches = np .intersect1d (ids , expected_ids )
171+ assert len (matches ) / len (ids ) >= ivf_pq_accuracy_threshold
172+ return
173+
174+ assert np .array_equiv (ids , expected_ids )
175+
176+
177+ def evaluate_query (
178+ index_type : str , index_uri , query_kwargs , dim_id , vector_dim_offset , config = None
179+ ):
148180 v_id = dim_id - vector_dim_offset
181+
149182 index = object_index .ObjectIndex (uri = index_uri , config = config )
150183 distances , objects , metadata = index .query (
151- {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])}, k = 5 , ** query_kwargs
184+ {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])}, k = 21 , ** query_kwargs
152185 )
153- assert np .array_equiv (
186+ assert_equal (
187+ index_type ,
154188 np .unique (objects ["external_id" ]),
155- np .array ([v_id - 2 , v_id - 1 , v_id , v_id + 1 , v_id + 2 ]),
189+ np .array ([v_id + i for i in range (- 10 , 11 )]),
190+ ivf_pq_accuracy_threshold = 0.8 ,
156191 )
192+
157193 distances , object_ids = index .query (
158194 {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])},
159- k = 5 ,
195+ k = 21 ,
160196 return_objects = False ,
161197 return_metadata = False ,
162198 ** query_kwargs ,
163199 )
164- assert np .array_equiv (
165- np .unique (object_ids ), np .array ([v_id - 2 , v_id - 1 , v_id , v_id + 1 , v_id + 2 ])
200+ assert_equal (
201+ index_type ,
202+ np .unique (object_ids ),
203+ np .array ([v_id + i for i in range (- 10 , 11 )]),
204+ ivf_pq_accuracy_threshold = 0.8 ,
166205 )
167206
168207 def df_filter (row ):
@@ -171,66 +210,84 @@ def df_filter(row):
171210 distances , objects , metadata = index .query (
172211 {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])},
173212 metadata_df_filter_fn = df_filter ,
174- k = 5 ,
213+ k = 21 ,
175214 ** query_kwargs ,
176215 )
177- assert np .array_equiv (
178- objects ["external_id" ], np .array ([v_id , v_id + 1 , v_id + 2 , v_id + 3 , v_id + 4 ])
216+ assert_equal (
217+ index_type ,
218+ np .unique (objects ["external_id" ]),
219+ np .array ([v_id + i for i in range (0 , 21 )]),
220+ ivf_pq_accuracy_threshold = 0.8 ,
179221 )
180222
181223 distances , object_ids = index .query (
182224 {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])},
183225 metadata_df_filter_fn = df_filter ,
184- k = 5 ,
226+ k = 21 ,
185227 return_objects = False ,
186228 return_metadata = False ,
187229 ** query_kwargs ,
188230 )
189- assert np .array_equiv (
190- object_ids , np .array ([v_id , v_id + 1 , v_id + 2 , v_id + 3 , v_id + 4 ])
231+ assert_equal (
232+ index_type ,
233+ np .unique (object_ids ),
234+ np .array ([v_id + i for i in range (0 , 21 )]),
235+ ivf_pq_accuracy_threshold = 0.8 ,
191236 )
192237
193238 index = object_index .ObjectIndex (
194239 uri = index_uri , load_metadata_in_memory = False , config = config
195240 )
196241 distances , objects , metadata = index .query (
197- {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])}, k = 5 , ** query_kwargs
242+ {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])}, k = 21 , ** query_kwargs
198243 )
199- assert np .array_equiv (
244+ assert_equal (
245+ index_type ,
200246 np .unique (objects ["external_id" ]),
201- np .array ([v_id - 2 , v_id - 1 , v_id , v_id + 1 , v_id + 2 ]),
247+ np .array ([v_id + i for i in range (- 10 , 11 )]),
248+ ivf_pq_accuracy_threshold = 0.8 ,
202249 )
250+
203251 distances , object_ids = index .query (
204252 {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])},
205- k = 5 ,
253+ k = 21 ,
206254 return_objects = False ,
207255 return_metadata = False ,
208256 ** query_kwargs ,
209257 )
210- assert np .array_equiv (
211- np .unique (object_ids ), np .array ([v_id - 2 , v_id - 1 , v_id , v_id + 1 , v_id + 2 ])
258+ assert_equal (
259+ index_type ,
260+ np .unique (object_ids ),
261+ np .array ([v_id + i for i in range (- 10 , 11 )]),
262+ ivf_pq_accuracy_threshold = 0.8 ,
212263 )
213264
214265 distances , objects , metadata = index .query (
215266 {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])},
216267 metadata_array_cond = f"test_attr >= { dim_id } " ,
217- k = 5 ,
268+ k = 21 ,
218269 ** query_kwargs ,
219270 )
220- assert np .array_equiv (
221- objects ["external_id" ], np .array ([v_id , v_id + 1 , v_id + 2 , v_id + 3 , v_id + 4 ])
271+ assert_equal (
272+ index_type ,
273+ np .unique (objects ["external_id" ]),
274+ np .array ([v_id + i for i in range (0 , 21 )]),
275+ ivf_pq_accuracy_threshold = 0.8 ,
222276 )
223277
224278 distances , object_ids = index .query (
225279 {"object" : np .array ([[dim_id , dim_id , dim_id , dim_id ]])},
226280 metadata_array_cond = f"test_attr >= { dim_id } " ,
227- k = 5 ,
281+ k = 21 ,
228282 return_objects = False ,
229283 return_metadata = False ,
230284 ** query_kwargs ,
231285 )
232- assert np .array_equiv (
233- object_ids , np .array ([v_id , v_id + 1 , v_id + 2 , v_id + 3 , v_id + 4 ])
286+ assert_equal (
287+ index_type ,
288+ np .unique (object_ids ),
289+ np .array ([v_id + i for i in range (0 , 21 )]),
290+ ivf_pq_accuracy_threshold = 0.8 ,
234291 )
235292
236293
@@ -256,12 +313,8 @@ def test_object_index(tmp_path):
256313
257314 # Check initial ingestion
258315 index .update_index (partitions = 10 )
259-
260- # TODO(SC-48908): Fix IVF_PQ with object index queries and remove.
261- if index_type == "IVF_PQ" :
262- continue
263-
264316 evaluate_query (
317+ index_type = index_type ,
265318 index_uri = index_uri ,
266319 query_kwargs = {"nprobe" : 10 , "l_search" : 250 },
267320 dim_id = 42 ,
@@ -272,6 +325,7 @@ def test_object_index(tmp_path):
272325 index = object_index .ObjectIndex (uri = index_uri )
273326 index .update_index (partitions = 10 )
274327 evaluate_query (
328+ index_type = index_type ,
275329 index_uri = index_uri ,
276330 query_kwargs = {"nprobe" : 10 , "l_search" : 500 },
277331 dim_id = 42 ,
@@ -288,6 +342,7 @@ def test_object_index(tmp_path):
288342 index .update_object_reader (reader )
289343 index .update_index (partitions = 10 )
290344 evaluate_query (
345+ index_type = index_type ,
291346 index_uri = index_uri ,
292347 query_kwargs = {"nprobe" : 10 , "l_search" : 500 },
293348 dim_id = 1042 ,
@@ -304,6 +359,7 @@ def test_object_index(tmp_path):
304359 index .update_object_reader (reader )
305360 index .update_index (partitions = 10 )
306361 evaluate_query (
362+ index_type = index_type ,
307363 index_uri = index_uri ,
308364 query_kwargs = {"nprobe" : 10 , "l_search" : 500 },
309365 dim_id = 2042 ,
@@ -351,6 +407,7 @@ def test_object_index_ivf_flat_cloud(tmp_path):
351407 config = config ,
352408 )
353409 evaluate_query (
410+ index_type = "IVF_FLAT" ,
354411 index_uri = index_uri ,
355412 query_kwargs = {"nprobe" : 10 },
356413 dim_id = 42 ,
@@ -381,6 +438,7 @@ def test_object_index_ivf_flat_cloud(tmp_path):
381438 config = config ,
382439 )
383440 evaluate_query (
441+ index_type = "IVF_FLAT" ,
384442 index_uri = index_uri ,
385443 query_kwargs = {"nprobe" : 10 },
386444 dim_id = 1042 ,
@@ -409,6 +467,7 @@ def test_object_index_flat(tmp_path):
409467 # Check initial ingestion
410468 index .update_index ()
411469 evaluate_query (
470+ index_type = "FLAT" ,
412471 index_uri = index_uri ,
413472 query_kwargs = {},
414473 dim_id = 42 ,
@@ -419,6 +478,7 @@ def test_object_index_flat(tmp_path):
419478 index = object_index .ObjectIndex (uri = index_uri )
420479 index .update_index ()
421480 evaluate_query (
481+ index_type = "FLAT" ,
422482 index_uri = index_uri ,
423483 query_kwargs = {},
424484 dim_id = 42 ,
@@ -435,6 +495,7 @@ def test_object_index_flat(tmp_path):
435495 index .update_object_reader (reader )
436496 index .update_index ()
437497 evaluate_query (
498+ index_type = "FLAT" ,
438499 index_uri = index_uri ,
439500 query_kwargs = {},
440501 dim_id = 1042 ,
@@ -451,6 +512,7 @@ def test_object_index_flat(tmp_path):
451512 index .update_object_reader (reader )
452513 index .update_index ()
453514 evaluate_query (
515+ index_type = "FLAT" ,
454516 index_uri = index_uri ,
455517 query_kwargs = {},
456518 dim_id = 2042 ,
0 commit comments