Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -486,4 +486,5 @@
exports org.elasticsearch.index.codec.vectors to org.elasticsearch.test.knn;
exports org.elasticsearch.index.codec.vectors.es818 to org.elasticsearch.test.knn;
exports org.elasticsearch.inference.telemetry;
exports org.elasticsearch.index.codec.vectors.es91 to org.elasticsearch.test.knn;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
Expand All @@ -29,6 +28,7 @@
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;

import java.io.IOException;
import java.util.Map;
Expand All @@ -39,7 +39,7 @@ public class ES813FlatVectorFormat extends KnnVectorsFormat {

static final String NAME = "ES813FlatVectorFormat";

private static final FlatVectorsFormat format = new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);
private static final FlatVectorsFormat format = new ES91BFloat16FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);

/**
* Sole constructor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
Expand All @@ -34,6 +33,7 @@
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.QuantizedVectorsReader;
import org.apache.lucene.util.quantization.ScalarQuantizer;
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;
import org.elasticsearch.simdvec.VectorScorerFactory;
import org.elasticsearch.simdvec.VectorSimilarityType;

Expand All @@ -48,7 +48,7 @@ public class ES814ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
static final String NAME = "ES814ScalarQuantizedVectorsFormat";
private static final int ALLOWED_BITS = (1 << 8) | (1 << 7) | (1 << 4);

private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);
private static final FlatVectorsFormat rawVectorFormat = new ES91BFloat16FlatVectorsFormat(DefaultFlatVectorScorer.INSTANCE);

static final FlatVectorsScorer flatVectorScorer = new ESFlatVectorsScorer(
new ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.SegmentReadState;
Expand All @@ -24,14 +23,15 @@
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;

import java.io.IOException;

import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.MAX_DIMS_COUNT;

class ES815BitFlatVectorsFormat extends FlatVectorsFormat {

private static final FlatVectorsFormat delegate = new Lucene99FlatVectorsFormat(FlatBitVectorScorer.INSTANCE);
private static final FlatVectorsFormat delegate = new ES91BFloat16FlatVectorsFormat(FlatBitVectorScorer.INSTANCE);

protected ES815BitFlatVectorsFormat() {
super("ES815BitFlatVectorsFormat");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;

import java.io.IOException;

Expand All @@ -47,7 +47,7 @@ public class ES816BinaryQuantizedVectorsFormat extends FlatVectorsFormat {
static final String VECTOR_DATA_EXTENSION = "veb";
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;

private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(
private static final FlatVectorsFormat rawVectorFormat = new ES91BFloat16FlatVectorsFormat(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
import org.elasticsearch.index.codec.vectors.es91.ES91BFloat16FlatVectorsFormat;

import java.io.IOException;

Expand Down Expand Up @@ -110,7 +110,7 @@ private static boolean getUseDirectIO() {

private static final FlatVectorsFormat rawVectorFormat = USE_DIRECT_IO
? new DirectIOLucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer())
: new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
: new ES91BFloat16FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());

private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer(
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec.vectors.es91;

import org.apache.lucene.util.BitUtil;

class BFloat16 {

public static final int BYTES = Short.BYTES;

public static short floatToBFloat16(float f) {
// this does round towards 0
// zero - zero exp, zero fraction
// denormal - zero exp, non-zero fraction
// infinity - all-1 exp, zero fraction
// NaN - all-1 exp, non-zero fraction
// the Float.NaN constant is 0x7fc0_0000, so this won't turn the most common NaN values into infinities
return (short) (Float.floatToIntBits(f) >>> 16);
}

public static float bFloat16ToFloat(short bf) {
return Float.intBitsToFloat(bf << 16);
}

public static short[] floatToBFloat16(float[] f) {
short[] bf = new short[f.length];
for (int i = 0; i < f.length; i++) {
bf[i] = floatToBFloat16(f[i]);
}
return bf;
}

public static float[] bFloat16ToFloat(short[] bf) {
float[] f = new float[bf.length];
for (int i = 0; i < bf.length; i++) {
f[i] = bFloat16ToFloat(bf[i]);
}
return f;
}

public static void bFloat16ToFloat(byte[] bfBytes, float[] floats) {
assert floats.length * 2 == bfBytes.length;
for (int i = 0; i < floats.length; i++) {
floats[i] = bFloat16ToFloat((short) BitUtil.VH_LE_SHORT.get(bfBytes, i * 2));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* @notice
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Modifications copyright (C) 2024 Elasticsearch B.V.
*/
package org.elasticsearch.index.codec.vectors.es91;

import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;

import java.io.IOException;

public class ES91BFloat16FlatVectorsFormat extends FlatVectorsFormat {

static final String NAME = "ES91BFloat16FlatVectorsFormat";
static final String META_CODEC_NAME = "ES91BFloat16FlatVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "ES91BFloat16FlatVectorsFormatData";
static final String META_EXTENSION = "vemf";
static final String VECTOR_DATA_EXTENSION = "vec";

public static final int VERSION_START = 0;
public static final int VERSION_CURRENT = VERSION_START;

static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
private final FlatVectorsScorer vectorsScorer;

/** Constructs a format */
public ES91BFloat16FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
super(NAME);
this.vectorsScorer = vectorsScorer;
}

@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new ES91BFloat16FlatVectorsWriter(state, vectorsScorer);
}

@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new ES91BFloat16FlatVectorsReader(state, vectorsScorer);
}

@Override
public String toString() {
return "Lucene99FlatVectorsFormat(" + "vectorsScorer=" + vectorsScorer + ')';
}
}
Loading