Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions qa/vector/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/

import org.elasticsearch.gradle.internal.test.TestUtil

apply plugin: 'elasticsearch.java'
apply plugin: 'elasticsearch.build'

Expand All @@ -23,6 +25,8 @@ dependencies {
api "org.apache.lucene:lucene-core:${versions.lucene}"
api "org.apache.lucene:lucene-queries:${versions.lucene}"
api "org.apache.lucene:lucene-codecs:${versions.lucene}"
implementation project(':libs:simdvec')
implementation project(':libs:native')
implementation project(':libs:logging')
implementation project(':server')
}
Expand All @@ -37,6 +41,7 @@ tasks.register("checkVec", JavaExec) {
// Configure logging to console
systemProperty "es.logger.out", "console"
systemProperty "es.logger.level", "INFO" // Change to DEBUG if needed
systemProperty 'es.nativelibs.path', TestUtil.getTestLibraryPath(file("../../libs/native/libraries/build/platform/").toString())

if (buildParams.getRuntimeJavaVersion().map { it.majorVersion.toInteger() }.get() >= 21) {
jvmArgs '-Xms4g', '-Xmx4g', '--add-modules=jdk.incubator.vector', '--enable-native-access=ALL-UNNAMED', '-Djava.util.concurrent.ForkJoinPool.common.parallelism=8', '-XX:+UnlockDiagnosticVMOptions', '-XX:+DebugNonSafepoints', '-XX:+HeapDumpOnOutOfMemoryError'
Expand Down
1 change: 1 addition & 0 deletions qa/vector/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
requires org.elasticsearch.base;
requires org.elasticsearch.server;
requires org.elasticsearch.xcontent;
requires org.elasticsearch.cli;
requires org.apache.lucene.core;
requires org.apache.lucene.codecs;
requires org.apache.lucene.queries;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene101.Lucene101Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.elasticsearch.cli.ProcessInfo;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.PathUtils;
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.IVFVectorsFormat;
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat;
import org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat;
import org.elasticsearch.logging.Level;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
Expand All @@ -35,19 +39,26 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;

/**
* A utility class to create and test KNN indices using Lucene.
* It supports various index types (HNSW, FLAT, IVF) and configurations.
*/
public class KnnIndexTester {
static final Level LOG_LEVEL = Level.DEBUG;

static final SysOutLogger logger = new SysOutLogger();
static final Logger logger;

static {
LogConfigurator.loadLog4jPlugins();
LogConfigurator.configureESLogging(); // native access requires logging to be initialized

// necessary otherwise the es.logger.level system configuration in build.gradle is ignored
ProcessInfo pinfo = ProcessInfo.fromSystem();
Map<String, String> sysprops = pinfo.sysprops();
String loggerLevel = sysprops.getOrDefault("es.logger.level", Level.INFO.name());
Settings settings = Settings.builder().put("logger.level", loggerLevel).build();
LogConfigurator.configureWithoutConfig(settings);

logger = LogManager.getLogger(KnnIndexTester.class);
}

static final String INDEX_DIR = "target/knn_index";
Expand Down Expand Up @@ -163,7 +174,7 @@ public static void main(String[] args) throws Exception {
FormattedResults formattedResults = new FormattedResults();
for (CmdLineArgs cmdLineArgs : cmdLineArgsList) {
Results result = new Results(cmdLineArgs.indexType().name().toLowerCase(Locale.ROOT), cmdLineArgs.numDocs());
System.out.println("Running KNN index tester with arguments: " + cmdLineArgs);
logger.info("Running KNN index tester with arguments: " + cmdLineArgs);
Codec codec = createCodec(cmdLineArgs);
Path indexPath = PathUtils.get(formatIndexPath(cmdLineArgs));
if (cmdLineArgs.reindex() || cmdLineArgs.forceMerge()) {
Expand Down Expand Up @@ -195,8 +206,7 @@ public static void main(String[] args) throws Exception {
}
formattedResults.results.add(result);
}
System.out.println("Results:");
System.out.println(formattedResults);
logger.info("Results: \n" + formattedResults);
}

static class FormattedResults {
Expand Down Expand Up @@ -326,57 +336,6 @@ static class Results {
}
}

static final class SysOutLogger {

void warn(String message) {
if (LOG_LEVEL.ordinal() >= Level.WARN.ordinal()) {
System.out.println(message);
}
}

void warn(String message, Object... params) {
if (LOG_LEVEL.ordinal() >= Level.WARN.ordinal()) {
System.out.println(String.format(Locale.ROOT, message, params));
}
}

void info(String message) {
if (LOG_LEVEL.ordinal() >= Level.INFO.ordinal()) {
System.out.println(message);
}
}

void info(String message, Object... params) {
if (LOG_LEVEL.ordinal() >= Level.INFO.ordinal()) {
System.out.println(String.format(Locale.ROOT, message, params));
}
}

void debug(String message) {
if (LOG_LEVEL.ordinal() >= Level.DEBUG.ordinal()) {
System.out.println(message);
}
}

void debug(String message, Object... params) {
if (LOG_LEVEL.ordinal() >= Level.DEBUG.ordinal()) {
System.out.println(String.format(Locale.ROOT, message, params));
}
}

void trace(String message) {
if (LOG_LEVEL == Level.TRACE) {
System.out.println(message);
}
}

void trace(String message, Object... params) {
if (LOG_LEVEL == Level.TRACE) {
System.out.println(String.format(Locale.ROOT, message, params));
}
}
}

static final class ThreadDetails {
private static final ThreadMXBean threadBean = (ThreadMXBean) java.lang.management.ManagementFactory.getThreadMXBean();
public final long[] threadIDs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import org.apache.lucene.internal.hppc.IntArrayList;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;

import java.io.IOException;
Expand All @@ -31,14 +32,14 @@
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize;
import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary;
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT;

/**
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
* partition the vector space, and then stores the centroids and posting list in a sequential
* fashion.
*/
public class DefaultIVFVectorsWriter extends IVFVectorsWriter {
private static final Logger logger = LogManager.getLogger(DefaultIVFVectorsWriter.class);

private final int vectorPerCluster;

Expand All @@ -53,7 +54,6 @@ long[] buildAndWritePostingsLists(
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
InfoStream infoStream,
IntArrayList[] assignmentsByCluster
) throws IOException {
// write the posting lists
Expand All @@ -79,14 +79,14 @@ long[] buildAndWritePostingsLists(
writePostingList(cluster, postingsOutput, binarizedByteVectorValues);
}

if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
printClusterQualityStatistics(assignmentsByCluster, infoStream);
if (logger.isDebugEnabled()) {
printClusterQualityStatistics(assignmentsByCluster);
}

return offsets;
}

private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) {
private static void printClusterQualityStatistics(IntArrayList[] clusters) {
float min = Float.MAX_VALUE;
float max = Float.MIN_VALUE;
float mean = 0;
Expand All @@ -105,20 +105,14 @@ private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoS
max = Math.max(max, cluster.size());
}
float variance = m2 / (clusters.length - 1);
infoStream.message(
IVF_VECTOR_COMPONENT,
"Centroid count: "
+ clusters.length
+ " min: "
+ min
+ " max: "
+ max
+ " mean: "
+ mean
+ " stdDev: "
+ Math.sqrt(variance)
+ " variance: "
+ variance
logger.debug(
"Centroid count: {} min: {} max: {} mean: {} stdDev: {} variance: {}",
clusters.length,
min,
max,
mean,
Math.sqrt(variance),
variance
);
}

Expand Down Expand Up @@ -208,17 +202,16 @@ CentroidAssignments calculateAndWriteCentroids(
float[] globalCentroid
) throws IOException {
// TODO: take advantage of prior generated clusters from mergeState in the future
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, mergeState.infoStream, globalCentroid, false);
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid, false);
}

CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
InfoStream infoStream,
float[] globalCentroid
) throws IOException {
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, infoStream, globalCentroid, true);
return calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidOutput, globalCentroid, true);
}

/**
Expand All @@ -228,7 +221,6 @@ CentroidAssignments calculateAndWriteCentroids(
* @param fieldInfo merging field info
* @param floatVectorValues the float vector values to merge
* @param centroidOutput the centroid output
* @param infoStream the merge state
* @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids
* @param cacheCentroids whether the centroids are kept or discarded once computed
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
Expand All @@ -238,7 +230,6 @@ CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
InfoStream infoStream,
float[] globalCentroid,
boolean cacheCentroids
) throws IOException {
Expand Down Expand Up @@ -266,12 +257,9 @@ CentroidAssignments calculateAndWriteCentroids(
// write centroids
writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput);

if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
infoStream.message(
IVF_VECTOR_COMPONENT,
"calculate centroids and assign vectors time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)
);
infoStream.message(IVF_VECTOR_COMPONENT, "final centroid count: " + centroids.length);
if (logger.isDebugEnabled()) {
logger.debug("calculate centroids and assign vectors time ms: {}", (System.nanoTime() - nanoTime) / 1000000.0);
logger.debug("final centroid count: {}", centroids.length);
}

IntArrayList[] assignmentsByCluster = new IntArrayList[centroids.length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ static final class ESFlatVectorsScorer implements FlatVectorsScorer {
final FlatVectorsScorer delegate;
final VectorScorerFactory factory;

ESFlatVectorsScorer(FlatVectorsScorer delegte) {
this.delegate = delegte;
ESFlatVectorsScorer(FlatVectorsScorer delegate) {
this.delegate = delegate;
factory = VectorScorerFactory.instance().orElse(null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
*/
public class IVFVectorsFormat extends KnnVectorsFormat {

public static final String IVF_VECTOR_COMPONENT = "IVF";
public static final String NAME = "IVFVectorsFormat";
// centroid ordinals -> centroid values, offsets
public static final String CENTROID_EXTENSION = "cenivf";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.VectorUtil;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.SuppressForbidden;
Expand Down Expand Up @@ -134,7 +133,6 @@ abstract CentroidAssignments calculateAndWriteCentroids(
FieldInfo fieldInfo,
FloatVectorValues floatVectorValues,
IndexOutput centroidOutput,
InfoStream infoStream,
float[] globalCentroid
) throws IOException;

Expand All @@ -143,7 +141,6 @@ abstract long[] buildAndWritePostingsLists(
CentroidSupplier centroidSupplier,
FloatVectorValues floatVectorValues,
IndexOutput postingsOutput,
InfoStream infoStream,
IntArrayList[] assignmentsByCluster
) throws IOException;

Expand All @@ -168,7 +165,6 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
fieldWriter.fieldInfo,
floatVectorValues,
ivfCentroids,
segmentWriteState.infoStream,
globalCentroid
);

Expand All @@ -180,7 +176,6 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
centroidSupplier,
floatVectorValues,
ivfClusters,
segmentWriteState.infoStream,
centroidAssignments.assignmentsByCluster()
);
// write posting lists
Expand Down Expand Up @@ -313,7 +308,6 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
centroidSupplier,
floatVectorValues,
ivfClusters,
mergeState.infoStream,
centroidAssignments.assignmentsByCluster()
);
assert offsets.length == centroidSupplier.size();
Expand Down