Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ record CmdLineArgs(
int quantizeBits,
VectorEncoding vectorEncoding,
int dimensions,
boolean earlyTermination
boolean earlyTermination,
KnnIndexTester.MergePolicyType mergePolicy
) implements ToXContentObject {

static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors");
Expand All @@ -79,6 +80,7 @@ record CmdLineArgs(
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
static final ParseField FILTER_SELECTIVITY_FIELD = new ParseField("filter_selectivity");
static final ParseField SEED_FIELD = new ParseField("seed");
static final ParseField MERGE_POLICY_FIELD = new ParseField("merge_policy");

static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
Builder builder = PARSER.apply(parser, null);
Expand Down Expand Up @@ -112,6 +114,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
PARSER.declareBoolean(Builder::setEarlyTermination, EARLY_TERMINATION_FIELD);
PARSER.declareFloat(Builder::setFilterSelectivity, FILTER_SELECTIVITY_FIELD);
PARSER.declareLong(Builder::setSeed, SEED_FIELD);
PARSER.declareString(Builder::setMergePolicy, MERGE_POLICY_FIELD);
}

@Override
Expand Down Expand Up @@ -179,6 +182,7 @@ static class Builder {
private boolean earlyTermination;
private float filterSelectivity = 1f;
private long seed = 1751900822751L;
private KnnIndexTester.MergePolicyType mergePolicy = null;

public Builder setDocVectors(List<String> docVectors) {
if (docVectors == null || docVectors.isEmpty()) {
Expand Down Expand Up @@ -304,6 +308,11 @@ public Builder setSeed(long seed) {
return this;
}

public Builder setMergePolicy(String mergePolicy) {
this.mergePolicy = KnnIndexTester.MergePolicyType.valueOf(mergePolicy.toUpperCase(Locale.ROOT));
return this;
}

public CmdLineArgs build() {
if (docVectors == null) {
throw new IllegalArgumentException("Document vectors path must be provided");
Expand Down Expand Up @@ -337,7 +346,8 @@ public CmdLineArgs build() {
quantizeBits,
vectorEncoding,
dimensions,
earlyTermination
earlyTermination,
mergePolicy
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.index.LogByteSizeMergePolicy;
import org.apache.lucene.index.LogDocMergePolicy;
import org.apache.lucene.index.MergePolicy;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.TieredMergePolicy;
import org.elasticsearch.cli.ProcessInfo;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.logging.LogConfigurator;
Expand Down Expand Up @@ -69,6 +74,13 @@ enum IndexType {
IVF
}

enum MergePolicyType {
TIERED,
LOG_BYTE,
NO,
LOG_DOC
}

private static String formatIndexPath(CmdLineArgs args) {
List<String> suffix = new ArrayList<>();
if (args.indexType() == IndexType.FLAT) {
Expand Down Expand Up @@ -196,6 +208,7 @@ public static void main(String[] args) throws Exception {
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
Codec codec = createCodec(cmdLineArgs);
Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs));
MergePolicy mergePolicy = getMergePolicy(cmdLineArgs);
if (cmdLineArgs.reindex() || cmdLineArgs.forceMerge()) {
KnnIndexer knnIndexer = new KnnIndexer(
cmdLineArgs.docVectors(),
Expand All @@ -205,7 +218,8 @@ public static void main(String[] args) throws Exception {
cmdLineArgs.vectorEncoding(),
cmdLineArgs.dimensions(),
cmdLineArgs.vectorSpace(),
cmdLineArgs.numDocs()
cmdLineArgs.numDocs(),
mergePolicy
);
if (cmdLineArgs.reindex() == false && Files.exists(indexPath) == false) {
throw new IllegalArgumentException("Index path does not exist: " + indexPath);
Expand All @@ -232,6 +246,24 @@ public static void main(String[] args) throws Exception {
logger.info("Results: \n" + formattedResults);
}

private static MergePolicy getMergePolicy(CmdLineArgs args) {
MergePolicy mergePolicy = null;
if (args.mergePolicy() != null) {
if (args.mergePolicy() == MergePolicyType.TIERED) {
mergePolicy = new TieredMergePolicy();
} else if (args.mergePolicy() == MergePolicyType.LOG_BYTE) {
mergePolicy = new LogByteSizeMergePolicy();
} else if (args.mergePolicy() == MergePolicyType.NO) {
mergePolicy = NoMergePolicy.INSTANCE;
} else if (args.mergePolicy() == MergePolicyType.LOG_DOC) {
mergePolicy = new LogDocMergePolicy();
} else {
throw new IllegalArgumentException("Invalid merge policy: " + args.mergePolicy());
}
}
return mergePolicy;
}

static class FormattedResults {
List<Results> indexResults = new ArrayList<>();
List<Results> queryResults = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.MergePolicy;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.FSDirectory;
Expand Down Expand Up @@ -69,6 +70,7 @@ class KnnIndexer {
private final Codec codec;
private final int numDocs;
private final int numIndexThreads;
private final MergePolicy mergePolicy;

KnnIndexer(
List<Path> docsPath,
Expand All @@ -78,7 +80,8 @@ class KnnIndexer {
VectorEncoding vectorEncoding,
int dim,
VectorSimilarityFunction similarityFunction,
int numDocs
int numDocs,
MergePolicy mergePolicy
) {
this.docsPath = docsPath;
this.indexPath = indexPath;
Expand All @@ -88,6 +91,7 @@ class KnnIndexer {
this.dim = dim;
this.similarityFunction = similarityFunction;
this.numDocs = numDocs;
this.mergePolicy = mergePolicy;
}

void numSegments(KnnIndexTester.Results result) {
Expand All @@ -103,7 +107,9 @@ void createIndex(KnnIndexTester.Results result) throws IOException, InterruptedE
iwc.setCodec(codec);
iwc.setRAMBufferSizeMB(WRITER_BUFFER_MB);
iwc.setUseCompoundFile(false);

if (mergePolicy != null) {
iwc.setMergePolicy(mergePolicy);
}
iwc.setMaxFullFlushMergeWaitMillis(0);

iwc.setInfoStream(new PrintStreamInfoStream(System.out) {
Expand Down