|
| 1 | +/* |
| 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one |
| 3 | + * or more contributor license agreements. Licensed under the "Elastic License |
| 4 | + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side |
| 5 | + * Public License v 1"; you may not use this file except in compliance with, at |
| 6 | + * your election, the "Elastic License 2.0", the "GNU Affero General Public |
| 7 | + * License v3.0 only", or the "Server Side Public License, v 1". |
| 8 | + */ |
| 9 | + |
| 10 | +package org.elasticsearch.test.knn; |
| 11 | + |
| 12 | +import org.apache.lucene.codecs.Codec; |
| 13 | +import org.apache.lucene.codecs.KnnVectorsFormat; |
| 14 | +import org.apache.lucene.codecs.lucene101.Lucene101Codec; |
| 15 | +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; |
| 16 | +import org.apache.lucene.index.VectorEncoding; |
| 17 | +import org.apache.lucene.index.VectorSimilarityFunction; |
| 18 | +import org.elasticsearch.common.logging.LogConfigurator; |
| 19 | +import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat; |
| 20 | +import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat; |
| 21 | +import org.elasticsearch.index.codec.vectors.IVFVectorsFormat; |
| 22 | +import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat; |
| 23 | +import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; |
| 24 | + |
| 25 | +import java.nio.file.Path; |
| 26 | +import java.util.ArrayList; |
| 27 | +import java.util.List; |
| 28 | + |
| 29 | +class KnnIndexTester { |
| 30 | + static { |
| 31 | + LogConfigurator.loadLog4jPlugins(); |
| 32 | + LogConfigurator.configureESLogging(); // native access requires logging to be initialized |
| 33 | + } |
| 34 | + |
| 35 | + static final String INDEX_DIR = "target/knn_index"; |
| 36 | + |
| 37 | + enum IndexType { |
| 38 | + HNSW, |
| 39 | + FLAT, |
| 40 | + IVF |
| 41 | + } |
| 42 | + |
| 43 | + static class CmdLineArgs { |
| 44 | + static CmdLineArgs parse(String[] args) { |
| 45 | + CmdLineArgs cmdLineArgs = new CmdLineArgs(); |
| 46 | + |
| 47 | + for (String arg : args) { |
| 48 | + String[] parts = arg.split("="); |
| 49 | + String key = parts[0].trim(); |
| 50 | + String value = null; |
| 51 | + if (parts.length > 1) { |
| 52 | + value = parts[1].trim(); |
| 53 | + } |
| 54 | + if (parts.length > 2) { |
| 55 | + throw new IllegalArgumentException("Too many parts in argument: " + arg); |
| 56 | + } |
| 57 | + |
| 58 | + switch (key) { |
| 59 | + case "--docVectors": |
| 60 | + cmdLineArgs.docVectors = Path.of(value); |
| 61 | + break; |
| 62 | + case "--queryVectors": |
| 63 | + cmdLineArgs.queryVectors = Path.of(value); |
| 64 | + break; |
| 65 | + case "--numDocs": |
| 66 | + cmdLineArgs.numDocs = Integer.parseInt(value); |
| 67 | + break; |
| 68 | + case "--numQueries": |
| 69 | + cmdLineArgs.numQueries = Integer.parseInt(value); |
| 70 | + break; |
| 71 | + case "--indexType": |
| 72 | + cmdLineArgs.indexType = IndexType.valueOf(value.toUpperCase()); |
| 73 | + break; |
| 74 | + case "--numCandidates": |
| 75 | + cmdLineArgs.numCandidates = Integer.parseInt(value); |
| 76 | + break; |
| 77 | + case "--k": |
| 78 | + cmdLineArgs.k = Integer.parseInt(value); |
| 79 | + break; |
| 80 | + case "--nProbe": |
| 81 | + cmdLineArgs.nProbe = Integer.parseInt(value); |
| 82 | + break; |
| 83 | + case "--ivfClusterSize": |
| 84 | + cmdLineArgs.ivfClusterSize = Integer.parseInt(value); |
| 85 | + break; |
| 86 | + case "--overSamplingFactor": |
| 87 | + cmdLineArgs.overSamplingFactor = Integer.parseInt(value); |
| 88 | + break; |
| 89 | + case "--hnswM": |
| 90 | + cmdLineArgs.hnswM = Integer.parseInt(value); |
| 91 | + break; |
| 92 | + case "--hnswEfConstruction": |
| 93 | + cmdLineArgs.hnswEfConstruction = Integer.parseInt(value); |
| 94 | + break; |
| 95 | + case "--searchThreads": |
| 96 | + cmdLineArgs.searchThreads = Integer.parseInt(value); |
| 97 | + break; |
| 98 | + case "--indexThreads": |
| 99 | + cmdLineArgs.indexThreads = Integer.parseInt(value); |
| 100 | + break; |
| 101 | + case "--reindex": |
| 102 | + cmdLineArgs.reindex = true; |
| 103 | + break; |
| 104 | + case "--forceMerge": |
| 105 | + cmdLineArgs.forceMerge = true; |
| 106 | + break; |
| 107 | + case "--vectorSpace": |
| 108 | + cmdLineArgs.vectorSpace = VectorSimilarityFunction.valueOf(value.toUpperCase()); |
| 109 | + break; |
| 110 | + case "--quantizeBits": |
| 111 | + cmdLineArgs.quantizeBits = Integer.parseInt(value); |
| 112 | + break; |
| 113 | + case "--vectorEncoding": |
| 114 | + cmdLineArgs.vectorEncoding = VectorEncoding.valueOf(value.toUpperCase()); |
| 115 | + break; |
| 116 | + case "--dimensions": |
| 117 | + cmdLineArgs.dimensions = Integer.parseInt(value); |
| 118 | + break; |
| 119 | + default: |
| 120 | + throw new IllegalArgumentException("Unknown argument: " + key); |
| 121 | + } |
| 122 | + } |
| 123 | + return cmdLineArgs; |
| 124 | + } |
| 125 | + |
| 126 | + int numDocs = 1000; |
| 127 | + int numQueries = 10; |
| 128 | + IndexType indexType = IndexType.IVF; |
| 129 | + int numCandidates = 100; |
| 130 | + int k = 100; |
| 131 | + int nProbe = -1; |
| 132 | + int ivfClusterSize = 384; |
| 133 | + int overSamplingFactor = 0; |
| 134 | + int hnswM = 16; |
| 135 | + int hnswEfConstruction = 100; |
| 136 | + int searchThreads = 1; |
| 137 | + int indexThreads = 1; |
| 138 | + boolean reindex = true; |
| 139 | + boolean forceMerge = false; |
| 140 | + VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN; |
| 141 | + // 32 means no quantization |
| 142 | + int quantizeBits = 32; |
| 143 | + int dimensions = 1024; // Default dimension size for vectors |
| 144 | + VectorEncoding vectorEncoding = VectorEncoding.FLOAT32; |
| 145 | + Path docVectors = null; |
| 146 | + Path queryVectors = null; |
| 147 | + } |
| 148 | + |
| 149 | + private static String formatIndexPath(CmdLineArgs args) { |
| 150 | + List<String> suffix = new ArrayList<>(); |
| 151 | + if (args.indexType == IndexType.FLAT) { |
| 152 | + suffix.add("flat"); |
| 153 | + } else if (args.indexType == IndexType.IVF) { |
| 154 | + suffix.add("ivf"); |
| 155 | + suffix.add(Integer.toString(args.ivfClusterSize)); |
| 156 | + } else { |
| 157 | + suffix.add(Integer.toString(args.hnswM)); |
| 158 | + suffix.add(Integer.toString(args.hnswEfConstruction)); |
| 159 | + if (args.quantizeBits < 32) { |
| 160 | + suffix.add(Integer.toString(args.quantizeBits)); |
| 161 | + } |
| 162 | + } |
| 163 | + return INDEX_DIR + "/" + args.docVectors.getFileName() + "-" + String.join("-", suffix) + ".index"; |
| 164 | + } |
| 165 | + |
| 166 | + static Codec createCodec(CmdLineArgs args) { |
| 167 | + final KnnVectorsFormat format; |
| 168 | + if (args.indexType == IndexType.IVF) { |
| 169 | + format = new IVFVectorsFormat(args.ivfClusterSize); |
| 170 | + } else { |
| 171 | + if (args.quantizeBits == 1) { |
| 172 | + if (args.indexType == IndexType.FLAT) { |
| 173 | + format = new ES818BinaryQuantizedVectorsFormat(); |
| 174 | + } else { |
| 175 | + format = new ES818HnswBinaryQuantizedVectorsFormat(args.hnswM, args.hnswEfConstruction, 1, null); |
| 176 | + } |
| 177 | + } else if (args.quantizeBits < 32) { |
| 178 | + if (args.indexType == IndexType.FLAT) { |
| 179 | + format = new ES813Int8FlatVectorFormat(null, args.quantizeBits, true); |
| 180 | + } else { |
| 181 | + format = new ES814HnswScalarQuantizedVectorsFormat(args.hnswM, args.hnswEfConstruction, null, args.quantizeBits, true); |
| 182 | + } |
| 183 | + } else { |
| 184 | + format = new Lucene99HnswVectorsFormat(args.hnswM, args.hnswEfConstruction, 1, null); |
| 185 | + } |
| 186 | + } |
| 187 | + return new Lucene101Codec() { |
| 188 | + @Override |
| 189 | + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { |
| 190 | + return format; |
| 191 | + } |
| 192 | + }; |
| 193 | + } |
| 194 | + |
| 195 | + public static void main(String[] args) throws Exception { |
| 196 | + CmdLineArgs cmdLineArgs = CmdLineArgs.parse(args); |
| 197 | + if (cmdLineArgs.docVectors == null || cmdLineArgs.docVectors.toFile().exists() == false) { |
| 198 | + throw new IllegalArgumentException("Document vectors file does not exist: " + cmdLineArgs.docVectors); |
| 199 | + } |
| 200 | + Codec codec = createCodec(cmdLineArgs); |
| 201 | + Path indexPath = Path.of(formatIndexPath(cmdLineArgs)); |
| 202 | + long indexCreationTimeMS = 0; |
| 203 | + long forceMergeTimeMS = 0; |
| 204 | + int numSegments = 1; |
| 205 | + StringBuilder resultHeaders = new StringBuilder(); |
| 206 | + StringBuilder resultValues = new StringBuilder(); |
| 207 | + // indicate params used for index creation |
| 208 | + resultHeaders.append("index_type,"); |
| 209 | + resultValues.append(cmdLineArgs.indexType).append(","); |
| 210 | + resultHeaders.append("num_docs,"); |
| 211 | + resultValues.append(cmdLineArgs.numDocs).append(","); |
| 212 | + if (cmdLineArgs.reindex || cmdLineArgs.forceMerge) { |
| 213 | + KnnIndexer knnIndexer = new KnnIndexer( |
| 214 | + cmdLineArgs.docVectors, |
| 215 | + indexPath, |
| 216 | + codec, |
| 217 | + cmdLineArgs.indexThreads, |
| 218 | + cmdLineArgs.vectorEncoding, |
| 219 | + cmdLineArgs.dimensions, |
| 220 | + cmdLineArgs.vectorSpace, |
| 221 | + cmdLineArgs.numDocs |
| 222 | + ); |
| 223 | + if (cmdLineArgs.reindex) { |
| 224 | + indexCreationTimeMS = knnIndexer.createIndex(); |
| 225 | + } |
| 226 | + if (cmdLineArgs.forceMerge) { |
| 227 | + forceMergeTimeMS = knnIndexer.forceMerge(); |
| 228 | + } else { |
| 229 | + numSegments = knnIndexer.numSegments(); |
| 230 | + } |
| 231 | + } |
| 232 | + if (indexCreationTimeMS > 0) { |
| 233 | + resultHeaders.append("index_time(ms)").append(","); |
| 234 | + resultValues.append(indexCreationTimeMS).append(","); |
| 235 | + } |
| 236 | + if (forceMergeTimeMS > 0) { |
| 237 | + resultHeaders.append("force_merge_time(ms)").append(","); |
| 238 | + resultValues.append(forceMergeTimeMS).append(","); |
| 239 | + } |
| 240 | + resultHeaders.append("num_segments,"); |
| 241 | + resultValues.append(numSegments).append(","); |
| 242 | + |
| 243 | + if (cmdLineArgs.queryVectors != null) { |
| 244 | + KnnSearcher knnSearcher = new KnnSearcher(indexPath, cmdLineArgs); |
| 245 | + KnnSearcher.SearcherResults results = knnSearcher.runSearch(); |
| 246 | + resultHeaders.append("latency(ms),"); |
| 247 | + resultValues.append(results.avgLatency()).append(","); |
| 248 | + resultHeaders.append("qps,"); |
| 249 | + resultValues.append(results.qps()).append(","); |
| 250 | + resultHeaders.append("recall,"); |
| 251 | + resultValues.append(results.avgRecall()).append(","); |
| 252 | + resultHeaders.append("visited,"); |
| 253 | + resultValues.append(results.averageVisited()).append(","); |
| 254 | + return; |
| 255 | + } |
| 256 | + System.out.println(resultHeaders); |
| 257 | + System.out.println(resultValues); |
| 258 | + } |
| 259 | +} |
0 commit comments