Skip to content

Commit 06f6edf

Browse files
authored
Add bfloat16 support to rank_vectors (elastic#139463)
1 parent 4a12eba commit 06f6edf

File tree

9 files changed

+343
-22
lines changed

9 files changed

+343
-22
lines changed

docs/changelog/139463.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 139463
2+
summary: Add bfloat16 support to `rank_vectors`
3+
area: Vector Search
4+
type: feature
5+
issues: []

docs/reference/elasticsearch/mapping-reference/rank-vectors.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,27 @@ PUT my-rank-vectors-float/_doc/1
3434
```
3535
% TESTSETUP
3636

37-
In addition to the `float` element type, `byte` and `bit` element types are also supported.
37+
In addition to the `float` element type, `bfloat16`, `byte`, and `bit` element types are also supported.
38+
39+
Here is an example of using this field with `bfloat16` elements.
40+
```console
41+
PUT my-rank-vectors-bfloat16
42+
{
43+
"mappings": {
44+
"properties": {
45+
"my_vector": {
46+
"type": "rank_vectors",
47+
"element_type": "bfloat16"
48+
}
49+
}
50+
}
51+
}
52+
53+
PUT my-rank-vectors-bfloat16/_doc/1
54+
{
55+
"my_vector" : [[0.5, 10, 6], [-0.5, 10, 10]]
56+
}
57+
```
3858

3959
Here is an example of using this field with `byte` elements.
4060

@@ -92,6 +112,9 @@ $$$rank-vectors-element-type$$$
92112
`float`
93113
: indexes a 4-byte floating-point value per dimension. This is the default value.
94114

115+
`bfloat16` {applies_to}`stack: ga 9.3`
116+
: indexes a 2-byte floating-point value per dimension.
117+
95118
`byte`
96119
: indexes a 1-byte integer value per dimension.
97120

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.script.field.vectors;
11+
12+
import org.apache.lucene.index.BinaryDocValues;
13+
import org.apache.lucene.util.BytesRef;
14+
import org.elasticsearch.index.codec.vectors.BFloat16;
15+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
16+
import org.elasticsearch.index.mapper.vectors.RankVectorsScriptDocValues;
17+
18+
import java.io.IOException;
19+
import java.nio.ByteBuffer;
20+
import java.nio.ByteOrder;
21+
import java.nio.ShortBuffer;
22+
import java.util.Iterator;
23+
24+
public class BFloat16RankVectorsDocValuesField extends RankVectorsDocValuesField {
25+
26+
private final BinaryDocValues input;
27+
private final BinaryDocValues magnitudes;
28+
private boolean decoded;
29+
private final int dims;
30+
private BytesRef value;
31+
private BytesRef magnitudesValue;
32+
private BFloat16VectorIterator vectorValues;
33+
private int numVectors;
34+
private float[] buffer;
35+
36+
public BFloat16RankVectorsDocValuesField(
37+
BinaryDocValues input,
38+
BinaryDocValues magnitudes,
39+
String name,
40+
DenseVectorFieldMapper.ElementType elementType,
41+
int dims
42+
) {
43+
super(name, elementType);
44+
this.input = input;
45+
this.magnitudes = magnitudes;
46+
this.dims = dims;
47+
this.buffer = new float[dims];
48+
}
49+
50+
@Override
51+
public void setNextDocId(int docId) throws IOException {
52+
decoded = false;
53+
if (input.advanceExact(docId)) {
54+
boolean magnitudesFound = magnitudes.advanceExact(docId);
55+
assert magnitudesFound;
56+
57+
value = input.binaryValue();
58+
assert value.length % (BFloat16.BYTES * dims) == 0;
59+
numVectors = value.length / (BFloat16.BYTES * dims);
60+
magnitudesValue = magnitudes.binaryValue();
61+
assert magnitudesValue.length == (Float.BYTES * numVectors);
62+
} else {
63+
value = null;
64+
magnitudesValue = null;
65+
numVectors = 0;
66+
}
67+
}
68+
69+
@Override
70+
public RankVectorsScriptDocValues toScriptDocValues() {
71+
return new RankVectorsScriptDocValues(this, dims);
72+
}
73+
74+
@Override
75+
public boolean isEmpty() {
76+
return value == null;
77+
}
78+
79+
@Override
80+
public RankVectors get() {
81+
if (isEmpty()) {
82+
return RankVectors.EMPTY;
83+
}
84+
decodeVectorIfNecessary();
85+
return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims);
86+
}
87+
88+
@Override
89+
public RankVectors get(RankVectors defaultValue) {
90+
if (isEmpty()) {
91+
return defaultValue;
92+
}
93+
decodeVectorIfNecessary();
94+
return new FloatRankVectors(vectorValues, magnitudesValue, numVectors, dims);
95+
}
96+
97+
@Override
98+
public RankVectors getInternal() {
99+
return get(null);
100+
}
101+
102+
@Override
103+
public int size() {
104+
return value == null ? 0 : value.length / (BFloat16.BYTES * dims);
105+
}
106+
107+
private void decodeVectorIfNecessary() {
108+
if (decoded == false && value != null) {
109+
vectorValues = new BFloat16VectorIterator(value, buffer, numVectors);
110+
decoded = true;
111+
}
112+
}
113+
114+
public static class BFloat16VectorIterator implements VectorIterator<float[]> {
115+
private final float[] buffer;
116+
private final ShortBuffer vectorValues;
117+
private final BytesRef vectorValueBytesRef;
118+
private final int size;
119+
private int idx = 0;
120+
121+
public BFloat16VectorIterator(BytesRef vectorValues, float[] buffer, int size) {
122+
assert vectorValues.length == (buffer.length * BFloat16.BYTES * size);
123+
this.vectorValueBytesRef = vectorValues;
124+
this.vectorValues = ByteBuffer.wrap(vectorValues.bytes, vectorValues.offset, vectorValues.length)
125+
.order(ByteOrder.LITTLE_ENDIAN)
126+
.asShortBuffer();
127+
this.size = size;
128+
this.buffer = buffer;
129+
}
130+
131+
@Override
132+
public boolean hasNext() {
133+
return idx < size;
134+
}
135+
136+
@Override
137+
public float[] next() {
138+
if (hasNext() == false) {
139+
throw new IllegalArgumentException("No more elements in the iterator");
140+
}
141+
BFloat16.bFloat16ToFloat(vectorValues, buffer);
142+
idx++;
143+
return buffer;
144+
}
145+
146+
@Override
147+
public Iterator<float[]> copy() {
148+
return new BFloat16VectorIterator(vectorValueBytesRef, new float[buffer.length], size);
149+
}
150+
151+
@Override
152+
public void reset() {
153+
idx = 0;
154+
vectorValues.rewind();
155+
}
156+
}
157+
}

x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsDVLeafFieldData.java

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
import org.apache.lucene.index.DocValues;
1212
import org.apache.lucene.index.LeafReader;
1313
import org.apache.lucene.util.BytesRef;
14+
import org.elasticsearch.index.codec.vectors.BFloat16;
1415
import org.elasticsearch.index.fielddata.FormattedDocValues;
1516
import org.elasticsearch.index.fielddata.LeafFieldData;
1617
import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
1718
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1819
import org.elasticsearch.script.field.DocValuesScriptFieldFactory;
20+
import org.elasticsearch.script.field.vectors.BFloat16RankVectorsDocValuesField;
1921
import org.elasticsearch.script.field.vectors.BitRankVectorsDocValuesField;
2022
import org.elasticsearch.script.field.vectors.ByteRankVectorsDocValuesField;
2123
import org.elasticsearch.script.field.vectors.FloatRankVectorsDocValuesField;
@@ -128,7 +130,46 @@ public Object nextValue() {
128130
return vectors;
129131
}
130132
};
131-
case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16");
133+
case BFLOAT16 -> new FormattedDocValues() {
134+
private final float[] vector = new float[dims];
135+
private BytesRef ref = null;
136+
private int numVecs = -1;
137+
private final BinaryDocValues binary;
138+
{
139+
try {
140+
binary = DocValues.getBinary(reader, field);
141+
} catch (IOException e) {
142+
throw new IllegalStateException("Cannot load doc values", e);
143+
}
144+
}
145+
146+
@Override
147+
public boolean advanceExact(int docId) throws IOException {
148+
if (binary == null || binary.advanceExact(docId) == false) {
149+
return false;
150+
}
151+
ref = binary.binaryValue();
152+
assert ref.length % (BFloat16.BYTES * dims) == 0;
153+
numVecs = ref.length / (BFloat16.BYTES * dims);
154+
return true;
155+
}
156+
157+
@Override
158+
public int docValueCount() {
159+
return 1;
160+
}
161+
162+
@Override
163+
public Object nextValue() {
164+
List<float[]> vectors = new ArrayList<>(numVecs);
165+
VectorIterator<float[]> iterator = new BFloat16RankVectorsDocValuesField.BFloat16VectorIterator(ref, vector, numVecs);
166+
while (iterator.hasNext()) {
167+
float[] v = iterator.next();
168+
vectors.add(Arrays.copyOf(v, v.length));
169+
}
170+
return vectors;
171+
}
172+
};
132173
};
133174
}
134175

@@ -140,8 +181,8 @@ public DocValuesScriptFieldFactory getScriptFieldFactory(String name) {
140181
return switch (elementType) {
141182
case BYTE -> new ByteRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
142183
case FLOAT -> new FloatRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
184+
case BFLOAT16 -> new BFloat16RankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
143185
case BIT -> new BitRankVectorsDocValuesField(values, magnitudeValues, name, elementType, dims);
144-
case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16");
145186
};
146187
} catch (IOException e) {
147188
throw new IllegalStateException("Cannot load doc values for multi-vector field!", e);

x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapper.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.apache.lucene.util.BytesRef;
1616
import org.elasticsearch.common.xcontent.support.XContentMapValues;
1717
import org.elasticsearch.index.IndexVersion;
18+
import org.elasticsearch.index.codec.vectors.BFloat16;
1819
import org.elasticsearch.index.fielddata.FieldDataContext;
1920
import org.elasticsearch.index.fielddata.IndexFieldData;
2021
import org.elasticsearch.index.mapper.ArraySourceValueFetcher;
@@ -77,9 +78,6 @@ public static class Builder extends FieldMapper.Builder {
7778
"invalid element_type [" + o + "]; available types are " + namesToElementType.keySet()
7879
);
7980
}
80-
if (elementType == ElementType.BFLOAT16) {
81-
throw new MapperParsingException("Rank vectors does not support bfloat16");
82-
}
8381
return elementType;
8482
},
8583
m -> toType(m).fieldType().element.elementType(),
@@ -497,6 +495,13 @@ private List<List<?>> copyVectorsAsList() throws IOException {
497495
}
498496
vectors.add(vec);
499497
}
498+
case BFLOAT16 -> {
499+
List<Float> vec = new ArrayList<>(dims);
500+
for (int dim = 0; dim < dims; dim++) {
501+
vec.add(BFloat16.bFloat16ToFloat(byteBuffer.getShort()));
502+
}
503+
vectors.add(vec);
504+
}
500505
case BYTE, BIT -> {
501506
List<Byte> vec = new ArrayList<>(dims);
502507
for (int dim = 0; dim < dims; dim++) {

x-pack/plugin/rank-vectors/src/main/java/org/elasticsearch/xpack/rank/vectors/script/RankVectorsScoreScriptUtils.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ public static final class MaxSimInvHamming {
179179

180180
public MaxSimInvHamming(ScoreScript scoreScript, Object queryVector, String fieldName) {
181181
RankVectorsDocValuesField field = (RankVectorsDocValuesField) scoreScript.field(fieldName);
182-
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT) {
182+
if (field.getElementType() == DenseVectorFieldMapper.ElementType.FLOAT
183+
|| field.getElementType() == DenseVectorFieldMapper.ElementType.BFLOAT16) {
183184
throw new IllegalArgumentException("hamming distance is only supported for byte or bit vectors");
184185
}
185186
BytesOrList bytesOrList = parseBytes(queryVector);
@@ -351,13 +352,12 @@ public MaxSimDotProduct(ScoreScript scoreScript, Object queryVector, String fiel
351352
yield new MaxSimByteDotProduct(scoreScript, field, bytesOrList.list);
352353
}
353354
}
354-
case FLOAT -> {
355+
case FLOAT, BFLOAT16 -> {
355356
if (queryVector instanceof List) {
356357
yield new MaxSimFloatDotProduct(scoreScript, field, (List<List<Number>>) queryVector);
357358
}
358359
throw new IllegalArgumentException("Unsupported input object for float vectors: " + queryVector.getClass().getName());
359360
}
360-
case BFLOAT16 -> throw new IllegalArgumentException("Unsupported element type: bfloat16");
361361
};
362362
}
363363

x-pack/plugin/rank-vectors/src/test/java/org/elasticsearch/xpack/rank/vectors/mapper/RankVectorsFieldMapperTests.java

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import java.util.stream.Stream;
4949

5050
import static org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase.randomNormalizedVector;
51+
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTests.convertToBFloat16List;
5152
import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTests.convertToList;
5253
import static org.hamcrest.Matchers.containsString;
5354
import static org.hamcrest.Matchers.equalTo;
@@ -61,9 +62,12 @@ public class RankVectorsFieldMapperTests extends SyntheticVectorsMapperTestCase
6162
private final int dims;
6263

6364
public RankVectorsFieldMapperTests() {
64-
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT);
65+
this.elementType = randomFrom(ElementType.BYTE, ElementType.FLOAT, ElementType.BFLOAT16, ElementType.BIT);
6566
int baseDims = ElementType.BIT == elementType ? 4 * Byte.SIZE : 4;
66-
int randomMultiplier = ElementType.FLOAT == elementType ? randomIntBetween(1, 64) : 1;
67+
int randomMultiplier = switch (elementType) {
68+
case FLOAT, BFLOAT16 -> randomIntBetween(1, 64);
69+
case BYTE, BIT -> 1;
70+
};
6771
this.dims = baseDims * randomMultiplier;
6872
}
6973

@@ -97,11 +101,12 @@ protected Object getSampleValueForDocument(boolean binaryFormat) {
97101
@Override
98102
protected Object getSampleValueForDocument() {
99103
int numVectors = randomIntBetween(1, 16);
100-
return Stream.generate(
101-
() -> elementType == ElementType.FLOAT
102-
? convertToList(randomNormalizedVector(this.dims))
103-
: convertToList(randomByteArrayOfLength(elementType == ElementType.BIT ? this.dims / Byte.SIZE : dims))
104-
).limit(numVectors).toList();
104+
return Stream.generate(switch (elementType) {
105+
case FLOAT -> () -> convertToList(randomNormalizedVector(this.dims));
106+
case BFLOAT16 -> () -> convertToBFloat16List(randomNormalizedVector(this.dims));
107+
case BYTE -> () -> convertToList(randomByteArrayOfLength(dims));
108+
case BIT -> () -> convertToList(randomByteArrayOfLength(dims / Byte.SIZE));
109+
}).limit(numVectors).toList();
105110
}
106111

107112
@Override
@@ -119,6 +124,21 @@ protected void registerParameters(ParameterChecker checker) throws IOException {
119124
checker.registerConflictCheck(
120125
"element_type",
121126
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "float")),
127+
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "bfloat16"))
128+
);
129+
checker.registerConflictCheck(
130+
"element_type",
131+
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "byte")),
132+
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "bfloat16"))
133+
);
134+
checker.registerConflictCheck(
135+
"element_type",
136+
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "float")),
137+
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims * 8).field("element_type", "bit"))
138+
);
139+
checker.registerConflictCheck(
140+
"element_type",
141+
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims).field("element_type", "bfloat16")),
122142
fieldMapping(b -> b.field("type", "rank_vectors").field("dims", dims * 8).field("element_type", "bit"))
123143
);
124144
checker.registerConflictCheck(

0 commit comments

Comments
 (0)