@@ -130,14 +130,6 @@ void testBatchSearch() {
130130 }
131131 }
132132
133- private Index createFlatIndexWithMetric (MetricType metricType ) {
134- return IndexFactory .create (DIMENSION , "Flat" , metricType );
135- }
136-
137- private Index createFlatIndex () {
138- return IndexFactory .create (DIMENSION , "Flat" , MetricType .L2 );
139- }
140-
141133 @ Test
142134 void testInnerProductMetric () {
143135 try (Index index = createFlatIndexWithMetric (MetricType .INNER_PRODUCT )) {
@@ -279,6 +271,65 @@ void testHNSWIndex() {
279271 }
280272 }
281273
274+ @ Test
275+ void testIVFSQ8Index () {
276+ // IVF16384,SQ8 is a quantized index that needs training
277+ try (Index index = IndexFactory .create (DIMENSION , "IVF16384,SQ8" , MetricType .L2 )) {
278+ assertEquals (DIMENSION , index .getDimension ());
279+ assertEquals (MetricType .L2 , index .getMetricType ());
280+
281+ // IVF index needs training
282+ assertTrue (!index .isTrained (), "IVF index should not be trained initially" );
283+
284+ // Train the index with training vectors
285+ int numTrainingVectors = 20000 ; // Should be >= nlist (16384) for good training
286+ ByteBuffer trainingBuffer = createVectorBuffer (numTrainingVectors , DIMENSION );
287+ index .train (numTrainingVectors , trainingBuffer );
288+
289+ assertTrue (index .isTrained (), "Index should be trained after training" );
290+
291+ // Add vectors after training
292+ ByteBuffer vectorBuffer = createVectorBuffer (NUM_VECTORS , DIMENSION );
293+ index .add (NUM_VECTORS , vectorBuffer );
294+ assertEquals (NUM_VECTORS , index .getCount ());
295+
296+ // Set nprobe for search (number of clusters to visit)
297+ IndexIVF .setNprobe (index , 64 );
298+ assertEquals (64 , IndexIVF .getNprobe (index ));
299+
300+ // Search
301+ float [] queryVectors = createQueryVectors (1 , DIMENSION );
302+ float [] distances = new float [K ];
303+ long [] labels = new long [K ];
304+
305+ index .search (1 , queryVectors , K , distances , labels );
306+
307+ // Verify search results
308+ for (int i = 0 ; i < K ; i ++) {
309+ assertTrue (
310+ labels [i ] >= 0 && labels [i ] < NUM_VECTORS ,
311+ "Label " + labels [i ] + " out of range" );
312+ assertTrue (distances [i ] >= 0 , "Distance should be non-negative for L2" );
313+ }
314+
315+ // Test batch search
316+ int numQueries = 3 ;
317+ float [] batchQueryVectors = createQueryVectors (numQueries , DIMENSION );
318+ float [] batchDistances = new float [numQueries * K ];
319+ long [] batchLabels = new long [numQueries * K ];
320+
321+ index .search (numQueries , batchQueryVectors , K , batchDistances , batchLabels );
322+
323+ for (int q = 0 ; q < numQueries ; q ++) {
324+ for (int n = 0 ; n < K ; n ++) {
325+ int idx = q * K + n ;
326+ assertTrue (batchLabels [idx ] >= 0 && batchLabels [idx ] < NUM_VECTORS );
327+ assertTrue (batchDistances [idx ] >= 0 );
328+ }
329+ }
330+ }
331+ }
332+
282333 @ Test
283334 void testErrorHandling () {
284335 // Test invalid dimension
@@ -372,6 +423,14 @@ void testBufferAllocationHelpers() {
372423 assertEquals (10 * Long .BYTES , idBuffer .capacity ());
373424 }
374425
426+ private Index createFlatIndexWithMetric (MetricType metricType ) {
427+ return IndexFactory .create (DIMENSION , "Flat" , metricType );
428+ }
429+
430+ private Index createFlatIndex () {
431+ return IndexFactory .create (DIMENSION , "Flat" , MetricType .L2 );
432+ }
433+
375434 /** Create a direct ByteBuffer with random vectors. */
376435 private ByteBuffer createVectorBuffer (int n , int d ) {
377436 ByteBuffer buffer = Index .allocateVectorBuffer (n , d );
0 commit comments