Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors;

import org.apache.lucene.util.BitUtil;

import java.nio.ByteOrder;
import java.nio.ShortBuffer;

public final class BFloat16 {

public static final int BYTES = Short.BYTES;

public static short floatToBFloat16(float f) {
// this rounds towards 0
// zero - zero exp, zero fraction
// denormal - zero exp, non-zero fraction
// infinity - all-1 exp, zero fraction
// NaN - all-1 exp, non-zero fraction
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into
// infinities
return (short) (Float.floatToIntBits(f) >>> 16);
}

public static float bFloat16ToFloat(short bf) {
return Float.intBitsToFloat(bf << 16);
}

public static void floatToBFloat16(float[] floats, ShortBuffer bFloats) {
assert bFloats.remaining() == floats.length;
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
for (float v : floats) {
bFloats.put(floatToBFloat16(v));
}
}

public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) {
assert floats.length * 2 == bfBytes.length;
for (int i = 0; i < floats.length; i++) {
floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2));
}
}

public static void bFloat16ToFloat(ShortBuffer bFloats, float[] floats) {
assert floats.length == bFloats.remaining();
assert bFloats.order() == ByteOrder.LITTLE_ENDIAN;
for (int i = 0; i < floats.length; i++) {
floats[i] = bFloat16ToFloat(bFloats.get());
}
}

private BFloat16() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

