Skip to content

Commit e1d7abd

Browse files
committed
Adding vector knn testing tools
1 parent 43a9e6f commit e1d7abd

File tree

5 files changed

+1077
-0
lines changed

5 files changed

+1077
-0
lines changed

settings.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,5 @@ if (extraProjects.exists()) {
171171
addSubProjects('', extraProjectDir)
172172
}
173173
}
174+
175+
include 'test:external-modules:vector'
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
plugins {
11+
id 'application'
12+
}
13+
14+
ext {
15+
javaMainClass = "org.elasticsearch.test.knn.KnnIndexTester"
16+
}
17+
18+
application {
19+
mainClass.set(javaMainClass)
20+
}
21+
22+
dependencies {
23+
api "org.apache.lucene:lucene-core:${versions.lucene}"
24+
api "org.apache.lucene:lucene-codecs:${versions.lucene}"
25+
api "commons-logging:commons-logging:${versions.commonslogging}"
26+
api "commons-codec:commons-codec:${versions.commonscodec}"
27+
}
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.test.knn;
11+
12+
import org.apache.lucene.codecs.Codec;
13+
import org.apache.lucene.codecs.KnnVectorsFormat;
14+
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
15+
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
16+
import org.apache.lucene.index.VectorEncoding;
17+
import org.apache.lucene.index.VectorSimilarityFunction;
18+
import org.elasticsearch.common.logging.LogConfigurator;
19+
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
20+
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
21+
import org.elasticsearch.index.codec.vectors.IVFVectorsFormat;
22+
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat;
23+
import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
24+
25+
import java.nio.file.Path;
26+
import java.util.ArrayList;
27+
import java.util.List;
28+
29+
class KnnIndexTester {
30+
static {
31+
LogConfigurator.loadLog4jPlugins();
32+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
33+
}
34+
35+
static final String INDEX_DIR = "target/knn_index";
36+
37+
enum IndexType {
38+
HNSW,
39+
FLAT,
40+
IVF
41+
}
42+
43+
static class CmdLineArgs {
44+
static CmdLineArgs parse(String[] args) {
45+
CmdLineArgs cmdLineArgs = new CmdLineArgs();
46+
47+
for (String arg : args) {
48+
String[] parts = arg.split("=");
49+
String key = parts[0].trim();
50+
String value = null;
51+
if (parts.length > 1) {
52+
value = parts[1].trim();
53+
}
54+
if (parts.length > 2) {
55+
throw new IllegalArgumentException("Too many parts in argument: " + arg);
56+
}
57+
58+
switch (key) {
59+
case "--docVectors":
60+
cmdLineArgs.docVectors = Path.of(value);
61+
break;
62+
case "--queryVectors":
63+
cmdLineArgs.queryVectors = Path.of(value);
64+
break;
65+
case "--numDocs":
66+
cmdLineArgs.numDocs = Integer.parseInt(value);
67+
break;
68+
case "--numQueries":
69+
cmdLineArgs.numQueries = Integer.parseInt(value);
70+
break;
71+
case "--indexType":
72+
cmdLineArgs.indexType = IndexType.valueOf(value.toUpperCase());
73+
break;
74+
case "--numCandidates":
75+
cmdLineArgs.numCandidates = Integer.parseInt(value);
76+
break;
77+
case "--k":
78+
cmdLineArgs.k = Integer.parseInt(value);
79+
break;
80+
case "--nProbe":
81+
cmdLineArgs.nProbe = Integer.parseInt(value);
82+
break;
83+
case "--ivfClusterSize":
84+
cmdLineArgs.ivfClusterSize = Integer.parseInt(value);
85+
break;
86+
case "--overSamplingFactor":
87+
cmdLineArgs.overSamplingFactor = Integer.parseInt(value);
88+
break;
89+
case "--hnswM":
90+
cmdLineArgs.hnswM = Integer.parseInt(value);
91+
break;
92+
case "--hnswEfConstruction":
93+
cmdLineArgs.hnswEfConstruction = Integer.parseInt(value);
94+
break;
95+
case "--searchThreads":
96+
cmdLineArgs.searchThreads = Integer.parseInt(value);
97+
break;
98+
case "--indexThreads":
99+
cmdLineArgs.indexThreads = Integer.parseInt(value);
100+
break;
101+
case "--reindex":
102+
cmdLineArgs.reindex = true;
103+
break;
104+
case "--forceMerge":
105+
cmdLineArgs.forceMerge = true;
106+
break;
107+
case "--vectorSpace":
108+
cmdLineArgs.vectorSpace = VectorSimilarityFunction.valueOf(value.toUpperCase());
109+
break;
110+
case "--quantizeBits":
111+
cmdLineArgs.quantizeBits = Integer.parseInt(value);
112+
break;
113+
case "--vectorEncoding":
114+
cmdLineArgs.vectorEncoding = VectorEncoding.valueOf(value.toUpperCase());
115+
break;
116+
case "--dimensions":
117+
cmdLineArgs.dimensions = Integer.parseInt(value);
118+
break;
119+
default:
120+
throw new IllegalArgumentException("Unknown argument: " + key);
121+
}
122+
}
123+
return cmdLineArgs;
124+
}
125+
126+
int numDocs = 1000;
127+
int numQueries = 10;
128+
IndexType indexType = IndexType.IVF;
129+
int numCandidates = 100;
130+
int k = 100;
131+
int nProbe = -1;
132+
int ivfClusterSize = 384;
133+
int overSamplingFactor = 0;
134+
int hnswM = 16;
135+
int hnswEfConstruction = 100;
136+
int searchThreads = 1;
137+
int indexThreads = 1;
138+
boolean reindex = true;
139+
boolean forceMerge = false;
140+
VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN;
141+
// 32 means no quantization
142+
int quantizeBits = 32;
143+
int dimensions = 1024; // Default dimension size for vectors
144+
VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
145+
Path docVectors = null;
146+
Path queryVectors = null;
147+
}
148+
149+
private static String formatIndexPath(CmdLineArgs args) {
150+
List<String> suffix = new ArrayList<>();
151+
if (args.indexType == IndexType.FLAT) {
152+
suffix.add("flat");
153+
} else if (args.indexType == IndexType.IVF) {
154+
suffix.add("ivf");
155+
suffix.add(Integer.toString(args.ivfClusterSize));
156+
} else {
157+
suffix.add(Integer.toString(args.hnswM));
158+
suffix.add(Integer.toString(args.hnswEfConstruction));
159+
if (args.quantizeBits < 32) {
160+
suffix.add(Integer.toString(args.quantizeBits));
161+
}
162+
}
163+
return INDEX_DIR + "/" + args.docVectors.getFileName() + "-" + String.join("-", suffix) + ".index";
164+
}
165+
166+
static Codec createCodec(CmdLineArgs args) {
167+
final KnnVectorsFormat format;
168+
if (args.indexType == IndexType.IVF) {
169+
format = new IVFVectorsFormat(args.ivfClusterSize);
170+
} else {
171+
if (args.quantizeBits == 1) {
172+
if (args.indexType == IndexType.FLAT) {
173+
format = new ES818BinaryQuantizedVectorsFormat();
174+
} else {
175+
format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM, args.hnswEfConstruction, 1, null);
176+
}
177+
} else if (args.quantizeBits < 32) {
178+
if (args.indexType == IndexType.FLAT) {
179+
format = new ES813Int8FlatVectorFormat(null, args.quantizeBits, true);
180+
} else {
181+
format = new ES814HnswScalarQuantizedVectorsFormat(args.hnswM, args.hnswEfConstruction, null, args.quantizeBits, true);
182+
}
183+
} else {
184+
format = new Lucene99HnswVectorsFormat(args.hnswM, args.hnswEfConstruction, 1, null);
185+
}
186+
}
187+
return new Lucene101Codec() {
188+
@Override
189+
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
190+
return format;
191+
}
192+
};
193+
}
194+
195+
public static void main(String[] args) throws Exception {
196+
CmdLineArgs cmdLineArgs = CmdLineArgs.parse(args);
197+
if (cmdLineArgs.docVectors == null || cmdLineArgs.docVectors.toFile().exists() == false) {
198+
throw new IllegalArgumentException("Document vectors file does not exist: " + cmdLineArgs.docVectors);
199+
}
200+
Codec codec = createCodec(cmdLineArgs);
201+
Path indexPath = Path.of(formatIndexPath(cmdLineArgs));
202+
long indexCreationTimeMS = 0;
203+
long forceMergeTimeMS = 0;
204+
int numSegments = 1;
205+
StringBuilder resultHeaders = new StringBuilder();
206+
StringBuilder resultValues = new StringBuilder();
207+
// indicate params used for index creation
208+
resultHeaders.append("index_type,");
209+
resultValues.append(cmdLineArgs.indexType).append(",");
210+
resultHeaders.append("num_docs,");
211+
resultValues.append(cmdLineArgs.numDocs).append(",");
212+
if (cmdLineArgs.reindex || cmdLineArgs.forceMerge) {
213+
KnnIndexer knnIndexer = new KnnIndexer(
214+
cmdLineArgs.docVectors,
215+
indexPath,
216+
codec,
217+
cmdLineArgs.indexThreads,
218+
cmdLineArgs.vectorEncoding,
219+
cmdLineArgs.dimensions,
220+
cmdLineArgs.vectorSpace,
221+
cmdLineArgs.numDocs
222+
);
223+
if (cmdLineArgs.reindex) {
224+
indexCreationTimeMS = knnIndexer.createIndex();
225+
}
226+
if (cmdLineArgs.forceMerge) {
227+
forceMergeTimeMS = knnIndexer.forceMerge();
228+
} else {
229+
numSegments = knnIndexer.numSegments();
230+
}
231+
}
232+
if (indexCreationTimeMS > 0) {
233+
resultHeaders.append("index_time(ms)").append(",");
234+
resultValues.append(indexCreationTimeMS).append(",");
235+
}
236+
if (forceMergeTimeMS > 0) {
237+
resultHeaders.append("force_merge_time(ms)").append(",");
238+
resultValues.append(forceMergeTimeMS).append(",");
239+
}
240+
resultHeaders.append("num_segments,");
241+
resultValues.append(numSegments).append(",");
242+
243+
if (cmdLineArgs.queryVectors != null) {
244+
KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs);
245+
KnnSearcher.SearcherResults results = knnSearcher.runSearch();
246+
resultHeaders.append("latency(ms),");
247+
resultValues.append(results.avgLatency()).append(",");
248+
resultHeaders.append("qps,");
249+
resultValues.append(results.qps()).append(",");
250+
resultHeaders.append("recall,");
251+
resultValues.append(results.avgRecall()).append(",");
252+
resultHeaders.append("visited,");
253+
resultValues.append(results.averageVisited()).append(",");
254+
return;
255+
}
256+
System.out.println(resultHeaders);
257+
System.out.println(resultValues);
258+
}
259+
}

0 commit comments

Comments
 (0)