Skip to content

Commit 0bf803b

Browse files
authored
Consolidate the HNSW tests (elastic#136488)
1 parent ec2b7f0 commit 0bf803b

File tree

7 files changed

+329
-307
lines changed

7 files changed

+329
-307
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormat.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,18 @@ public ES93HnswBinaryQuantizedVectorsFormat() {
4747
/**
4848
* Constructs a format using the given graph construction parameters.
4949
*
50-
* @param maxConn the maximum number of connections to a node in the HNSW graph
51-
* @param beamWidth the size of the queue maintained during graph construction.
50+
* @param useDirectIO whether to use direct IO when reading raw vectors
51+
*/
52+
public ES93HnswBinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) {
53+
super(NAME);
54+
flatVectorsFormat = new ES93BinaryQuantizedVectorsFormat(useBFloat16, useDirectIO);
55+
}
56+
57+
/**
58+
* Constructs a format using the given graph construction parameters.
59+
*
60+
* @param maxConn the maximum number of connections to a node in the HNSW graph
61+
* @param beamWidth the size of the queue maintained during graph construction.
5262
* @param useDirectIO whether to use direct IO when reading raw vectors
5363
*/
5464
public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean useBFloat16, boolean useDirectIO) {
@@ -70,8 +80,8 @@ public ES93HnswBinaryQuantizedVectorsFormat(int maxConn, int beamWidth, boolean
7080
public ES93HnswBinaryQuantizedVectorsFormat(
7181
int maxConn,
7282
int beamWidth,
73-
boolean useDirectIO,
7483
boolean useBFloat16,
84+
boolean useDirectIO,
7585
int numMergeWorkers,
7686
ExecutorService mergeExec
7787
) {
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors;
11+
12+
import org.apache.lucene.codecs.Codec;
13+
import org.apache.lucene.codecs.KnnVectorsFormat;
14+
import org.apache.lucene.codecs.KnnVectorsReader;
15+
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
16+
import org.apache.lucene.document.Document;
17+
import org.apache.lucene.document.KnnFloatVectorField;
18+
import org.apache.lucene.index.CodecReader;
19+
import org.apache.lucene.index.DirectoryReader;
20+
import org.apache.lucene.index.FloatVectorValues;
21+
import org.apache.lucene.index.IndexReader;
22+
import org.apache.lucene.index.IndexWriter;
23+
import org.apache.lucene.index.IndexWriterConfig;
24+
import org.apache.lucene.index.KnnVectorValues;
25+
import org.apache.lucene.index.LeafReader;
26+
import org.apache.lucene.index.VectorSimilarityFunction;
27+
import org.apache.lucene.search.AcceptDocs;
28+
import org.apache.lucene.search.TopDocs;
29+
import org.apache.lucene.store.Directory;
30+
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
31+
import org.apache.lucene.tests.util.TestUtil;
32+
import org.apache.lucene.util.SameThreadExecutorService;
33+
import org.apache.lucene.util.VectorUtil;
34+
import org.elasticsearch.common.logging.LogConfigurator;
35+
import org.hamcrest.Matcher;
36+
37+
import java.io.IOException;
38+
import java.util.Map;
39+
import java.util.concurrent.ExecutorService;
40+
41+
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
42+
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
43+
import static org.hamcrest.Matchers.equalTo;
44+
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
45+
46+
public abstract class BaseHnswVectorsFormatTestCase extends BaseKnnVectorsFormatTestCase {
47+
48+
static {
49+
LogConfigurator.loadLog4jPlugins();
50+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
51+
}
52+
53+
protected abstract KnnVectorsFormat createFormat();
54+
55+
protected abstract KnnVectorsFormat createFormat(int maxConn, int beamWidth);
56+
57+
protected abstract KnnVectorsFormat createFormat(int maxConn, int beamWidth, int numMergeWorkers, ExecutorService service);
58+
59+
private KnnVectorsFormat format;
60+
61+
@Override
62+
public void setUp() throws Exception {
63+
format = createFormat();
64+
super.setUp();
65+
}
66+
67+
@Override
68+
protected Codec getCodec() {
69+
return TestUtil.alwaysKnnVectorsFormat(format);
70+
}
71+
72+
public void testLimits() {
73+
expectThrows(IllegalArgumentException.class, () -> createFormat(-1, 20));
74+
expectThrows(IllegalArgumentException.class, () -> createFormat(0, 20));
75+
expectThrows(IllegalArgumentException.class, () -> createFormat(20, 0));
76+
expectThrows(IllegalArgumentException.class, () -> createFormat(20, -1));
77+
expectThrows(IllegalArgumentException.class, () -> createFormat(512 + 1, 20));
78+
expectThrows(IllegalArgumentException.class, () -> createFormat(20, 3201));
79+
expectThrows(IllegalArgumentException.class, () -> createFormat(20, 100, 1, new SameThreadExecutorService()));
80+
}
81+
82+
public void testSingleVectorCase() throws Exception {
83+
float[] vector = randomVector(random().nextInt(12, 500));
84+
for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
85+
try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
86+
Document doc = new Document();
87+
if (similarityFunction == VectorSimilarityFunction.COSINE) {
88+
VectorUtil.l2normalize(vector);
89+
}
90+
doc.add(new KnnFloatVectorField("f", vector, similarityFunction));
91+
w.addDocument(doc);
92+
w.commit();
93+
try (IndexReader reader = DirectoryReader.open(w)) {
94+
LeafReader r = getOnlyLeafReader(reader);
95+
FloatVectorValues vectorValues = r.getFloatVectorValues("f");
96+
KnnVectorValues.DocIndexIterator docIndexIterator = vectorValues.iterator();
97+
assertThat(vectorValues.size(), equalTo(1));
98+
while (docIndexIterator.nextDoc() != NO_MORE_DOCS) {
99+
assertArrayEquals(vector, vectorValues.vectorValue(docIndexIterator.index()), 0.00001f);
100+
}
101+
float[] randomVector = randomVector(vector.length);
102+
if (similarityFunction == VectorSimilarityFunction.COSINE) {
103+
VectorUtil.l2normalize(randomVector);
104+
}
105+
float trueScore = similarityFunction.compare(vector, randomVector);
106+
TopDocs td = r.searchNearestVectors(
107+
"f",
108+
randomVector,
109+
1,
110+
AcceptDocs.fromLiveDocs(r.getLiveDocs(), r.maxDoc()),
111+
Integer.MAX_VALUE
112+
);
113+
assertEquals(1, td.totalHits.value());
114+
assertThat(td.scoreDocs[0].score, greaterThanOrEqualTo(0f));
115+
// When it's the only vector in a segment, the score should be very close to the true score
116+
assertEquals(trueScore, td.scoreDocs[0].score, 0.01f);
117+
}
118+
}
119+
}
120+
}
121+
122+
protected static void testSimpleOffHeapSize(
123+
Directory dir,
124+
IndexWriterConfig config,
125+
float[] vector,
126+
Matcher<? super Map<String, Long>> matchesMap
127+
) throws IOException {
128+
try (IndexWriter w = new IndexWriter(dir, config)) {
129+
Document doc = new Document();
130+
doc.add(new KnnFloatVectorField("f", vector, DOT_PRODUCT));
131+
w.addDocument(doc);
132+
w.commit();
133+
try (IndexReader reader = DirectoryReader.open(w)) {
134+
LeafReader r = getOnlyLeafReader(reader);
135+
if (r instanceof CodecReader codecReader) {
136+
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
137+
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
138+
knnVectorsReader = fieldsReader.getFieldReader("f");
139+
}
140+
var fieldInfo = r.getFieldInfos().fieldInfo("f");
141+
var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo);
142+
assertThat(offHeap, matchesMap);
143+
}
144+
}
145+
}
146+
}
147+
}

