|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the "Elastic License |
| 4 | + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side |
| 5 | + * Public License v 1"; you may not use this file except in compliance with, at |
| 6 | + * your election, the "Elastic License 2.0", the "GNU Affero General Public |
| 7 | + * License v3.0 only", or the "Server Side Public License, v 1". |
| 8 | + */ |
| 9 | + |
| 10 | +package org.elasticsearch.test.knn; |
| 11 | + |
| 12 | +import org.apache.lucene.index.VectorEncoding; |
| 13 | +import org.apache.lucene.index.VectorSimilarityFunction; |
| 14 | +import org.elasticsearch.common.Strings; |
| 15 | +import org.elasticsearch.xcontent.ObjectParser; |
| 16 | +import org.elasticsearch.xcontent.ParseField; |
| 17 | +import org.elasticsearch.xcontent.ToXContentObject; |
| 18 | +import org.elasticsearch.xcontent.XContentBuilder; |
| 19 | +import org.elasticsearch.xcontent.XContentParser; |
| 20 | + |
| 21 | +import java.io.IOException; |
| 22 | +import java.nio.file.Path; |
| 23 | +import java.util.Locale; |
| 24 | + |
| 25 | +/** |
| 26 | + * Command line arguments for the KNN index tester. |
| 27 | + * This class encapsulates all the parameters required to run the KNN index tests. |
| 28 | + */ |
| 29 | +record CmdLineArgs( |
| 30 | + Path docVectors, |
| 31 | + Path queryVectors, |
| 32 | + int numDocs, |
| 33 | + int numQueries, |
| 34 | + KnnIndexTester.IndexType indexType, |
| 35 | + int numCandidates, |
| 36 | + int k, |
| 37 | + int nProbe, |
| 38 | + int ivfClusterSize, |
| 39 | + int overSamplingFactor, |
| 40 | + int hnswM, |
| 41 | + int hnswEfConstruction, |
| 42 | + int searchThreads, |
| 43 | + int indexThreads, |
| 44 | + boolean reindex, |
| 45 | + boolean forceMerge, |
| 46 | + VectorSimilarityFunction vectorSpace, |
| 47 | + int quantizeBits, |
| 48 | + VectorEncoding vectorEncoding, |
| 49 | + int dimensions |
| 50 | +) implements ToXContentObject { |
| 51 | + |
| 52 | + static final ParseField DOC_VECTORS_FIELD = new ParseField("doc_vectors"); |
| 53 | + static final ParseField QUERY_VECTORS_FIELD = new ParseField("query_vectors"); |
| 54 | + static final ParseField NUM_DOCS_FIELD = new ParseField("num_docs"); |
| 55 | + static final ParseField NUM_QUERIES_FIELD = new ParseField("num_queries"); |
| 56 | + static final ParseField INDEX_TYPE_FIELD = new ParseField("index_type"); |
| 57 | + static final ParseField NUM_CANDIDATES_FIELD = new ParseField("num_candidates"); |
| 58 | + static final ParseField K_FIELD = new ParseField("k"); |
| 59 | + static final ParseField N_PROBE_FIELD = new ParseField("n_probe"); |
| 60 | + static final ParseField IVF_CLUSTER_SIZE_FIELD = new ParseField("ivf_cluster_size"); |
| 61 | + static final ParseField OVER_SAMPLING_FACTOR_FIELD = new ParseField("over_sampling_factor"); |
| 62 | + static final ParseField HNSW_M_FIELD = new ParseField("hnsw_m"); |
| 63 | + static final ParseField HNSW_EF_CONSTRUCTION_FIELD = new ParseField("hnsw_ef_construction"); |
| 64 | + static final ParseField SEARCH_THREADS_FIELD = new ParseField("search_threads"); |
| 65 | + static final ParseField INDEX_THREADS_FIELD = new ParseField("index_threads"); |
| 66 | + static final ParseField REINDEX_FIELD = new ParseField("reindex"); |
| 67 | + static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge"); |
| 68 | + static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space"); |
| 69 | + static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits"); |
| 70 | + static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding"); |
| 71 | + static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions"); |
| 72 | + |
| 73 | + static CmdLineArgs fromXContent(XContentParser parser) throws IOException { |
| 74 | + Builder builder = PARSER.apply(parser, null); |
| 75 | + return builder.build(); |
| 76 | + } |
| 77 | + |
| 78 | + static final ObjectParser<CmdLineArgs.Builder, Void> PARSER = new ObjectParser<>("cmd_line_args", true, Builder::new); |
| 79 | + |
| 80 | + static { |
| 81 | + PARSER.declareString(Builder::setDocVectors, DOC_VECTORS_FIELD); |
| 82 | + PARSER.declareString(Builder::setQueryVectors, QUERY_VECTORS_FIELD); |
| 83 | + PARSER.declareInt(Builder::setNumDocs, NUM_DOCS_FIELD); |
| 84 | + PARSER.declareInt(Builder::setNumQueries, NUM_QUERIES_FIELD); |
| 85 | + PARSER.declareString(Builder::setIndexType, INDEX_TYPE_FIELD); |
| 86 | + PARSER.declareInt(Builder::setNumCandidates, NUM_CANDIDATES_FIELD); |
| 87 | + PARSER.declareInt(Builder::setK, K_FIELD); |
| 88 | + PARSER.declareInt(Builder::setNProbe, N_PROBE_FIELD); |
| 89 | + PARSER.declareInt(Builder::setIvfClusterSize, IVF_CLUSTER_SIZE_FIELD); |
| 90 | + PARSER.declareInt(Builder::setOverSamplingFactor, OVER_SAMPLING_FACTOR_FIELD); |
| 91 | + PARSER.declareInt(Builder::setHnswM, HNSW_M_FIELD); |
| 92 | + PARSER.declareInt(Builder::setHnswEfConstruction, HNSW_EF_CONSTRUCTION_FIELD); |
| 93 | + PARSER.declareInt(Builder::setSearchThreads, SEARCH_THREADS_FIELD); |
| 94 | + PARSER.declareInt(Builder::setIndexThreads, INDEX_THREADS_FIELD); |
| 95 | + PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD); |
| 96 | + PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD); |
| 97 | + PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD); |
| 98 | + PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD); |
| 99 | + PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD); |
| 100 | + PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD); |
| 101 | + } |
| 102 | + |
| 103 | + @Override |
| 104 | + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { |
| 105 | + builder.startObject(); |
| 106 | + if (docVectors != null) { |
| 107 | + builder.field(DOC_VECTORS_FIELD.getPreferredName(), docVectors.toString()); |
| 108 | + } |
| 109 | + if (queryVectors != null) { |
| 110 | + builder.field(QUERY_VECTORS_FIELD.getPreferredName(), queryVectors.toString()); |
| 111 | + } |
| 112 | + builder.field(NUM_DOCS_FIELD.getPreferredName(), numDocs); |
| 113 | + builder.field(NUM_QUERIES_FIELD.getPreferredName(), numQueries); |
| 114 | + builder.field(INDEX_TYPE_FIELD.getPreferredName(), indexType.name().toLowerCase(Locale.ROOT)); |
| 115 | + builder.field(NUM_CANDIDATES_FIELD.getPreferredName(), numCandidates); |
| 116 | + builder.field(K_FIELD.getPreferredName(), k); |
| 117 | + builder.field(N_PROBE_FIELD.getPreferredName(), nProbe); |
| 118 | + builder.field(IVF_CLUSTER_SIZE_FIELD.getPreferredName(), ivfClusterSize); |
| 119 | + builder.field(OVER_SAMPLING_FACTOR_FIELD.getPreferredName(), overSamplingFactor); |
| 120 | + builder.field(HNSW_M_FIELD.getPreferredName(), hnswM); |
| 121 | + builder.field(HNSW_EF_CONSTRUCTION_FIELD.getPreferredName(), hnswEfConstruction); |
| 122 | + builder.field(SEARCH_THREADS_FIELD.getPreferredName(), searchThreads); |
| 123 | + builder.field(INDEX_THREADS_FIELD.getPreferredName(), indexThreads); |
| 124 | + builder.field(REINDEX_FIELD.getPreferredName(), reindex); |
| 125 | + builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge); |
| 126 | + builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT)); |
| 127 | + builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits); |
| 128 | + builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT)); |
| 129 | + builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions); |
| 130 | + return builder.endObject(); |
| 131 | + } |
| 132 | + |
| 133 | + @Override |
| 134 | + public String toString() { |
| 135 | + return Strings.toString(this, false, false); |
| 136 | + } |
| 137 | + |
| 138 | + static class Builder { |
| 139 | + private Path docVectors; |
| 140 | + private Path queryVectors; |
| 141 | + private int numDocs = 1000; |
| 142 | + private int numQueries = 100; |
| 143 | + private KnnIndexTester.IndexType indexType = KnnIndexTester.IndexType.HNSW; |
| 144 | + private int numCandidates = 1000; |
| 145 | + private int k = 10; |
| 146 | + private int nProbe = 10; |
| 147 | + private int ivfClusterSize = 1000; |
| 148 | + private int overSamplingFactor = 1; |
| 149 | + private int hnswM = 16; |
| 150 | + private int hnswEfConstruction = 200; |
| 151 | + private int searchThreads = 1; |
| 152 | + private int indexThreads = 1; |
| 153 | + private boolean reindex = false; |
| 154 | + private boolean forceMerge = false; |
| 155 | + private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN; |
| 156 | + private int quantizeBits = 8; |
| 157 | + private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; |
| 158 | + private int dimensions; |
| 159 | + |
| 160 | + public Builder setDocVectors(String docVectors) { |
| 161 | + this.docVectors = Path.of(docVectors); |
| 162 | + return this; |
| 163 | + } |
| 164 | + |
| 165 | + public Builder setQueryVectors(String queryVectors) { |
| 166 | + this.queryVectors = Path.of(queryVectors); |
| 167 | + return this; |
| 168 | + } |
| 169 | + |
| 170 | + public Builder setNumDocs(int numDocs) { |
| 171 | + this.numDocs = numDocs; |
| 172 | + return this; |
| 173 | + } |
| 174 | + |
| 175 | + public Builder setNumQueries(int numQueries) { |
| 176 | + this.numQueries = numQueries; |
| 177 | + return this; |
| 178 | + } |
| 179 | + |
| 180 | + public Builder setIndexType(String indexType) { |
| 181 | + this.indexType = KnnIndexTester.IndexType.valueOf(indexType.toUpperCase(Locale.ROOT)); |
| 182 | + return this; |
| 183 | + } |
| 184 | + |
| 185 | + public Builder setNumCandidates(int numCandidates) { |
| 186 | + this.numCandidates = numCandidates; |
| 187 | + return this; |
| 188 | + } |
| 189 | + |
| 190 | + public Builder setK(int k) { |
| 191 | + this.k = k; |
| 192 | + return this; |
| 193 | + } |
| 194 | + |
| 195 | + public Builder setNProbe(int nProbe) { |
| 196 | + this.nProbe = nProbe; |
| 197 | + return this; |
| 198 | + } |
| 199 | + |
| 200 | + public Builder setIvfClusterSize(int ivfClusterSize) { |
| 201 | + this.ivfClusterSize = ivfClusterSize; |
| 202 | + return this; |
| 203 | + } |
| 204 | + |
| 205 | + public Builder setOverSamplingFactor(int overSamplingFactor) { |
| 206 | + this.overSamplingFactor = overSamplingFactor; |
| 207 | + return this; |
| 208 | + } |
| 209 | + |
| 210 | + public Builder setHnswM(int hnswM) { |
| 211 | + this.hnswM = hnswM; |
| 212 | + return this; |
| 213 | + } |
| 214 | + |
| 215 | + public Builder setHnswEfConstruction(int hnswEfConstruction) { |
| 216 | + this.hnswEfConstruction = hnswEfConstruction; |
| 217 | + return this; |
| 218 | + } |
| 219 | + |
| 220 | + public Builder setSearchThreads(int searchThreads) { |
| 221 | + this.searchThreads = searchThreads; |
| 222 | + return this; |
| 223 | + } |
| 224 | + |
| 225 | + public Builder setIndexThreads(int indexThreads) { |
| 226 | + this.indexThreads = indexThreads; |
| 227 | + return this; |
| 228 | + } |
| 229 | + |
| 230 | + public Builder setReindex(boolean reindex) { |
| 231 | + this.reindex = reindex; |
| 232 | + return this; |
| 233 | + } |
| 234 | + |
| 235 | + public Builder setForceMerge(boolean forceMerge) { |
| 236 | + this.forceMerge = forceMerge; |
| 237 | + return this; |
| 238 | + } |
| 239 | + |
| 240 | + public Builder setVectorSpace(String vectorSpace) { |
| 241 | + this.vectorSpace = VectorSimilarityFunction.valueOf(vectorSpace.toUpperCase(Locale.ROOT)); |
| 242 | + return this; |
| 243 | + } |
| 244 | + |
| 245 | + public Builder setQuantizeBits(int quantizeBits) { |
| 246 | + this.quantizeBits = quantizeBits; |
| 247 | + return this; |
| 248 | + } |
| 249 | + |
| 250 | + public Builder setVectorEncoding(String vectorEncoding) { |
| 251 | + this.vectorEncoding = VectorEncoding.valueOf(vectorEncoding.toUpperCase(Locale.ROOT)); |
| 252 | + return this; |
| 253 | + } |
| 254 | + |
| 255 | + public Builder setDimensions(int dimensions) { |
| 256 | + this.dimensions = dimensions; |
| 257 | + return this; |
| 258 | + } |
| 259 | + |
| 260 | + public CmdLineArgs build() { |
| 261 | + if (docVectors == null) { |
| 262 | + throw new IllegalArgumentException("Document vectors path must be provided"); |
| 263 | + } |
| 264 | + if (dimensions <= 0) { |
| 265 | + throw new IllegalArgumentException("dimensions must be a positive integer"); |
| 266 | + } |
| 267 | + return new CmdLineArgs( |
| 268 | + docVectors, |
| 269 | + queryVectors, |
| 270 | + numDocs, |
| 271 | + numQueries, |
| 272 | + indexType, |
| 273 | + numCandidates, |
| 274 | + k, |
| 275 | + nProbe, |
| 276 | + ivfClusterSize, |
| 277 | + overSamplingFactor, |
| 278 | + hnswM, |
| 279 | + hnswEfConstruction, |
| 280 | + searchThreads, |
| 281 | + indexThreads, |
| 282 | + reindex, |
| 283 | + forceMerge, |
| 284 | + vectorSpace, |
| 285 | + quantizeBits, |
| 286 | + vectorEncoding, |
| 287 | + dimensions |
| 288 | + ); |
| 289 | + } |
| 290 | + } |
| 291 | +} |
0 commit comments