Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import org.elasticsearch.common.util.FeatureFlag;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.mapper.vectors.VectorsFormatProvider;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.internal.InternalVectorFormatProviderPlugin;
import org.elasticsearch.xpack.gpu.codec.ES92GpuHnswSQVectorsFormat;
Expand All @@ -21,6 +23,8 @@

public class GPUPlugin extends Plugin implements InternalVectorFormatProviderPlugin {

private static final Logger logger = LogManager.getLogger(GPUPlugin.class);

public static final FeatureFlag GPU_FORMAT = new FeatureFlag("gpu_vectors_indexing");

/**
Expand Down Expand Up @@ -49,6 +53,29 @@ public enum GpuMode {
Setting.Property.Dynamic
);

/** The default minimum number of vectors required before building on the GPU. */
public static final int DEFAULT_MIN_NUM_VECTORS_FOR_GPU_BUILD = 10_000;

public static final int MIN_NUM_VECTORS_FOR_GPU_BUILD = tinySegmentProperty();

public static int tinySegmentProperty() {
int v = DEFAULT_MIN_NUM_VECTORS_FOR_GPU_BUILD;
String str = System.getProperty("gpu.tiny.segment.size");
if (str != null) {
try {
int parsedValue = Integer.parseInt(str);
if (parsedValue > 1) {
v = parsedValue;
} else {
logger.warn("Ignoring gpu.tiny.segment.size. Value too small:" + parsedValue);
}
} catch (NumberFormatException e) {
logger.warn("Bad gpu.tiny.segment.size. Not a number:" + str);
}
}
return v;
}

@Override
public List<Setting<?>> getSettings() {
if (GPU_FORMAT.isEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.function.Supplier;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
import static org.elasticsearch.xpack.gpu.GPUPlugin.MIN_NUM_VECTORS_FOR_GPU_BUILD;
import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.elasticsearch.xpack.gpu.codec.ES92GpuHnswVectorsFormat.DEFAULT_MAX_CONN;

Expand All @@ -33,18 +34,32 @@ public class ES92GpuHnswSQVectorsFormat extends KnnVectorsFormat {
static final int MAXIMUM_BEAM_WIDTH = 3200;
private final int maxConn;
private final int beamWidth;
// The threshold to use to bypass HNSW graph building for tiny segments on the GPU.
private final int tinySegmentsThreshold;

/** The format for storing, reading, merging vectors on disk */
private final FlatVectorsFormat flatVectorsFormat;
private final Supplier<CuVSResourceManager> cuVSResourceManagerSupplier;

public ES92GpuHnswSQVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null, 7, false);
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null, 7, false, CuVSResourceManager::pooling, MIN_NUM_VECTORS_FOR_GPU_BUILD);
}

public ES92GpuHnswSQVectorsFormat(int maxConn, int beamWidth, Float confidenceInterval, int bits, boolean compress) {
this(maxConn, beamWidth, confidenceInterval, bits, compress, CuVSResourceManager::pooling, MIN_NUM_VECTORS_FOR_GPU_BUILD);
}

public ES92GpuHnswSQVectorsFormat(
int maxConn,
int beamWidth,
Float confidenceInterval,
int bits,
boolean compress,
Supplier<CuVSResourceManager> cuVSResourceManagerSupplier,
int tinySegmentsThreshold
) {
super(NAME);
this.cuVSResourceManagerSupplier = CuVSResourceManager::pooling;
this.cuVSResourceManagerSupplier = cuVSResourceManagerSupplier;
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
Expand All @@ -55,8 +70,12 @@ public ES92GpuHnswSQVectorsFormat(int maxConn, int beamWidth, Float confidenceIn
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
);
}
if (tinySegmentsThreshold < 2) {
throw new IllegalArgumentException("tinySegmentsThreshold must be greater than 1, got:" + tinySegmentsThreshold);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.tinySegmentsThreshold = tinySegmentsThreshold;
this.flatVectorsFormat = new ES814ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress);
}

Expand All @@ -67,7 +86,8 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
state,
maxConn,
beamWidth,
flatVectorsFormat.fieldsWriter(state)
flatVectorsFormat.fieldsWriter(state),
tinySegmentsThreshold
);
}

Expand All @@ -90,6 +110,8 @@ public String toString() {
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", tinySegmentsThreshold="
+ tinySegmentsThreshold
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.function.Supplier;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;
import static org.elasticsearch.xpack.gpu.GPUPlugin.MIN_NUM_VECTORS_FOR_GPU_BUILD;

/**
* Codec format for GPU-accelerated vector indexes. This format is designed to
Expand All @@ -38,7 +39,6 @@ public class ES92GpuHnswVectorsFormat extends KnnVectorsFormat {

static final int DEFAULT_MAX_CONN = 16; // graph degree
public static final int DEFAULT_BEAM_WIDTH = 128; // intermediate graph degree
static final int MIN_NUM_VECTORS_FOR_GPU_BUILD = 2;

private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
Expand All @@ -49,20 +49,31 @@ public class ES92GpuHnswVectorsFormat extends KnnVectorsFormat {
// Intermediate graph degree, the number of connections for each node before pruning
private final int beamWidth;
private final Supplier<CuVSResourceManager> cuVSResourceManagerSupplier;
// The threshold to use to bypass HNSW graph building for tiny segments on the GPU.
private final int tinySegmentsThreshold;

public ES92GpuHnswVectorsFormat() {
this(CuVSResourceManager::pooling, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, CuVSResourceManager::pooling, MIN_NUM_VECTORS_FOR_GPU_BUILD);
}

public ES92GpuHnswVectorsFormat(int maxConn, int beamWidth) {
this(CuVSResourceManager::pooling, maxConn, beamWidth);
};
this(maxConn, beamWidth, CuVSResourceManager::pooling, MIN_NUM_VECTORS_FOR_GPU_BUILD);
}

public ES92GpuHnswVectorsFormat(Supplier<CuVSResourceManager> cuVSResourceManagerSupplier, int maxConn, int beamWidth) {
public ES92GpuHnswVectorsFormat(
int maxConn,
int beamWidth,
Supplier<CuVSResourceManager> cuVSResourceManagerSupplier,
int tinySegmentsThreshold
) {
super(NAME);
if (tinySegmentsThreshold < 2) {
throw new IllegalArgumentException("tinySegmentsThreshold must be greater than 1, got:" + tinySegmentsThreshold);
}
this.cuVSResourceManagerSupplier = cuVSResourceManagerSupplier;
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.tinySegmentsThreshold = tinySegmentsThreshold;
}

@Override
Expand All @@ -72,7 +83,8 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
state,
maxConn,
beamWidth,
flatVectorsFormat.fieldsWriter(state)
flatVectorsFormat.fieldsWriter(state),
tinySegmentsThreshold
);
}

Expand All @@ -95,6 +107,8 @@ public String toString() {
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", tinySegmentsThreshold="
+ tinySegmentsThreshold
+ ", flatVectorFormat="
+ flatVectorsFormat.getName()
+ ")";
Expand Down
Loading