Skip to content

Commit 6f70507

Browse files
authored
Test ES|QL bfloat16 support (#138499)
1 parent 5dec80b commit 6f70507

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ private enum VectorSourceOptions {
6161
.collect(Collectors.toSet());
6262

6363
public static final float DELTA = 1e-7F;
64+
public static final float BFLOAT16_DELTA = 1e-2F;
6465

6566
private final ElementType elementType;
6667
private final DenseVectorFieldMapper.VectorSimilarity similarity;
@@ -70,7 +71,7 @@ private enum VectorSourceOptions {
7071
@ParametersFactory
7172
public static Iterable<Object[]> parameters() throws Exception {
7273
List<Object[]> params = new ArrayList<>();
73-
for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT)) {
74+
for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT, ElementType.BFLOAT16)) {
7475
// Test all similarities
7576
for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) {
7677
if (elementType == ElementType.BIT && similarity != DenseVectorFieldMapper.VectorSimilarity.L2_NORM) {
@@ -137,8 +138,10 @@ public void testRetrieveTopNDenseVectorFieldData() {
137138
} else {
138139
assertNotNull(actualVector);
139140
assertEquals(expectedVector.size(), actualVector.size());
141+
142+
float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA;
140143
for (int i = 0; i < expectedVector.size(); i++) {
141-
assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), DELTA);
144+
assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), delta);
142145
}
143146
}
144147
});
@@ -167,12 +170,14 @@ public void testRetrieveDenseVectorFieldData() {
167170
} else {
168171
assertNotNull(actualVector);
169172
assertEquals(expectedVector.size(), actualVector.size());
173+
174+
float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA;
170175
for (int i = 0; i < actualVector.size(); i++) {
171176
assertEquals(
172177
"Actual: " + actualVector + "; expected: " + expectedVector,
173178
expectedVector.get(i).floatValue(),
174179
actualVector.get(i).floatValue(),
175-
DELTA
180+
delta
176181
);
177182
}
178183
}
@@ -253,12 +258,13 @@ public void setup() throws IOException {
253258
} else {
254259
for (int j = 0; j < numDims; j++) {
255260
switch (elementType) {
256-
case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true));
261+
case FLOAT, BFLOAT16 -> vector.add(randomFloatBetween(0F, 1F, true));
257262
case BYTE, BIT -> vector.add((byte) randomIntBetween(-128, 127));
258263
default -> throw new IllegalArgumentException("Unexpected element type: " + elementType);
259264
}
260265
}
261-
if ((elementType == ElementType.FLOAT) && (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) {
266+
if ((elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16)
267+
&& (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) {
262268
// Normalize the vector
263269
float magnitude = DenseVector.getMagnitude(vector);
264270
vector.replaceAll(number -> number.floatValue() / magnitude);

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ public static Iterable<Object[]> parameters() throws Exception {
5757
List<Object[]> params = new ArrayList<>();
5858
for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) {
5959
params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType });
60+
params.add(new Object[] { DenseVectorFieldMapper.ElementType.BFLOAT16, indexType });
6061
}
6162
for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) {
6263
params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType });
@@ -264,7 +265,7 @@ public void setup() throws IOException {
264265
List<Number> vector = new ArrayList<>(numDims);
265266
for (int j = 0; j < numDims; j++) {
266267
switch (elementType) {
267-
case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true));
268+
case FLOAT, BFLOAT16 -> vector.add(randomFloatBetween(0F, 1F, true));
268269
case BYTE, BIT -> vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127.0f));
269270
default -> throw new IllegalArgumentException("Unexpected element type: " + elementType);
270271
}

0 commit comments

Comments
 (0)