diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java index f51c550e5292e..85fa02aecaaef 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/CmdLineArgs.java @@ -52,7 +52,8 @@ record CmdLineArgs( int quantizeBits, VectorEncoding vectorEncoding, int dimensions, - boolean earlyTermination + boolean earlyTermination, + KnnIndexTester.MergePolicyType mergePolicy ) implements ToXContentObject { static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors"); @@ -79,6 +80,7 @@ record CmdLineArgs( static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination"); static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity"); static final ParseField SEED_FIELD = new ParseField("seed"); + static final ParseField MERGE_POLICY_FIELD = new ParseField("merge_policy"); static CmdLineArgs fromXContent(XContentParser parser) throws IOException { Builder builder = PARSER.apply(parser, null); @@ -112,6 +114,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException { PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD); PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD); PARSER.declareLong(Builder::setSeed, SEED_FIELD); + PARSER.declareString(Builder::setMergePolicy, MERGE_POLICY_FIELD); } @Override @@ -179,6 +182,7 @@ static class Builder { private boolean earlyTermination; private float filterSelectivity = 1f; private long seed = 1751900822751L; + private KnnIndexTester.MergePolicyType mergePolicy = null; public Builder setDocVectors(List docVectors) { if (docVectors == null || docVectors.isEmpty()) { @@ -304,6 +308,11 @@ public Builder setSeed(long seed) { return this; } + public Builder setMergePolicy(String mergePolicy) { + this.mergePolicy = KnnIndexTester.MergePolicyType.valueOf(mergePolicy.toUpperCase(Locale.ROOT)); + return this; + } + public CmdLineArgs build() { if (docVectors == null) { throw new IllegalArgumentException("Document vectors path must be provided"); @@ -337,7 +346,8 @@ public CmdLineArgs build() { quantizeBits, vectorEncoding, dimensions, - earlyTermination + earlyTermination, + mergePolicy ); } } diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index c4b0ccdfe35e3..17257dcb73d59 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -15,6 +15,11 @@ import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.lucene101.Lucene101Codec; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.index.LogByteSizeMergePolicy; +import org.apache.lucene.index.LogDocMergePolicy; +import org.apache.lucene.index.MergePolicy; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.index.TieredMergePolicy; import org.elasticsearch.cli.ProcessInfo; import org.elasticsearch.common.Strings; import org.elasticsearch.common.logging.LogConfigurator; @@ -69,6 +74,13 @@ enum IndexType { IVF } + enum MergePolicyType { + TIERED, + LOG_BYTE, + NO, + LOG_DOC + } + private static String formatIndexPath(CmdLineArgs args) { List suffix = new ArrayList<>(); if (args.indexType() == IndexType.FLAT) { @@ -196,6 +208,7 @@ public static void main(String[] args) throws Exception { logger.info("Running KNN index tester with arguments: " + cmdLineArgs); Codec codec = createCodec(cmdLineArgs); Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs)); + MergePolicy mergePolicy = getMergePolicy(cmdLineArgs); if (cmdLineArgs.reindex() || cmdLineArgs.forceMerge()) { KnnIndexer knnIndexer = new KnnIndexer( cmdLineArgs.docVectors(), @@ -205,7 +218,8 @@ public static void main(String[] args) throws Exception { cmdLineArgs.vectorEncoding(), cmdLineArgs.dimensions(), cmdLineArgs.vectorSpace(), - cmdLineArgs.numDocs() + cmdLineArgs.numDocs(), + mergePolicy ); if (cmdLineArgs.reindex() == false && Files.exists(indexPath) == false) { throw new IllegalArgumentException("Index path does not exist: " + indexPath); @@ -232,6 +246,24 @@ public static void main(String[] args) throws Exception { logger.info("Results: \n" + formattedResults); } + private static MergePolicy getMergePolicy(CmdLineArgs args) { + MergePolicy mergePolicy = null; + if (args.mergePolicy() != null) { + if (args.mergePolicy() == MergePolicyType.TIERED) { + mergePolicy = new TieredMergePolicy(); + } else if (args.mergePolicy() == MergePolicyType.LOG_BYTE) { + mergePolicy = new LogByteSizeMergePolicy(); + } else if (args.mergePolicy() == MergePolicyType.NO) { + mergePolicy = NoMergePolicy.INSTANCE; + } else if (args.mergePolicy() == MergePolicyType.LOG_DOC) { + mergePolicy = new LogDocMergePolicy(); + } else { + throw new IllegalArgumentException("Invalid merge policy: " + args.mergePolicy()); + } + } + return mergePolicy; + } + static class FormattedResults { List indexResults = new ArrayList<>(); List queryResults = new ArrayList<>(); diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java index f7d00c9806c8d..aa8792bb2c4a5 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexer.java @@ -31,6 +31,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.MergePolicy; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.FSDirectory; @@ -69,6 +70,7 @@ class KnnIndexer { private final Codec codec; private final int numDocs; private final int numIndexThreads; + private final MergePolicy mergePolicy; KnnIndexer( List docsPath, @@ -78,7 +80,8 @@ class KnnIndexer { VectorEncoding vectorEncoding, int dim, VectorSimilarityFunction similarityFunction, - int numDocs + int numDocs, + MergePolicy mergePolicy ) { this.docsPath = docsPath; this.indexPath = indexPath; @@ -88,6 +91,7 @@ class KnnIndexer { this.dim = dim; this.similarityFunction = similarityFunction; this.numDocs = numDocs; + this.mergePolicy = mergePolicy; } void numSegments(KnnIndexTester.Results result) { @@ -103,7 +107,9 @@ void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedE iwc.setCodec(codec); iwc.setRAMBufferSizeMB(WRITER_BUFFER_MB); iwc.setUseCompoundFile(false); - + if (mergePolicy != null) { + iwc.setMergePolicy(mergePolicy); + } iwc.setMaxFullFlushMergeWaitMillis(0); iwc.setInfoStream(new PrintStreamInfoStream(System.out) {