/** Utility class for vector quantization calculations */
public class BQVectorUtils {
private static final float EPSILON = 1e-4f;
// NOTE: this is currently > 1e-4f due to bfloat16
private static final float EPSILON = 1e-2f;

public static double sqrtNewtonRaphson(double x, double curr, double prev) {
return (curr == prev) ? curr : sqrtNewtonRaphson(x, 0.5 * (curr + x / curr), curr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,81 @@

import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.store.FlushInfo;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.MergeInfo;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
import org.elasticsearch.index.store.FsDirectoryFactory;

import java.io.IOException;
import java.util.Set;

public abstract class DirectIOCapableFlatVectorsFormat extends AbstractFlatVectorsFormat {
protected DirectIOCapableFlatVectorsFormat(String name) {
super(name);
}

protected abstract FlatVectorsReader createReader(SegmentReadState state) throws IOException;

static boolean canUseDirectIO(SegmentReadState state) {
return FsDirectoryFactory.isHybridFs(state.directory);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return fieldsReader(state, false);
}

public abstract FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException;
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) {
// only override the context for the random-access use case
SegmentReadState directIOState = new SegmentReadState(
state.directory,
state.segmentInfo,
state.fieldInfos,
new DirectIOContext(state.context.hints()),
state.segmentSuffix
);
// Use mmap for merges and direct I/O for searches.
return new MergeReaderWrapper(createReader(directIOState), createReader(state));
} else {
return createReader(state);
}
}

static class DirectIOContext implements IOContext {

final Set<FileOpenHint> hints;

DirectIOContext(Set<FileOpenHint> hints) {
// always add DirectIOHint to the hints given
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
}

@Override
public Context context() {
return Context.DEFAULT;
}

@Override
public MergeInfo mergeInfo() {
return null;
}

@Override
public FlushInfo flushInfo() {
return null;
}

@Override
public Set<FileOpenHint> hints() {
return hints;
}

@Override
public IOContext withHints(FileOpenHint... hints) {
return new DirectIOContext(Set.of(hints));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public QuantizationResult[] multiScalarQuantize(
}

public QuantizationResult scalarQuantize(float[] vector, float[] residualDestination, int[] destination, byte bits, float[] centroid) {
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
assert similarityFunction != COSINE || BQVectorUtils.isUnitVector(vector);
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
assert vector.length <= destination.length;
assert bits > 0 && bits <= 8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
public static final int VERSION_DIRECT_IO = 1;
public static final int VERSION_CURRENT = VERSION_DIRECT_IO;

private static final DirectIOCapableFlatVectorsFormat rawVectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
);
private static final Map<String, DirectIOCapableFlatVectorsFormat> supportedFormats = Map.of(
rawVectorFormat.getName(),
rawVectorFormat
float32VectorFormat.getName(),
float32VectorFormat
);

// This dynamically sets the cluster probe based on the `k` requested and the number of clusters.
Expand All @@ -79,6 +79,7 @@ public class ES920DiskBBQVectorsFormat extends KnnVectorsFormat {
private final int vectorPerCluster;
private final int centroidsPerParentCluster;
private final boolean useDirectIO;
private final DirectIOCapableFlatVectorsFormat rawVectorFormat;

public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) {
this(vectorPerCluster, centroidsPerParentCluster, false);
Expand Down Expand Up @@ -109,6 +110,7 @@ public ES920DiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentClu
this.vectorPerCluster = vectorPerCluster;
this.centroidsPerParentCluster = centroidsPerParentCluster;
this.useDirectIO = useDirectIO;
this.rawVectorFormat = float32VectorFormat;
}

/** Constructs a format using the given graph construction parameters and scalar quantization. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.elasticsearch.index.codec.vectors.BQSpaceUtils;
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.simdvec.ESVectorUtil;

Expand Down Expand Up @@ -70,7 +71,7 @@ public RandomVectorScorer getRandomVectorScorer(
assert binarizedVectors.size() > 0 : "BinarizedByteVectorValues must have at least one vector for ES816BinaryFlatVectorsScorer";
OptimizedScalarQuantizer quantizer = binarizedVectors.getQuantizer();
float[] centroid = binarizedVectors.getCentroid();
assert similarityFunction != COSINE || VectorUtil.isUnitVector(target);
assert similarityFunction != COSINE || BQVectorUtils.isUnitVector(target);
float[] scratch = new float[vectorValues.dimension()];
int[] initial = new int[target.length];
byte[] quantized = new byte[BQSpaceUtils.B_QUERY * binarizedVectors.discretizedDimensions() / 8];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,9 @@
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.FlushInfo;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.MergeInfo;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;
import org.elasticsearch.index.codec.vectors.MergeReaderWrapper;
import org.elasticsearch.index.codec.vectors.es818.DirectIOHint;
import org.elasticsearch.index.store.FsDirectoryFactory;

import java.io.IOException;
import java.util.Set;

public class DirectIOCapableLucene99FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {

Expand All @@ -45,72 +37,12 @@ protected FlatVectorsScorer flatVectorsScorer() {
}

@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
}

static boolean canUseDirectIO(SegmentReadState state) {
return FsDirectoryFactory.isHybridFs(state.directory);
protected FlatVectorsReader createReader(SegmentReadState state) throws IOException {
return new Lucene99FlatVectorsReader(state, vectorsScorer);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return fieldsReader(state, false);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state, boolean useDirectIO) throws IOException {
if (state.context.context() == IOContext.Context.DEFAULT && useDirectIO && canUseDirectIO(state)) {
// only override the context for the random-access use case
SegmentReadState directIOState = new SegmentReadState(
state.directory,
state.segmentInfo,
state.fieldInfos,
new DirectIOContext(state.context.hints()),
state.segmentSuffix
);
// Use mmap for merges and direct I/O for searches.
return new MergeReaderWrapper(
new Lucene99FlatVectorsReader(directIOState, vectorsScorer),
new Lucene99FlatVectorsReader(state, vectorsScorer)
);
} else {
return new Lucene99FlatVectorsReader(state, vectorsScorer);
}
}

static class DirectIOContext implements IOContext {

final Set<FileOpenHint> hints;

DirectIOContext(Set<FileOpenHint> hints) {
// always add DirectIOHint to the hints given
this.hints = Sets.union(hints, Set.of(DirectIOHint.INSTANCE));
}

@Override
public Context context() {
return Context.DEFAULT;
}

@Override
public MergeInfo mergeInfo() {
return null;
}

@Override
public FlushInfo flushInfo() {
return null;
}

@Override
public Set<FileOpenHint> hints() {
return hints;
}

@Override
public IOContext withHints(FileOpenHint... hints) {
return new DirectIOContext(Set.of(hints));
}
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* @notice
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Modifications copyright (C) 2025 Elasticsearch B.V.
*/
package org.elasticsearch.index.codec.vectors.es93;

import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.elasticsearch.index.codec.vectors.DirectIOCapableFlatVectorsFormat;

import java.io.IOException;

public final class ES93BFloat16FlatVectorsFormat extends DirectIOCapableFlatVectorsFormat {

static final String NAME = "ES93BFloat16FlatVectorsFormat";
static final String META_CODEC_NAME = "ES93BFloat16FlatVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "ES93BFloat16FlatVectorsFormatData";
static final String META_EXTENSION = "vemf";
static final String VECTOR_DATA_EXTENSION = "vec";

public static final int VERSION_START = 0;
public static final int VERSION_CURRENT = VERSION_START;

static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
private final FlatVectorsScorer vectorsScorer;

public ES93BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should allow injectable scorers here.

We cannot use the getLucene99FlatVectorsScorer as it will assume vectors are float32, and attempt to score them off-heap if possible, which will completely break with bfloat16

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use the default vector scorers at the moment, as they work on FloatVectorValues, which has a bfloat16 implementation here. The panama scorer implementation is also ok, as that only works if HasIndexSlice is implemented, which it is not (now)

super(NAME);
this.vectorsScorer = vectorsScorer;
}

@Override
protected FlatVectorsScorer flatVectorsScorer() {
return vectorsScorer;
}

@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new ES93BFloat16FlatVectorsWriter(state, vectorsScorer);
}

@Override
protected FlatVectorsReader createReader(SegmentReadState state) throws IOException {
return new ES93BFloat16FlatVectorsReader(state, vectorsScorer);
}
}
Loading