Skip to content

Commit fde54fd

Browse files
committed
Revert "Don't add to DenseVectorFieldMapper yet"
This reverts commit feb7aee.
1 parent 3f91f54 commit fde54fd

File tree

16 files changed

+257
-47
lines changed

16 files changed

+257
-47
lines changed

server/src/main/java/org/elasticsearch/index/mapper/BlockDocValuesReader.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.lucene.util.BytesRef;
2323
import org.elasticsearch.common.io.stream.ByteArrayStreamInput;
2424
import org.elasticsearch.index.IndexVersion;
25+
import org.elasticsearch.index.codec.vectors.BFloat16;
2526
import org.elasticsearch.index.mapper.BlockLoader.BlockFactory;
2627
import org.elasticsearch.index.mapper.BlockLoader.BooleanBuilder;
2728
import org.elasticsearch.index.mapper.BlockLoader.Builder;
@@ -36,6 +37,9 @@
3637
import org.elasticsearch.search.fetch.StoredFieldsSpec;
3738

3839
import java.io.IOException;
40+
import java.nio.ByteBuffer;
41+
import java.nio.ByteOrder;
42+
import java.nio.ShortBuffer;
3943

4044
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.COSINE_MAGNITUDE_FIELD_SUFFIX;
4145

@@ -536,7 +540,8 @@ public Builder builder(BlockFactory factory, int expectedCount) {
536540
@Override
537541
public AllReader reader(LeafReaderContext context) throws IOException {
538542
switch (fieldType.getElementType()) {
539-
case FLOAT -> {
543+
case FLOAT, BFLOAT16 -> {
544+
// BFloat16 is handled by the implementation of FloatVectorValues
540545
FloatVectorValues floatVectorValues = context.reader().getFloatVectorValues(fieldName);
541546
if (floatVectorValues != null) {
542547
if (fieldType.isNormalized()) {
@@ -1052,6 +1057,7 @@ public AllReader reader(LeafReaderContext context) throws IOException {
10521057
}
10531058
return switch (elementType) {
10541059
case FLOAT -> new FloatDenseVectorFromBinary(docValues, dims, indexVersion);
1060+
case BFLOAT16 -> new BFloat16DenseVectorFromBinary(docValues, dims, indexVersion);
10551061
case BYTE -> new ByteDenseVectorFromBinary(docValues, dims, indexVersion);
10561062
case BIT -> new BitDenseVectorFromBinary(docValues, dims, indexVersion);
10571063
};
@@ -1135,6 +1141,26 @@ public String toString() {
11351141
}
11361142
}
11371143

1144+
private static class BFloat16DenseVectorFromBinary extends FloatDenseVectorFromBinary {
1145+
BFloat16DenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) {
1146+
super(docValues, dims, indexVersion);
1147+
}
1148+
1149+
@Override
1150+
protected void decodeDenseVector(BytesRef bytesRef, float[] scratch) {
1151+
VectorEncoderDecoder.decodeDenseVector(indexVersion, bytesRef, scratch);
1152+
ShortBuffer sb = ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length)
1153+
.order(ByteOrder.LITTLE_ENDIAN)
1154+
.asShortBuffer();
1155+
BFloat16.bFloat16ToFloat(sb, scratch);
1156+
}
1157+
1158+
@Override
1159+
public String toString() {
1160+
return "BFloat16DenseVectorFromBinary.Bytes";
1161+
}
1162+
}
1163+
11381164
private static class ByteDenseVectorFromBinary extends AbstractDenseVectorFromBinary<byte[]> {
11391165
ByteDenseVectorFromBinary(BinaryDocValues docValues, int dims, IndexVersion indexVersion) {
11401166
this(docValues, dims, indexVersion, dims);

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import org.elasticsearch.index.IndexSettings;
4848
import org.elasticsearch.index.IndexVersion;
4949
import org.elasticsearch.index.IndexVersions;
50+
import org.elasticsearch.index.codec.vectors.BFloat16;
5051
import org.elasticsearch.index.codec.vectors.ES813FlatVectorFormat;
5152
import org.elasticsearch.index.codec.vectors.ES813Int8FlatVectorFormat;
5253
import org.elasticsearch.index.codec.vectors.ES814HnswScalarQuantizedVectorsFormat;
@@ -461,6 +462,7 @@ public DenseVectorFieldMapper build(MapperBuilderContext context) {
461462
public enum ElementType {
462463
BYTE,
463464
FLOAT,
465+
BFLOAT16,
464466
BIT;
465467

466468
public static ElementType fromString(String name) {
@@ -475,13 +477,16 @@ public String toString() {
475477

476478
public static final Element BYTE_ELEMENT = new ByteElement();
477479
public static final Element FLOAT_ELEMENT = new FloatElement();
480+
public static final Element BFLOAT16_ELEMENT = new BFloat16Element();
478481
public static final Element BIT_ELEMENT = new BitElement();
479482

480483
public static final Map<String, ElementType> namesToElementType = Map.of(
481484
ElementType.BYTE.toString(),
482485
ElementType.BYTE,
483486
ElementType.FLOAT.toString(),
484487
ElementType.FLOAT,
488+
ElementType.BFLOAT16.toString(),
489+
ElementType.BFLOAT16,
485490
ElementType.BIT.toString(),
486491
ElementType.BIT
487492
);
@@ -491,6 +496,7 @@ public abstract static class Element {
491496
public static Element getElement(ElementType elementType) {
492497
return switch (elementType) {
493498
case FLOAT -> FLOAT_ELEMENT;
499+
case BFLOAT16 -> BFLOAT16_ELEMENT;
494500
case BYTE -> BYTE_ELEMENT;
495501
case BIT -> BIT_ELEMENT;
496502
};
@@ -1056,6 +1062,29 @@ static UnaryOperator<StringBuilder> errorElementsAppender(float[] vector) {
10561062
}
10571063
}
10581064

1065+
private static class BFloat16Element extends FloatElement {
1066+
1067+
@Override
1068+
public ElementType elementType() {
1069+
return ElementType.BFLOAT16;
1070+
}
1071+
1072+
@Override
1073+
public void writeValue(ByteBuffer byteBuffer, float value) {
1074+
byteBuffer.putShort(BFloat16.floatToBFloat16(value));
1075+
}
1076+
1077+
@Override
1078+
public void readAndWriteValue(ByteBuffer byteBuffer, XContentBuilder b) throws IOException {
1079+
b.value(BFloat16.bFloat16ToFloat(byteBuffer.getShort()));
1080+
}
1081+
1082+
@Override
1083+
public int getNumBytes(int dimensions) {
1084+
return dimensions * BFloat16.BYTES;
1085+
}
1086+
}
1087+
10591088
private static class BitElement extends ByteElement {
10601089

10611090
@Override
@@ -1123,7 +1152,7 @@ public enum VectorSimilarity {
11231152
@Override
11241153
float score(float similarity, ElementType elementType, int dim) {
11251154
return switch (elementType) {
1126-
case BYTE, FLOAT -> 1f / (1f + similarity * similarity);
1155+
case BYTE, FLOAT, BFLOAT16 -> 1f / (1f + similarity * similarity);
11271156
case BIT -> (dim - similarity) / dim;
11281157
};
11291158
}
@@ -1138,14 +1167,14 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi
11381167
float score(float similarity, ElementType elementType, int dim) {
11391168
assert elementType != ElementType.BIT;
11401169
return switch (elementType) {
1141-
case BYTE, FLOAT -> (1 + similarity) / 2f;
1170+
case BYTE, FLOAT, BFLOAT16 -> (1 + similarity) / 2f;
11421171
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
11431172
};
11441173
}
11451174

11461175
@Override
11471176
public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersion, ElementType elementType) {
1148-
return indexVersion.onOrAfter(NORMALIZE_COSINE) && ElementType.FLOAT.equals(elementType)
1177+
return indexVersion.onOrAfter(NORMALIZE_COSINE) && (elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16)
11491178
? VectorSimilarityFunction.DOT_PRODUCT
11501179
: VectorSimilarityFunction.COSINE;
11511180
}
@@ -1155,7 +1184,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi
11551184
float score(float similarity, ElementType elementType, int dim) {
11561185
return switch (elementType) {
11571186
case BYTE -> 0.5f + similarity / (float) (dim * (1 << 15));
1158-
case FLOAT -> (1 + similarity) / 2f;
1187+
case FLOAT, BFLOAT16 -> (1 + similarity) / 2f;
11591188
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
11601189
};
11611190
}
@@ -1169,7 +1198,7 @@ public VectorSimilarityFunction vectorSimilarityFunction(IndexVersion indexVersi
11691198
@Override
11701199
float score(float similarity, ElementType elementType, int dim) {
11711200
return switch (elementType) {
1172-
case BYTE, FLOAT -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1;
1201+
case BYTE, FLOAT, BFLOAT16 -> similarity < 0 ? 1 / (1 + -1 * similarity) : similarity + 1;
11731202
default -> throw new IllegalArgumentException("Unsupported element type [" + elementType + "]");
11741203
};
11751204
}
@@ -1478,7 +1507,7 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
14781507

14791508
@Override
14801509
public boolean supportsElementType(ElementType elementType) {
1481-
return elementType == ElementType.FLOAT;
1510+
return elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16;
14821511
}
14831512

14841513
@Override
@@ -1502,7 +1531,7 @@ public DenseVectorIndexOptions parseIndexOptions(String fieldName, Map<String, ?
15021531

15031532
@Override
15041533
public boolean supportsElementType(ElementType elementType) {
1505-
return elementType == ElementType.FLOAT;
1534+
return elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16;
15061535
}
15071536

15081537
@Override
@@ -2044,8 +2073,10 @@ public BBQHnswIndexOptions(int m, int efConstruction, boolean onDiskRescore, Res
20442073

20452074
@Override
20462075
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
2047-
assert elementType == ElementType.FLOAT;
2048-
return new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, onDiskRescore, false);
2076+
return switch (elementType) {
2077+
case FLOAT -> new ES93HnswBinaryQuantizedVectorsFormat(m, efConstruction, onDiskRescore, false);
2078+
default -> throw new AssertionError();
2079+
};
20492080
}
20502081

20512082
@Override
@@ -2110,8 +2141,10 @@ static class BBQFlatIndexOptions extends QuantizedIndexOptions {
21102141

21112142
@Override
21122143
KnnVectorsFormat getVectorsFormat(ElementType elementType) {
2113-
assert elementType == ElementType.FLOAT;
2114-
return new ES818BinaryQuantizedVectorsFormat();
2144+
return switch (elementType) {
2145+
case FLOAT -> new ES818BinaryQuantizedVectorsFormat();
2146+
default -> throw new AssertionError();
2147+
};
21152148
}
21162149

21172150
@Override
@@ -2360,7 +2393,7 @@ public Query createExactKnnQuery(VectorData queryVector, Float vectorSimilarity)
23602393
}
23612394
Query knnQuery = switch (element.elementType()) {
23622395
case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
2363-
case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
2396+
case FLOAT, BFLOAT16 -> createExactKnnFloatQuery(queryVector.asFloatVector());
23642397
case BIT -> createExactKnnBitQuery(queryVector.asByteVector());
23652398
};
23662399
if (vectorSimilarity != null) {
@@ -2441,7 +2474,7 @@ public Query createKnnQuery(
24412474
knnSearchStrategy,
24422475
hnswEarlyTermination
24432476
);
2444-
case FLOAT -> createKnnFloatQuery(
2477+
case FLOAT, BFLOAT16 -> createKnnFloatQuery(
24452478
queryVector.asFloatVector(),
24462479
k,
24472480
numCands,

server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorDVLeafFieldData.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,14 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) {
6969
if (indexed) {
7070
return switch (elementType) {
7171
case BYTE -> new ByteKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
72-
case FLOAT -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims);
72+
case FLOAT, BFLOAT16 -> new KnnDenseVectorDocValuesField(reader.getFloatVectorValues(field), name, dims);
7373
case BIT -> new BitKnnDenseVectorDocValuesField(reader.getByteVectorValues(field), name, dims);
7474
};
7575
} else {
7676
BinaryDocValues values = DocValues.getBinary(reader, field);
7777
return switch (elementType) {
7878
case BYTE -> new ByteBinaryDenseVectorDocValuesField(values, name, elementType, dims);
79-
case FLOAT -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion);
79+
case FLOAT, BFLOAT16 -> new BinaryDenseVectorDocValuesField(values, name, elementType, dims, indexVersion);
8080
case BIT -> new BitBinaryDenseVectorDocValuesField(values, name, elementType, dims);
8181
};
8282
}
@@ -138,7 +138,7 @@ public Object nextValue() {
138138
return vectorValue;
139139
}
140140
};
141-
case FLOAT -> new FormattedDocValues() {
141+
case FLOAT, BFLOAT16 -> new FormattedDocValues() {
142142
float[] vector = new float[dims];
143143
private FloatVectorValues floatVectorValues; // use when indexed
144144
private KnnVectorValues.DocIndexIterator iterator; // use when indexed

server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ public L1Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
210210
}
211211
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
212212
}
213-
case FLOAT -> {
213+
case FLOAT, BFLOAT16 -> {
214214
if (queryVector instanceof List) {
215215
yield new FloatL1Norm(scoreScript, field, (List<Number>) queryVector);
216216
}
@@ -320,7 +320,7 @@ public L2Norm(ScoreScript scoreScript, Object queryVector, String fieldName) {
320320
}
321321
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
322322
}
323-
case FLOAT -> {
323+
case FLOAT, BFLOAT16 -> {
324324
if (queryVector instanceof List) {
325325
yield new FloatL2Norm(scoreScript, field, (List<Number>) queryVector);
326326
}
@@ -478,7 +478,7 @@ public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName)
478478
}
479479
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
480480
}
481-
case FLOAT -> {
481+
case FLOAT, BFLOAT16 -> {
482482
if (queryVector instanceof List) {
483483
yield new FloatDotProduct(scoreScript, field, (List<Number>) queryVector);
484484
}
@@ -547,7 +547,7 @@ public CosineSimilarity(ScoreScript scoreScript, Object queryVector, String fiel
547547
}
548548
throw new IllegalArgumentException("Unsupported input object for byte vectors: " + queryVector.getClass().getName());
549549
}
550-
case FLOAT -> {
550+
case FLOAT, BFLOAT16 -> {
551551
if (queryVector instanceof List) {
552552
yield new FloatCosineSimilarity(scoreScript, field, (List<Number>) queryVector);
553553
}

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTestUtils.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ private DenseVectorFieldMapperTestUtils() {}
2222

2323
public static List<SimilarityMeasure> getSupportedSimilarities(DenseVectorFieldMapper.ElementType elementType) {
2424
return switch (elementType) {
25-
case FLOAT, BYTE -> List.of(SimilarityMeasure.values());
25+
case FLOAT, BFLOAT16, BYTE -> List.of(SimilarityMeasure.values());
2626
case BIT -> List.of(SimilarityMeasure.L2_NORM);
2727
};
2828
}
2929

3030
public static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
3131
return switch (elementType) {
32-
case FLOAT, BYTE -> dimensions;
32+
case FLOAT, BFLOAT16, BYTE -> dimensions;
3333
case BIT -> {
3434
assert dimensions % Byte.SIZE == 0;
3535
yield dimensions / Byte.SIZE;
@@ -43,7 +43,7 @@ public static int randomCompatibleDimensions(DenseVectorFieldMapper.ElementType
4343
}
4444

4545
return switch (elementType) {
46-
case FLOAT, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max);
46+
case FLOAT, BFLOAT16, BYTE -> RandomNumbers.randomIntBetween(random(), 1, max);
4747
case BIT -> {
4848
if (max < 8) {
4949
throw new IllegalArgumentException("max must be at least 8 for bit vectors");

server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2448,7 +2448,7 @@ protected Object generateRandomInputValue(MappedFieldType ft) {
24482448
DenseVectorFieldType vectorFieldType = (DenseVectorFieldType) ft;
24492449
return switch (vectorFieldType.getElementType()) {
24502450
case BYTE -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions());
2451-
case FLOAT -> randomNormalizedVector(vectorFieldType.getVectorDimensions());
2451+
case FLOAT, BFLOAT16 -> randomNormalizedVector(vectorFieldType.getVectorDimensions());
24522452
case BIT -> randomByteArrayOfLength(vectorFieldType.getVectorDimensions() / 8);
24532453
};
24542454
}
@@ -3043,7 +3043,7 @@ public SyntheticSourceExample example(int maxValues) throws IOException {
30433043
Object value = switch (elementType) {
30443044
case BYTE, BIT:
30453045
yield randomList(dims, dims, ESTestCase::randomByte);
3046-
case FLOAT:
3046+
case FLOAT, BFLOAT16:
30473047
yield randomList(dims, dims, ESTestCase::randomFloat);
30483048
};
30493049
return new SyntheticSourceExample(value, value, this::mapping);

server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
247247
approxFilterQuery,
248248
expectedStrategy
249249
);
250-
case FLOAT -> new ESKnnFloatVectorQuery(
250+
case FLOAT, BFLOAT16 -> new ESKnnFloatVectorQuery(
251251
VECTOR_FIELD,
252252
queryBuilder.queryVector().asFloatVector(),
253253
k,
@@ -268,7 +268,7 @@ protected void doAssertLuceneQuery(KnnVectorQueryBuilder queryBuilder, Query que
268268
yield new DenseVectorQuery.Bytes(queryBuilder.queryVector().asByteVector(), VECTOR_FIELD);
269269
}
270270
}
271-
case FLOAT -> {
271+
case FLOAT, BFLOAT16 -> {
272272
if (filterQuery != null) {
273273
yield new BooleanQuery.Builder().add(
274274
new DenseVectorQuery.Floats(queryBuilder.queryVector().asFloatVector(), VECTOR_FIELD),

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ private static List<Float> generateEmbedding(String input, int dimensions, Dense
255255
// Copied from DenseVectorFieldMapperTestUtils due to dependency restrictions
256256
private static int getEmbeddingLength(DenseVectorFieldMapper.ElementType elementType, int dimensions) {
257257
return switch (elementType) {
258-
case FLOAT, BYTE -> dimensions;
258+
case FLOAT, BYTE, BFLOAT16 -> dimensions;
259259
case BIT -> {
260260
assert dimensions % Byte.SIZE == 0;
261261
yield dimensions / Byte.SIZE;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticTextEmbeddingPayload.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ public TextEmbeddingResults<?> responseBody(SageMakerModel model, InvokeEndpoint
9797
return switch (model.apiServiceSettings().elementType()) {
9898
case BIT -> TextEmbeddingBinary.PARSER.apply(p, null);
9999
case BYTE -> TextEmbeddingBytes.PARSER.apply(p, null);
100-
case FLOAT -> TextEmbeddingFloat.PARSER.apply(p, null);
100+
case FLOAT, BFLOAT16 -> TextEmbeddingFloat.PARSER.apply(p, null);
101101
};
102102
}
103103
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ private static SemanticTextField randomSemanticText(
269269
) throws IOException {
270270
ChunkedInference results = switch (model.getTaskType()) {
271271
case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) {
272-
case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs);
272+
case FLOAT, BFLOAT16 -> randomChunkedInferenceEmbeddingFloat(model, inputs);
273273
case BYTE, BIT -> randomChunkedInferenceEmbeddingByte(model, inputs);
274274
};
275275
case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false);

0 commit comments

Comments
 (0)