Skip to content

Commit fde4c0c

Browse files
authored
Adds support for more than one input doc file for Knnindextester (#131308)
1 parent 5759822 commit fde4c0c

File tree

4 files changed

+95
-67
lines changed

4 files changed

+95
-67
lines changed

qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
* This class encapsulates all the parameters required to run the KNN index tests.
3030
*/
3131
record CmdLineArgs(
32-
Path docVectors,
32+
List<Path> docVectors,
3333
Path queryVectors,
3434
int numDocs,
3535
int numQueries,
@@ -88,7 +88,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
8888
static final ObjectParser<CmdLineArgs.Builder, Void> PARSER = new ObjectParser<>("cmd_line_args", true, Builder::new);
8989

9090
static {
91-
PARSER.declareString(Builder::setDocVectors, DOC_VECTORS_FIELD);
91+
PARSER.declareStringArray(Builder::setDocVectors, DOC_VECTORS_FIELD);
9292
PARSER.declareString(Builder::setQueryVectors, QUERY_VECTORS_FIELD);
9393
PARSER.declareInt(Builder::setNumDocs, NUM_DOCS_FIELD);
9494
PARSER.declareInt(Builder::setNumQueries, NUM_QUERIES_FIELD);
@@ -118,7 +118,8 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
118118
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
119119
builder.startObject();
120120
if (docVectors != null) {
121-
builder.field(DOC_VECTORS_FIELD.getPreferredName(), docVectors.toString());
121+
List<String> docVectorsStrings = docVectors.stream().map(Path::toString).toList();
122+
builder.field(DOC_VECTORS_FIELD.getPreferredName(), docVectorsStrings);
122123
}
123124
if (queryVectors != null) {
124125
builder.field(QUERY_VECTORS_FIELD.getPreferredName(), queryVectors.toString());
@@ -154,7 +155,7 @@ public String toString() {
154155
}
155156

156157
static class Builder {
157-
private Path docVectors;
158+
private List<Path> docVectors;
158159
private Path queryVectors;
159160
private int numDocs = 1000;
160161
private int numQueries = 100;
@@ -179,8 +180,12 @@ static class Builder {
179180
private float filterSelectivity = 1f;
180181
private long seed = 1751900822751L;
181182

182-
public Builder setDocVectors(String docVectors) {
183-
this.docVectors = PathUtils.get(docVectors);
183+
public Builder setDocVectors(List<String> docVectors) {
184+
if (docVectors == null || docVectors.isEmpty()) {
185+
throw new IllegalArgumentException("Document vectors path must be provided");
186+
}
187+
// Convert list of strings to list of Paths
188+
this.docVectors = docVectors.stream().map(PathUtils::get).toList();
184189
return this;
185190
}
186191

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ private static String formatIndexPath(CmdLineArgs args) {
8383
suffix.add(Integer.toString(args.quantizeBits()));
8484
}
8585
}
86-
return INDEX_DIR + "/" + args.docVectors().getFileName() + "-" + String.join("-", suffix) + ".index";
86+
return INDEX_DIR + "/" + args.docVectors().get(0).getFileName() + "-" + String.join("-", suffix) + ".index";
8787
}
8888

8989
static Codec createCodec(CmdLineArgs args) {
@@ -137,7 +137,7 @@ public static void main(String[] args) throws Exception {
137137
System.out.println(
138138
Strings.toString(
139139
new CmdLineArgs.Builder().setDimensions(64)
140-
.setDocVectors("/doc/vectors/path")
140+
.setDocVectors(List.of("/doc/vectors/path"))
141141
.setQueryVectors("/query/vectors/path")
142142
.build(),
143143
true,
@@ -179,15 +179,15 @@ public static void main(String[] args) throws Exception {
179179
: new int[] { 0 };
180180
String indexType = cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT);
181181
Results indexResults = new Results(
182-
cmdLineArgs.docVectors().getFileName().toString(),
182+
cmdLineArgs.docVectors().get(0).getFileName().toString(),
183183
indexType,
184184
cmdLineArgs.numDocs(),
185185
cmdLineArgs.filterSelectivity()
186186
);
187187
Results[] results = new Results[nProbes.length];
188188
for (int i = 0; i < nProbes.length; i++) {
189189
results[i] = new Results(
190-
cmdLineArgs.docVectors().getFileName().toString(),
190+
cmdLineArgs.docVectors().get(0).getFileName().toString(),
191191
indexType,
192192
cmdLineArgs.numDocs(),
193193
cmdLineArgs.filterSelectivity()

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java

Lines changed: 62 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class KnnIndexer {
6161
static final String ID_FIELD = "id";
6262
static final String VECTOR_FIELD = "vector";
6363

64-
private final Path docsPath;
64+
private final List<Path> docsPath;
6565
private final Path indexPath;
6666
private final VectorEncoding vectorEncoding;
6767
private int dim;
@@ -71,7 +71,7 @@ class KnnIndexer {
7171
private final int numIndexThreads;
7272

7373
KnnIndexer(
74-
Path docsPath,
74+
List<Path> docsPath,
7575
Path indexPath,
7676
Codec codec,
7777
int numIndexThreads,
@@ -127,57 +127,70 @@ public boolean isEnabled(String component) {
127127
}
128128

129129
long start = System.nanoTime();
130-
try (
131-
FSDirectory dir = FSDirectory.open(indexPath);
132-
IndexWriter iw = new IndexWriter(dir, iwc);
133-
FileChannel in = FileChannel.open(docsPath)
134-
) {
135-
long docsPathSizeInBytes = in.size();
136-
int offsetByteSize = 0;
137-
if (dim == -1) {
138-
offsetByteSize = 4;
139-
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
140-
int bytesRead = Channels.readFromFileChannel(in, 0, preamble);
141-
if (bytesRead < 4) {
142-
throw new IllegalArgumentException(
143-
"docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes
130+
AtomicInteger numDocsIndexed = new AtomicInteger();
131+
try (FSDirectory dir = FSDirectory.open(indexPath); IndexWriter iw = new IndexWriter(dir, iwc);) {
132+
for (Path docsPath : this.docsPath) {
133+
int dim = this.dim;
134+
try (FileChannel in = FileChannel.open(docsPath)) {
135+
long docsPathSizeInBytes = in.size();
136+
int offsetByteSize = 0;
137+
if (dim == -1) {
138+
offsetByteSize = 4;
139+
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
140+
int bytesRead = Channels.readFromFileChannel(in, 0, preamble);
141+
if (bytesRead < 4) {
142+
throw new IllegalArgumentException(
143+
"docsPath \"" + docsPath + "\" does not contain a valid dims? size=" + docsPathSizeInBytes
144+
);
145+
}
146+
dim = preamble.getInt(0);
147+
if (dim <= 0) {
148+
throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim);
149+
}
150+
}
151+
FieldType fieldType = switch (vectorEncoding) {
152+
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
153+
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
154+
};
155+
if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) {
156+
throw new IllegalArgumentException(
157+
"docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes
158+
);
159+
}
160+
int numDocs = (int) (docsPathSizeInBytes / ((long) dim * vectorEncoding.byteSize + offsetByteSize));
161+
numDocs = Math.min(this.numDocs - numDocsIndexed.get(), numDocs);
162+
if (numDocs <= 0) {
163+
break;
164+
}
165+
logger.info(
166+
"path={}, docsPathSizeInBytes={}, numDocs={}, dim={}, vectorEncoding={}, byteSize={}",
167+
docsPath,
168+
docsPathSizeInBytes,
169+
numDocs,
170+
dim,
171+
vectorEncoding,
172+
vectorEncoding.byteSize
144173
);
145-
}
146-
dim = preamble.getInt(0);
147-
if (dim <= 0) {
148-
throw new IllegalArgumentException("docsPath \"" + docsPath + "\" has invalid dimension: " + dim);
149-
}
150-
}
151-
FieldType fieldType = switch (vectorEncoding) {
152-
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
153-
case FLOAT32 -> KnnFloatVectorField.createFieldType(dim, similarityFunction);
154-
};
155-
if (docsPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) {
156-
throw new IllegalArgumentException(
157-
"docsPath \"" + docsPath + "\" does not contain a whole number of vectors? size=" + docsPathSizeInBytes
158-
);
159-
}
160-
logger.info(
161-
"docsPathSizeInBytes={}, dim={}, vectorEncoding={}, byteSize={}",
162-
docsPathSizeInBytes,
163-
dim,
164-
vectorEncoding,
165-
vectorEncoding.byteSize
166-
);
174+
// adjust numDocs to account for the number of documents already indexed
175+
// numDocsIndexed tracks the total docs read in order and is used for docIds
176+
// numDocs is the total number of docs to index from this file
177+
numDocs += numDocsIndexed.get();
167178

168-
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize);
169-
try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) {
170-
AtomicInteger numDocsIndexed = new AtomicInteger();
171-
List<Future<?>> threads = new ArrayList<>();
172-
for (int i = 0; i < numIndexThreads; i++) {
173-
Thread t = new IndexerThread(iw, inReader, dim, vectorEncoding, fieldType, numDocsIndexed, numDocs);
174-
t.setDaemon(true);
175-
threads.add(exec.submit(t));
176-
}
177-
for (Future<?> t : threads) {
178-
t.get();
179+
VectorReader inReader = VectorReader.create(in, dim, vectorEncoding, offsetByteSize);
180+
try (ExecutorService exec = Executors.newFixedThreadPool(numIndexThreads, r -> new Thread(r, "KnnIndexer-Thread"))) {
181+
List<Future<?>> threads = new ArrayList<>();
182+
for (int i = 0; i < numIndexThreads; i++) {
183+
Thread t = new IndexerThread(iw, inReader, dim, vectorEncoding, fieldType, numDocsIndexed, numDocs);
184+
t.setDaemon(true);
185+
threads.add(exec.submit(t));
186+
}
187+
for (Future<?> t : threads) {
188+
t.get();
189+
}
190+
}
179191
}
180192
}
193+
logger.info("KnnIndexer: indexed {} documents of desired {} numDocs", numDocsIndexed, numDocs);
181194
logger.debug("all indexing threads finished, now IndexWriter.commit()");
182195
iw.commit();
183196
ConcurrentMergeScheduler cms = (ConcurrentMergeScheduler) iwc.getMergeScheduler();

qa/vector/src/main/java/org/elasticsearch/test/knn/KnnSearcher.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@
9898

9999
class KnnSearcher {
100100

101-
private final Path docPath;
101+
private final List<Path> docPath;
102102
private final Path indexPath;
103103
private final Path queryPath;
104104
private final int numDocs;
@@ -153,12 +153,6 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
153153
: null
154154
) {
155155
long queryPathSizeInBytes = input.size();
156-
logger.info(
157-
"queryPath size: "
158-
+ queryPathSizeInBytes
159-
+ " bytes, assuming vector count is "
160-
+ (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize))
161-
);
162156
if (dim == -1) {
163157
offsetByteSize = 4;
164158
ByteBuffer preamble = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN);
@@ -171,6 +165,17 @@ void runSearch(KnnIndexTester.Results finalResults, boolean earlyTermination) th
171165
throw new IllegalArgumentException("queryPath \"" + queryPath + "\" has invalid dimension: " + dim);
172166
}
173167
}
168+
if (queryPathSizeInBytes % (((long) dim * vectorEncoding.byteSize + offsetByteSize)) != 0) {
169+
throw new IllegalArgumentException(
170+
"docsPath \"" + queryPath + "\" does not contain a whole number of vectors? size=" + queryPathSizeInBytes
171+
);
172+
}
173+
logger.info(
174+
"queryPath size: "
175+
+ queryPathSizeInBytes
176+
+ " bytes, assuming vector count is "
177+
+ (queryPathSizeInBytes / ((long) dim * vectorEncoding.byteSize + offsetByteSize))
178+
);
174179
KnnIndexer.VectorReader targetReader = KnnIndexer.VectorReader.create(input, dim, vectorEncoding, offsetByteSize);
175180
long startNS;
176181
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
@@ -368,8 +373,13 @@ private int[][] getOrCalculateExactNN(int vectorFileOffsetBytes, Query filterQue
368373
}
369374
}
370375

371-
private boolean isNewer(Path path, Path... others) throws IOException {
376+
private boolean isNewer(Path path, List<Path> paths, Path... others) throws IOException {
372377
FileTime modified = Files.getLastModifiedTime(path);
378+
for (Path p : paths) {
379+
if (Files.getLastModifiedTime(p).compareTo(modified) >= 0) {
380+
return false;
381+
}
382+
}
373383
for (Path other : others) {
374384
if (Files.getLastModifiedTime(other).compareTo(modified) >= 0) {
375385
return false;

0 commit comments

Comments
 (0)