Skip to content

Commit 80c6129

Browse files
committed
add ivf sq test
1 parent dd667b3 commit 80c6129

File tree

3 files changed

+67
-21
lines changed

3 files changed

+67
-21
lines changed

paimon-faiss/paimon-faiss-jni/pom.xml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ under the License.
3333

3434
<properties>
3535
<faiss.version>1.7.4</faiss.version>
36-
<skipFaissTests>true</skipFaissTests>
3736
</properties>
3837

3938
<dependencies>
@@ -83,15 +82,6 @@ under the License.
8382
</nonFilteredFileExtensions>
8483
</configuration>
8584
</plugin>
86-
87-
<!-- Surefire for tests -->
88-
<plugin>
89-
<groupId>org.apache.maven.plugins</groupId>
90-
<artifactId>maven-surefire-plugin</artifactId>
91-
<configuration>
92-
<skipTests>${skipFaissTests}</skipTests>
93-
</configuration>
94-
</plugin>
9585
</plugins>
9686
</build>
9787

paimon-faiss/paimon-faiss-jni/src/test/java/org/apache/paimon/faiss/IndexTest.java

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,6 @@ void testBatchSearch() {
130130
}
131131
}
132132

133-
private Index createFlatIndexWithMetric(MetricType metricType) {
134-
return IndexFactory.create(DIMENSION, "Flat", metricType);
135-
}
136-
137-
private Index createFlatIndex() {
138-
return IndexFactory.create(DIMENSION, "Flat", MetricType.L2);
139-
}
140-
141133
@Test
142134
void testInnerProductMetric() {
143135
try (Index index = createFlatIndexWithMetric(MetricType.INNER_PRODUCT)) {
@@ -279,6 +271,65 @@ void testHNSWIndex() {
279271
}
280272
}
281273

274+
@Test
275+
void testIVFSQ8Index() {
276+
// IVF16384,SQ8 is a quantized index that needs training
277+
try (Index index = IndexFactory.create(DIMENSION, "IVF16384,SQ8", MetricType.L2)) {
278+
assertEquals(DIMENSION, index.getDimension());
279+
assertEquals(MetricType.L2, index.getMetricType());
280+
281+
// IVF index needs training
282+
assertTrue(!index.isTrained(), "IVF index should not be trained initially");
283+
284+
// Train the index with training vectors
285+
int numTrainingVectors = 20000; // Should be >= nlist (16384) for good training
286+
ByteBuffer trainingBuffer = createVectorBuffer(numTrainingVectors, DIMENSION);
287+
index.train(numTrainingVectors, trainingBuffer);
288+
289+
assertTrue(index.isTrained(), "Index should be trained after training");
290+
291+
// Add vectors after training
292+
ByteBuffer vectorBuffer = createVectorBuffer(NUM_VECTORS, DIMENSION);
293+
index.add(NUM_VECTORS, vectorBuffer);
294+
assertEquals(NUM_VECTORS, index.getCount());
295+
296+
// Set nprobe for search (number of clusters to visit)
297+
IndexIVF.setNprobe(index, 64);
298+
assertEquals(64, IndexIVF.getNprobe(index));
299+
300+
// Search
301+
float[] queryVectors = createQueryVectors(1, DIMENSION);
302+
float[] distances = new float[K];
303+
long[] labels = new long[K];
304+
305+
index.search(1, queryVectors, K, distances, labels);
306+
307+
// Verify search results
308+
for (int i = 0; i < K; i++) {
309+
assertTrue(
310+
labels[i] >= 0 && labels[i] < NUM_VECTORS,
311+
"Label " + labels[i] + " out of range");
312+
assertTrue(distances[i] >= 0, "Distance should be non-negative for L2");
313+
}
314+
315+
// Test batch search
316+
int numQueries = 3;
317+
float[] batchQueryVectors = createQueryVectors(numQueries, DIMENSION);
318+
float[] batchDistances = new float[numQueries * K];
319+
long[] batchLabels = new long[numQueries * K];
320+
321+
index.search(numQueries, batchQueryVectors, K, batchDistances, batchLabels);
322+
323+
for (int q = 0; q < numQueries; q++) {
324+
for (int n = 0; n < K; n++) {
325+
int idx = q * K + n;
326+
assertTrue(batchLabels[idx] >= 0 && batchLabels[idx] < NUM_VECTORS);
327+
assertTrue(batchDistances[idx] >= 0);
328+
}
329+
}
330+
}
331+
}
332+
282333
@Test
283334
void testErrorHandling() {
284335
// Test invalid dimension
@@ -372,6 +423,14 @@ void testBufferAllocationHelpers() {
372423
assertEquals(10 * Long.BYTES, idBuffer.capacity());
373424
}
374425

426+
private Index createFlatIndexWithMetric(MetricType metricType) {
427+
return IndexFactory.create(DIMENSION, "Flat", metricType);
428+
}
429+
430+
private Index createFlatIndex() {
431+
return IndexFactory.create(DIMENSION, "Flat", MetricType.L2);
432+
}
433+
375434
/** Create a direct ByteBuffer with random vectors. */
376435
private ByteBuffer createVectorBuffer(int n, int d) {
377436
ByteBuffer buffer = Index.allocateVectorBuffer(n, d);

pom.xml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,6 @@ under the License.
530530
</modules>
531531
<activation>
532532
<activeByDefault>true</activeByDefault>
533-
<property>
534-
<name>paimon-faiss-vector</name>
535-
</property>
536533
</activation>
537534
</profile>
538535
</profiles>

0 commit comments

Comments
 (0)