@@ -61,6 +61,7 @@ private enum VectorSourceOptions {
6161 .collect (Collectors .toSet ());
6262
6363 public static final float DELTA = 1e-7F ;
64+ public static final float BFLOAT16_DELTA = 1e-2F ;
6465
6566 private final ElementType elementType ;
6667 private final DenseVectorFieldMapper .VectorSimilarity similarity ;
@@ -70,7 +71,7 @@ private enum VectorSourceOptions {
7071 @ ParametersFactory
7172 public static Iterable <Object []> parameters () throws Exception {
7273 List <Object []> params = new ArrayList <>();
73- for (ElementType elementType : List .of (ElementType .BYTE , ElementType .FLOAT , ElementType .BIT )) {
74+ for (ElementType elementType : List .of (ElementType .BYTE , ElementType .FLOAT , ElementType .BIT , ElementType . BFLOAT16 )) {
7475 // Test all similarities
7576 for (DenseVectorFieldMapper .VectorSimilarity similarity : DenseVectorFieldMapper .VectorSimilarity .values ()) {
7677 if (elementType == ElementType .BIT && similarity != DenseVectorFieldMapper .VectorSimilarity .L2_NORM ) {
@@ -137,8 +138,10 @@ public void testRetrieveTopNDenseVectorFieldData() {
137138 } else {
138139 assertNotNull (actualVector );
139140 assertEquals (expectedVector .size (), actualVector .size ());
141+
142+ float delta = elementType == ElementType .BFLOAT16 ? BFLOAT16_DELTA : DELTA ;
140143 for (int i = 0 ; i < expectedVector .size (); i ++) {
141- assertEquals (expectedVector .get (i ).floatValue (), actualVector .get (i ).floatValue (), DELTA );
144+ assertEquals (expectedVector .get (i ).floatValue (), actualVector .get (i ).floatValue (), delta );
142145 }
143146 }
144147 });
@@ -167,12 +170,14 @@ public void testRetrieveDenseVectorFieldData() {
167170 } else {
168171 assertNotNull (actualVector );
169172 assertEquals (expectedVector .size (), actualVector .size ());
173+
174+ float delta = elementType == ElementType .BFLOAT16 ? BFLOAT16_DELTA : DELTA ;
170175 for (int i = 0 ; i < actualVector .size (); i ++) {
171176 assertEquals (
172177 "Actual: " + actualVector + "; expected: " + expectedVector ,
173178 expectedVector .get (i ).floatValue (),
174179 actualVector .get (i ).floatValue (),
175- DELTA
180+ delta
176181 );
177182 }
178183 }
@@ -253,12 +258,13 @@ public void setup() throws IOException {
253258 } else {
254259 for (int j = 0 ; j < numDims ; j ++) {
255260 switch (elementType ) {
256- case FLOAT -> vector .add (randomFloatBetween (0F , 1F , true ));
261+ case FLOAT , BFLOAT16 -> vector .add (randomFloatBetween (0F , 1F , true ));
257262 case BYTE , BIT -> vector .add ((byte ) randomIntBetween (-128 , 127 ));
258263 default -> throw new IllegalArgumentException ("Unexpected element type: " + elementType );
259264 }
260265 }
261- if ((elementType == ElementType .FLOAT ) && (similarity == DenseVectorFieldMapper .VectorSimilarity .DOT_PRODUCT || rarely ())) {
266+ if ((elementType == ElementType .FLOAT || elementType == ElementType .BFLOAT16 )
267+ && (similarity == DenseVectorFieldMapper .VectorSimilarity .DOT_PRODUCT || rarely ())) {
262268 // Normalize the vector
263269 float magnitude = DenseVector .getMagnitude (vector );
264270 vector .replaceAll (number -> number .floatValue () / magnitude );
0 commit comments