Skip to content

Commit 722b209

Browse files
thecoopncordon
authored andcommitted
Reapply ES|QL bfloat16 support tests (elastic#138499) (elastic#138584)
Add bfloat16 tests with a feature flag this time
1 parent 8ef32e4 commit 722b209

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.cluster.metadata.IndexMetadata;
1515
import org.elasticsearch.common.settings.Settings;
1616
import org.elasticsearch.index.codec.vectors.BFloat16;
17+
import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat;
1718
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1819
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
1920
import org.elasticsearch.script.field.vectors.DenseVector;
@@ -61,16 +62,20 @@ private enum VectorSourceOptions {
6162
.collect(Collectors.toSet());
6263

6364
public static final float DELTA = 1e-7F;
65+
public static final float BFLOAT16_DELTA = 1e-2F;
6466

6567
private final ElementType elementType;
6668
private final DenseVectorFieldMapper.VectorSimilarity similarity;
6769
private final VectorSourceOptions sourceOptions;
6870
private final boolean index;
6971

7072
@ParametersFactory
71-
public static Iterable<Object[]> parameters() throws Exception {
73+
public static Iterable<Object[]> parameters() {
7274
List<Object[]> params = new ArrayList<>();
73-
for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT)) {
75+
ElementType[] elementTypes = ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled()
76+
? ElementType.values()
77+
: new ElementType[] { ElementType.BYTE, ElementType.FLOAT, ElementType.BIT };
78+
for (ElementType elementType : elementTypes) {
7479
// Test all similarities
7580
for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) {
7681
if (elementType == ElementType.BIT && similarity != DenseVectorFieldMapper.VectorSimilarity.L2_NORM) {
@@ -137,8 +142,10 @@ public void testRetrieveTopNDenseVectorFieldData() {
137142
} else {
138143
assertNotNull(actualVector);
139144
assertEquals(expectedVector.size(), actualVector.size());
145+
146+
float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA;
140147
for (int i = 0; i < expectedVector.size(); i++) {
141-
assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), DELTA);
148+
assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), delta);
142149
}
143150
}
144151
});
@@ -167,12 +174,14 @@ public void testRetrieveDenseVectorFieldData() {
167174
} else {
168175
assertNotNull(actualVector);
169176
assertEquals(expectedVector.size(), actualVector.size());
177+
178+
float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA;
170179
for (int i = 0; i < actualVector.size(); i++) {
171180
assertEquals(
172181
"Actual: " + actualVector + "; expected: " + expectedVector,
173182
expectedVector.get(i).floatValue(),
174183
actualVector.get(i).floatValue(),
175-
DELTA
184+
delta
176185
);
177186
}
178187
}
@@ -253,12 +262,13 @@ public void setup() throws IOException {
253262
} else {
254263
for (int j = 0; j < numDims; j++) {
255264
switch (elementType) {
256-
case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true));
265+
case FLOAT, BFLOAT16 -> vector.add(randomFloatBetween(0F, 1F, true));
257266
case BYTE, BIT -> vector.add((byte) randomIntBetween(-128, 127));
258267
default -> throw new IllegalArgumentException("Unexpected element type: " + elementType);
259268
}
260269
}
261-
if ((elementType == ElementType.FLOAT) && (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) {
270+
if ((elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16)
271+
&& (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) {
262272
// Normalize the vector
263273
float magnitude = DenseVector.getMagnitude(vector);
264274
vector.replaceAll(number -> number.floatValue() / magnitude);

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/KnnFunctionIT.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.cluster.metadata.IndexMetadata;
1616
import org.elasticsearch.common.settings.Settings;
1717
import org.elasticsearch.index.IndexSettings;
18+
import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat;
1819
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1920
import org.elasticsearch.xcontent.XContentBuilder;
2021
import org.elasticsearch.xcontent.XContentFactory;
@@ -57,6 +58,9 @@ public static Iterable<Object[]> parameters() throws Exception {
5758
List<Object[]> params = new ArrayList<>();
5859
for (String indexType : ALL_DENSE_VECTOR_INDEX_TYPES) {
5960
params.add(new Object[] { DenseVectorFieldMapper.ElementType.FLOAT, indexType });
61+
if (ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled()) {
62+
params.add(new Object[] { DenseVectorFieldMapper.ElementType.BFLOAT16, indexType });
63+
}
6064
}
6165
for (String indexType : NON_QUANTIZED_DENSE_VECTOR_INDEX_TYPES) {
6266
params.add(new Object[] { DenseVectorFieldMapper.ElementType.BYTE, indexType });
@@ -264,7 +268,7 @@ public void setup() throws IOException {
264268
List<Number> vector = new ArrayList<>(numDims);
265269
for (int j = 0; j < numDims; j++) {
266270
switch (elementType) {
267-
case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true));
271+
case FLOAT, BFLOAT16 -> vector.add(randomFloatBetween(0F, 1F, true));
268272
case BYTE, BIT -> vector.add((byte) (randomFloatBetween(0F, 1F, true) * 127.0f));
269273
default -> throw new IllegalArgumentException("Unexpected element type: " + elementType);
270274
}

0 commit comments

Comments
 (0)