server/src/test/java/org/elasticsearch/index/codec/vectors/ES814HnswScalarQuantizedVectorsFormatTests.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@
3636

3737
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
3838
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
39+
import static org.hamcrest.Matchers.aMapWithSize;
40+
import static org.hamcrest.Matchers.arrayWithSize;
41+
import static org.hamcrest.Matchers.equalTo;
42+
import static org.hamcrest.Matchers.greaterThan;
43+
import static org.hamcrest.Matchers.hasEntry;
3944

40-
// @com.carrotsearch.randomizedtesting.annotations.Repeat(iterations = 50) // tests.directory sys property?
4145
public class ES814HnswScalarQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
4246

4347
static {
@@ -178,7 +182,7 @@ private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws Exc
178182
AcceptDocs.fromLiveDocs(leafReader.getLiveDocs(), leafReader.maxDoc()),
179183
100
180184
);
181-
assertEquals(hits.scoreDocs.length, 3);
185+
assertThat(hits.scoreDocs, arrayWithSize(3));
182186
assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id"));
183187
assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id"));
184188
assertEquals("C", storedFields.document(hits.scoreDocs[2].doc).get("id"));
@@ -202,10 +206,11 @@ public void testSimpleOffHeapSize() throws IOException {
202206
}
203207
var fieldInfo = r.getFieldInfos().fieldInfo("f");
204208
var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo);
205-
assertEquals(3, offHeap.size());
206-
assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec"));
207-
assertEquals(1L, (long) offHeap.get("vex"));
208-
assertTrue(offHeap.get("veq") > 0L);
209+
210+
assertThat(offHeap, aMapWithSize(3));
211+
assertThat(offHeap, hasEntry("vex", 1L));
212+
assertThat(offHeap, hasEntry(equalTo("veq"), greaterThan(0L)));
213+
assertThat(offHeap, hasEntry("vec", (long) vector.length * Float.BYTES));
209214
}
210215
}
211216
}

server/src/test/java/org/elasticsearch/index/codec/vectors/ES815HnswBitVectorsFormatTests.java

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626

2727
import java.io.IOException;
2828

29+
import static org.hamcrest.Matchers.aMapWithSize;
30+
import static org.hamcrest.Matchers.equalTo;
31+
import static org.hamcrest.Matchers.greaterThan;
32+
import static org.hamcrest.Matchers.hasEntry;
33+
2934
public class ES815HnswBitVectorsFormatTests extends BaseKnnBitVectorsFormatTestCase {
3035

3136
static final Codec codec = TestUtil.alwaysKnnVectorsFormat(new ES815HnswBitVectorsFormat());
@@ -56,9 +61,10 @@ public void testSimpleOffHeapSize() throws IOException {
5661
}
5762
var fieldInfo = r.getFieldInfos().fieldInfo("f");
5863
var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo);
59-
assertEquals(2, offHeap.size());
60-
assertTrue(offHeap.get("vec") > 0L);
61-
assertEquals(1L, (long) offHeap.get("vex"));
64+
65+
assertThat(offHeap, aMapWithSize(2));
66+
assertThat(offHeap, hasEntry("vex", 1L));
67+
assertThat(offHeap, hasEntry(equalTo("vec"), greaterThan(0L)));
6268
}
6369
}
6470
}

0 commit comments

Comments
 (0)