Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.elasticsearch.core.SuppressForbidden;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

public abstract class AbstractFlatVectorsFormat extends FlatVectorsFormat {

protected static final boolean USE_DIRECT_IO = getUseDirectIO();

@SuppressForbidden(
reason = "TODO Deprecate any lenient usage of Boolean#parseBoolean https://github.com/elastic/elasticsearch/issues/128993"
)
private static boolean getUseDirectIO() {
return Boolean.parseBoolean(System.getProperty("vector.rescoring.directio", "false"));
}

protected AbstractFlatVectorsFormat(String name) {
super(name);
}

protected abstract FlatVectorsScorer flatVectorsScorer();

@Override
public int getMaxDimensions(String fieldName) {
return MAX_DIMS_COUNT;
}

@Override
public String toString() {
return getName() + "(name=" + getName() + ", flatVectorScorer=" + flatVectorsScorer() + ")";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.util.hnsw.HnswGraph;

import java.util.concurrent.ExecutorService;

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

public abstract class AbstractHnswVectorsFormat extends KnnVectorsFormat {

/**
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
*/
protected final int maxConn;

/**
* The number of candidate neighbors to track while searching the graph for each newly inserted
* node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph}
* for details.
*/
protected final int beamWidth;

protected final int numMergeWorkers;
protected final TaskExecutor mergeExec;

/** Constructs a format using default graph construction parameters */
protected AbstractHnswVectorsFormat(String name) {
this(name, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
}

/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
*/
protected AbstractHnswVectorsFormat(String name, int maxConn, int beamWidth) {
this(name, maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
}

/**
* Constructs a format using the given graph construction parameters and scalar quantization.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
* generated by this format to do the merge
*/
protected AbstractHnswVectorsFormat(String name, int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
super(name);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
if (numMergeWorkers == 1 && mergeExec != null) {
throw new IllegalArgumentException("No executor service is needed as we'll use single thread to merge");
}
this.numMergeWorkers = numMergeWorkers;
if (mergeExec != null) {
this.mergeExec = new TaskExecutor(mergeExec);
} else {
this.mergeExec = null;
}
}

protected abstract KnnVectorsFormat flatVectorsFormat();

@Override
public int getMaxDimensions(String fieldName) {
return MAX_DIMS_COUNT;
}

@Override
public String toString() {
return getName()
+ "(name="
+ getName()
+ ", maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat()
+ ")";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,11 @@

import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

public final class ES814HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
public final class ES814HnswScalarQuantizedVectorsFormat extends AbstractHnswVectorsFormat {

static final String NAME = "ES814HnswScalarQuantizedVectorsFormat";

static final int MAXIMUM_MAX_CONN = 512;
static final int MAXIMUM_BEAM_WIDTH = 3200;

private final int maxConn;

private final int beamWidth;

/** The format for storing, reading, merging vectors on disk */
private final FlatVectorsFormat flatVectorsFormat;

Expand All @@ -43,45 +35,22 @@ public ES814HnswScalarQuantizedVectorsFormat() {
}

public ES814HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth, Float confidenceInterval, int bits, boolean compress) {
super(NAME);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
super(NAME, maxConn, beamWidth);
this.flatVectorsFormat = new ES814ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress);
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
protected KnnVectorsFormat flatVectorsFormat() {
return flatVectorsFormat;
}

@Override
public int getMaxDimensions(String fieldName) {
return MAX_DIMS_COUNT;
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec);
}

@Override
public String toString() {
return "ES814HnswScalarQuantizedVectorsFormat(name=ES814HnswScalarQuantizedVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,63 +20,32 @@

import java.io.IOException;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

public class ES815HnswBitVectorsFormat extends KnnVectorsFormat {
public class ES815HnswBitVectorsFormat extends AbstractHnswVectorsFormat {

static final String NAME = "ES815HnswBitVectorsFormat";

static final int MAXIMUM_MAX_CONN = 512;
static final int MAXIMUM_BEAM_WIDTH = 3200;

private final int maxConn;
private final int beamWidth;

private static final FlatVectorsFormat flatVectorsFormat = new ES815BitFlatVectorsFormat();

public ES815HnswBitVectorsFormat() {
this(16, 100);
}

public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) {
super(NAME);
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to " + MAXIMUM_MAX_CONN + "; maxConn=" + maxConn
);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to " + MAXIMUM_BEAM_WIDTH + "; beamWidth=" + beamWidth
);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
}

@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), 1, null);
public ES815HnswBitVectorsFormat(int maxConn, int beamWidth) {
super(NAME, maxConn, beamWidth);
}

@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
protected KnnVectorsFormat flatVectorsFormat() {
return flatVectorsFormat;
}

@Override
public String toString() {
return "ES815HnswBitVectorsFormat(name=ES815HnswBitVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", flatVectorFormat="
+ flatVectorsFormat
+ ")";
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, flatVectorsFormat.fieldsWriter(state), numMergeWorkers, mergeExec);
}

@Override
public int getMaxDimensions(String fieldName) {
return MAX_DIMS_COUNT;
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors.es818;
package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
Expand All @@ -25,19 +25,19 @@
import java.util.Collection;
import java.util.Map;

class MergeReaderWrapper extends FlatVectorsReader implements OffHeapStats {
public class MergeReaderWrapper extends FlatVectorsReader implements OffHeapStats {

private final FlatVectorsReader mainReader;
private final FlatVectorsReader mergeReader;

protected MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) {
public MergeReaderWrapper(FlatVectorsReader mainReader, FlatVectorsReader mergeReader) {
super(mainReader.getFlatVectorScorer());
this.mainReader = mainReader;
this.mergeReader = mergeReader;
}

// For testing
FlatVectorsReader getMainReader() {
public FlatVectorsReader getMainReader() {
return mainReader;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;

/** Vector scorer over binarized vector values */
class ES816BinaryFlatVectorsScorer implements FlatVectorsScorer {
public class ES816BinaryFlatVectorsScorer implements FlatVectorsScorer {
private final FlatVectorsScorer nonQuantizedDelegate;

ES816BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) {
public ES816BinaryFlatVectorsScorer(FlatVectorsScorer nonQuantizedDelegate) {
this.nonQuantizedDelegate = nonQuantizedDelegate;
}

Expand Down
Loading