Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ record CmdLineArgs(
float filterSelectivity,
long seed,
VectorSimilarityFunction vectorSpace,
int rawVectorSize,
int quantizeBits,
VectorEncoding vectorEncoding,
int dimensions,
Expand Down Expand Up @@ -80,6 +81,7 @@ record CmdLineArgs(
static final ParseField FORCE_MERGE_FIELD = new ParseField("force_merge");
static final ParseField VECTOR_SPACE_FIELD = new ParseField("vector_space");
static final ParseField QUANTIZE_BITS_FIELD = new ParseField("quantize_bits");
static final ParseField RAW_VECTOR_SIZE_FIELD = new ParseField("raw_vector_size");
static final ParseField VECTOR_ENCODING_FIELD = new ParseField("vector_encoding");
static final ParseField DIMENSIONS_FIELD = new ParseField("dimensions");
static final ParseField EARLY_TERMINATION_FIELD = new ParseField("early_termination");
Expand Down Expand Up @@ -123,6 +125,7 @@ static CmdLineArgs fromXContent(XContentParser parser) throws IOException {
PARSER.declareBoolean(Builder::setReindex, REINDEX_FIELD);
PARSER.declareBoolean(Builder::setForceMerge, FORCE_MERGE_FIELD);
PARSER.declareString(Builder::setVectorSpace, VECTOR_SPACE_FIELD);
PARSER.declareInt(Builder::setRawVectorSize, RAW_VECTOR_SIZE_FIELD);
PARSER.declareInt(Builder::setQuantizeBits, QUANTIZE_BITS_FIELD);
PARSER.declareString(Builder::setVectorEncoding, VECTOR_ENCODING_FIELD);
PARSER.declareInt(Builder::setDimensions, DIMENSIONS_FIELD);
Expand Down Expand Up @@ -161,6 +164,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(REINDEX_FIELD.getPreferredName(), reindex);
builder.field(FORCE_MERGE_FIELD.getPreferredName(), forceMerge);
builder.field(VECTOR_SPACE_FIELD.getPreferredName(), vectorSpace.name().toLowerCase(Locale.ROOT));
builder.field(RAW_VECTOR_SIZE_FIELD.getPreferredName(), rawVectorSize);
builder.field(QUANTIZE_BITS_FIELD.getPreferredName(), quantizeBits);
builder.field(VECTOR_ENCODING_FIELD.getPreferredName(), vectorEncoding.name().toLowerCase(Locale.ROOT));
builder.field(DIMENSIONS_FIELD.getPreferredName(), dimensions);
Expand Down Expand Up @@ -196,6 +200,7 @@ static class Builder {
private boolean reindex = false;
private boolean forceMerge = false;
private VectorSimilarityFunction vectorSpace = VectorSimilarityFunction.EUCLIDEAN;
private int rawVectorSize = 32;
private int quantizeBits = 8;
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
private int dimensions;
Expand Down Expand Up @@ -305,6 +310,11 @@ public Builder setVectorSpace(String vectorSpace) {
return this;
}

public Builder setRawVectorSize(int rawVectorSize) {
this.rawVectorSize = rawVectorSize;
return this;
}

public Builder setQuantizeBits(int quantizeBits) {
this.quantizeBits = quantizeBits;
return this;
Expand Down Expand Up @@ -380,6 +390,7 @@ public CmdLineArgs build() {
filterSelectivity,
seed,
vectorSpace,
rawVectorSize,
quantizeBits,
vectorEncoding,
dimensions,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.util.BitUtil;

import java.nio.ByteOrder;
import java.nio.ShortBuffer;

public class BFloat16 {

public static final int BYTES = Short.BYTES;

public static short floatToBFloat16(float f) {
// this rounds towards 0
// zero - zero exp, zero fraction
// denormal - zero exp, non-zero fraction
// infinity - all-1 exp, zero fraction
// NaN - all-1 exp, non-zero fraction
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into
// infinities
return (short) (Float.floatToIntBits(f) >>> 16);
}

public static float bFloat16ToFloat(short bf) {
return Float.intBitsToFloat(bf << 16);
}

public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) {
assert bFloats.remaining() == floats.length;
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
for (float v : floats) {
bFloats.put(floatToBFloat16(v));
}
}

public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) {
assert floats.length * 2 == bfBytes.length;
for (int i = 0; i < floats.length; i++) {
floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2));
}
}

public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) {
assert floats.length == bFloats.remaining();
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
for (int i = 0; i < floats.length; i++) {
floats[i] = bFloat16ToFloat(bFloats.get());
}
}

private BFloat16() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,81 @@

import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.store.FlushInfo;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.MergeInfo;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
import org.elasticsearch.index.store.FsDirectoryFactory;

import java.io.IOException;
import java.util.Set;

public abstract class DirectIOCapableFlatVectorsFormat extends AbstractFlatVectorsFormat {
protected DirectIOCapableFlatVectorsFormat(String name) {
super(name);
}

protected abstract FlatVectorsReader createReader(SegmentReadState state) throws IOException;

static boolean canUseDirectIO(SegmentReadState state) {
return FsDirectoryFactory.isHybridFs(state.directory);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return fieldsReader(state, false);
}

public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException;
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) {
// only override the context for the random-access use case
SegmentReadState directIOState = new SegmentReadState(
state.directory,
state.segmentInfo,
state.fieldInfos,
new DirectIOContext(state.context.hints()),
state.segmentSuffix
);
// Use mmap for merges and direct I/O for searches.
return new MergeReaderWrapper(createReader(directIOState), createReader(state));
} else {
return createReader(state);
}
}

static class DirectIOContext implements IOContext {

final Set<FileOpenHint> hints;

DirectIOContext(Set<FileOpenHint> hints) {
// always add DirectIOHint to the hints given
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
}

@Override
public Context context() {
return Context.DEFAULT;
}

@Override
public MergeInfo mergeInfo() {
return null;
}

@Override
public FlushInfo flushInfo() {
return null;
}

@Override
public Set<FileOpenHint> hints() {
return hints;
}

@Override
public IOContext withHints(FileOpenHint... hints) {
return new DirectIOContext(Set.of(hints));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.es93.DirectIOCapableLucene99FlatVectorsFormat;
import org.elasticsearch.index.codec.vectors.es93.ES93BFloat16FlatVectorsFormat;

import java.io.IOException;
import java.util.Map;
Expand Down Expand Up @@ -58,12 +59,17 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
public static final int VERSION_DIRECT_IO = 1;
public static final int VERSION_CURRENT = VERSION_DIRECT_IO;

private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
);
private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
);
private static final Map<String, DirectIOCapableFlatVectorsFormat> supportedFormats = Map.of(
rawVectorFormat.getName(),
rawVectorFormat
float32VectorFormat.getName(),
float32VectorFormat,
bfloat16VectorFormat.getName(),
bfloat16VectorFormat
);

// This dynamically sets the cluster probe based on the `k` requested and the number of clusters.
Expand All @@ -79,12 +85,13 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
private final int vectorPerCluster;
private final int centroidsPerParentCluster;
private final boolean useDirectIO;
private final DirectIOCapableFlatVectorsFormat rawVectorFormat;

public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) {
this(vectorPerCluster, centroidsPerParentCluster, false);
this(vectorPerCluster, centroidsPerParentCluster, false, false);
}

