|
14 | 14 | import org.elasticsearch.cluster.metadata.IndexMetadata; |
15 | 15 | import org.elasticsearch.common.settings.Settings; |
16 | 16 | import org.elasticsearch.index.codec.vectors.BFloat16; |
| 17 | +import org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat; |
17 | 18 | import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; |
18 | 19 | import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType; |
19 | 20 | import org.elasticsearch.script.field.vectors.DenseVector; |
@@ -61,16 +62,20 @@ private enum VectorSourceOptions { |
61 | 62 | .collect(Collectors.toSet()); |
62 | 63 |
|
63 | 64 | public static final float DELTA = 1e-7F; |
| 65 | + public static final float BFLOAT16_DELTA = 1e-2F; |
64 | 66 |
|
65 | 67 | private final ElementType elementType; |
66 | 68 | private final DenseVectorFieldMapper.VectorSimilarity similarity; |
67 | 69 | private final VectorSourceOptions sourceOptions; |
68 | 70 | private final boolean index; |
69 | 71 |
|
70 | 72 | @ParametersFactory |
71 | | - public static Iterable<Object[]> parameters() throws Exception { |
| 73 | + public static Iterable<Object[]> parameters() { |
72 | 74 | List<Object[]> params = new ArrayList<>(); |
73 | | - for (ElementType elementType : List.of(ElementType.BYTE, ElementType.FLOAT, ElementType.BIT)) { |
| 75 | + ElementType[] elementTypes = ES93GenericFlatVectorsFormat.GENERIC_VECTOR_FORMAT.isEnabled() |
| 76 | + ? ElementType.values() |
| 77 | + : new ElementType[] { ElementType.BYTE, ElementType.FLOAT, ElementType.BIT }; |
| 78 | + for (ElementType elementType : elementTypes) { |
74 | 79 | // Test all similarities |
75 | 80 | for (DenseVectorFieldMapper.VectorSimilarity similarity : DenseVectorFieldMapper.VectorSimilarity.values()) { |
76 | 81 | if (elementType == ElementType.BIT && similarity != DenseVectorFieldMapper.VectorSimilarity.L2_NORM) { |
@@ -137,8 +142,10 @@ public void testRetrieveTopNDenseVectorFieldData() { |
137 | 142 | } else { |
138 | 143 | assertNotNull(actualVector); |
139 | 144 | assertEquals(expectedVector.size(), actualVector.size()); |
| 145 | + |
| 146 | + float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA; |
140 | 147 | for (int i = 0; i < expectedVector.size(); i++) { |
141 | | - assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), DELTA); |
| 148 | + assertEquals(expectedVector.get(i).floatValue(), actualVector.get(i).floatValue(), delta); |
142 | 149 | } |
143 | 150 | } |
144 | 151 | }); |
@@ -167,12 +174,14 @@ public void testRetrieveDenseVectorFieldData() { |
167 | 174 | } else { |
168 | 175 | assertNotNull(actualVector); |
169 | 176 | assertEquals(expectedVector.size(), actualVector.size()); |
| 177 | + |
| 178 | + float delta = elementType == ElementType.BFLOAT16 ? BFLOAT16_DELTA : DELTA; |
170 | 179 | for (int i = 0; i < actualVector.size(); i++) { |
171 | 180 | assertEquals( |
172 | 181 | "Actual: " + actualVector + "; expected: " + expectedVector, |
173 | 182 | expectedVector.get(i).floatValue(), |
174 | 183 | actualVector.get(i).floatValue(), |
175 | | - DELTA |
| 184 | + delta |
176 | 185 | ); |
177 | 186 | } |
178 | 187 | } |
@@ -253,12 +262,13 @@ public void setup() throws IOException { |
253 | 262 | } else { |
254 | 263 | for (int j = 0; j < numDims; j++) { |
255 | 264 | switch (elementType) { |
256 | | - case FLOAT -> vector.add(randomFloatBetween(0F, 1F, true)); |
| 265 | + case FLOAT, BFLOAT16 -> vector.add(randomFloatBetween(0F, 1F, true)); |
257 | 266 | case BYTE, BIT -> vector.add((byte) randomIntBetween(-128, 127)); |
258 | 267 | default -> throw new IllegalArgumentException("Unexpected element type: " + elementType); |
259 | 268 | } |
260 | 269 | } |
261 | | - if ((elementType == ElementType.FLOAT) && (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) { |
| 270 | + if ((elementType == ElementType.FLOAT || elementType == ElementType.BFLOAT16) |
| 271 | + && (similarity == DenseVectorFieldMapper.VectorSimilarity.DOT_PRODUCT || rarely())) { |
262 | 272 | // Normalize the vector |
263 | 273 | float magnitude = DenseVector.getMagnitude(vector); |
264 | 274 | vector.replaceAll(number -> number.floatValue() / magnitude); |
|
0 commit comments