public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO) {
public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster, boolean useDirectIO, boolean useBFloat16) {
super(NAME);
if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) {
throw new IllegalArgumentException(
Expand All @@ -109,6 +116,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu
this.vectorPerCluster = vectorPerCluster;
this.centroidsPerParentCluster = centroidsPerParentCluster;
this.useDirectIO = useDirectIO;
this.rawVectorFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat;
}

/** Constructs a format using the given graph construction parameters and scalar quantization. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,9 @@
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.FlushInfo;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.MergeInfo;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
import org.elasticsearch.index.store.FsDirectoryFactory;

import java.io.IOException;
import java.util.Set;

public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {

Expand All @@ -45,72 +37,12 @@ protected FlatVectorsScorer flatVectorsScorer() {
}

@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
}

static boolean canUseDirectIO(SegmentReadState state) {
return FsDirectoryFactory.isHybridFs(state.directory);
protected FlatVectorsReader createReader(SegmentReadState state) throws IOException {
return new Lucene99FlatVectorsReader(state, vectorsScorer);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return fieldsReader(state, false);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) {
// only override the context for the random-access use case
SegmentReadState directIOState = new SegmentReadState(
state.directory,
state.segmentInfo,
state.fieldInfos,
new DirectIOContext(state.context.hints()),
state.segmentSuffix
);
// Use mmap for merges and direct I/O for searches.
return new MergeReaderWrapper(
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
new Lucene99FlatVectorsReader(state, vectorsScorer)
);
} else {
return new Lucene99FlatVectorsReader(state, vectorsScorer);
}
}

static class DirectIOContext implements IOContext {

final Set<FileOpenHint> hints;

DirectIOContext(Set<FileOpenHint> hints) {
// always add DirectIOHint to the hints given
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
}

@Override
public Context context() {
return Context.DEFAULT;
}

@Override
public MergeInfo mergeInfo() {
return null;
}

@Override
public FlushInfo flushInfo() {
return null;
}

@Override
public Set<FileOpenHint> hints() {
return hints;
}

@Override
public IOContext withHints(FileOpenHint... hints) {
return new DirectIOContext(Set.of(hints));
}
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
}
}
Loading