Skip to content

Commit 24d33ce

Browse files
committed
Adding further tests and support
1 parent 49c117b commit 24d33ce

File tree

2 files changed

+130
-28
lines changed

2 files changed

+130
-28
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -808,11 +808,20 @@ VectorData parseStringValue(
808808
}
809809

810810
VectorData parseHexEncodedVector(String s, IntBooleanConsumer dimChecker, VectorSimilarity similarity) {
811-
return parseStringValue(s, dimChecker, similarity, str -> HexFormat.of().parseHex(str));
811+
return parseStringValue(s, dimChecker, similarity, HexFormat.of()::parseHex);
812812
}
813813

814814
VectorData parseBase64EncodedVector(String s, IntBooleanConsumer dimChecker, VectorSimilarity similarity) {
815-
return parseStringValue(s, dimChecker, similarity, str -> Base64.getDecoder().decode(str));
815+
return parseStringValue(s, dimChecker, similarity, Base64.getDecoder()::decode);
816+
}
817+
818+
VectorData parseBase64BinaryEncodedVector(byte[] binaryValue, IntBooleanConsumer dimChecker, VectorSimilarity similarity) {
819+
byte[] decodedVector = Base64.getDecoder().decode(binaryValue);
820+
dimChecker.accept(decodedVector.length, true);
821+
VectorData vectorData = VectorData.fromBytes(decodedVector);
822+
double squaredMagnitude = computeSquaredMagnitude(vectorData);
823+
checkVectorMagnitude(similarity, errorElementsAppender(decodedVector), (float) squaredMagnitude);
824+
return vectorData;
816825
}
817826

818827
@Override
@@ -825,9 +834,10 @@ public VectorData parseKnnVector(
825834
XContentParser.Token token = context.parser().currentToken();
826835
return switch (token) {
827836
case START_ARRAY -> parseVectorArray(context, dims, dimChecker, similarity);
837+
case VALUE_EMBEDDED_OBJECT -> parseBase64BinaryEncodedVector(context.parser().binaryValue(), dimChecker, similarity);
828838
case VALUE_STRING -> {
829839
String s = context.parser().text();
830-
if (s.length() == dims * 2 && isMaybeHexString(s)) {
840+
if (s.length() == dims * 2) {
831841
try {
832842
yield parseHexEncodedVector(s, dimChecker, similarity);
833843
} catch (IllegalArgumentException e) {
@@ -865,8 +875,7 @@ static boolean isMaybeHexString(String s) {
865875
}
866876
for (int i = 0; i < len; i++) {
867877
char c = s.charAt(i);
868-
boolean isHexChar = (c >= '0' && c <= '9') || (c >= 'A' && c <= 'F') || (c >= 'a' && c <= 'f');
869-
if (isHexChar == false) {
878+
if (HexFormat.isHexDigit(c) == false) {
870879
return false;
871880
}
872881
}
@@ -877,6 +886,7 @@ static boolean isMaybeHexString(String s) {
877886
public int parseDimensionCount(DocumentParserContext context) throws IOException {
878887
XContentParser.Token currentToken = context.parser().currentToken();
879888
return switch (currentToken) {
889+
case VALUE_EMBEDDED_OBJECT -> Base64.getDecoder().decode(context.parser().binaryValue()).length;
880890
case START_ARRAY -> {
881891
int index = 0;
882892
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
@@ -1028,13 +1038,28 @@ public double computeSquaredMagnitude(VectorData vectorData) {
10281038
public int parseDimensionCount(DocumentParserContext context) throws IOException {
10291039
XContentParser.Token currentToken = context.parser().currentToken();
10301040
return switch (currentToken) {
1041+
10311042
case START_ARRAY -> {
10321043
int index = 0;
10331044
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
10341045
index++;
10351046
}
10361047
yield index;
10371048
}
1049+
case VALUE_EMBEDDED_OBJECT -> {
1050+
byte[] vector = Base64.getDecoder().decode(context.parser().binaryValue());
1051+
if (vector.length % Float.BYTES != 0) {
1052+
throw new ParsingException(
1053+
context.parser().getTokenLocation(),
1054+
"Failed to parse object: Embedded vector byte length ["
1055+
+ vector.length
1056+
+ "] is not a multiple of ["
1057+
+ Float.BYTES
1058+
+ "]"
1059+
);
1060+
}
1061+
yield vector.length / Float.BYTES;
1062+
}
10381063
case VALUE_STRING -> {
10391064
byte[] decodedVectorBytes = Base64.getDecoder().decode(context.parser().text());
10401065
if (decodedVectorBytes.length % Float.BYTES != 0) {
@@ -1113,6 +1138,7 @@ VectorDataAndMagnitude parseFloatVectorInput(DocumentParserContext context, int
11131138
XContentParser.Token token = context.parser().currentToken();
11141139
return switch (token) {
11151140
case START_ARRAY -> parseVectorArray(context, dimChecker, dims);
1141+
case VALUE_EMBEDDED_OBJECT -> parseBase64BinaryEncodedVector(context, dimChecker, dims);
11161142
case VALUE_STRING -> parseBase64EncodedVector(context, dimChecker, dims);
11171143
default -> throw new ParsingException(
11181144
context.parser().getTokenLocation(),
@@ -1137,8 +1163,34 @@ VectorDataAndMagnitude parseVectorArray(DocumentParserContext context, IntBoolea
11371163
return new VectorDataAndMagnitude(VectorData.fromFloats(vector), squaredMagnitude);
11381164
}
11391165

1166+
VectorDataAndMagnitude parseBase64BinaryEncodedVector(DocumentParserContext context, IntBooleanConsumer dimChecker, int dims)
1167+
throws IOException {
1168+
// BIG_ENDIAN is the default, but just being explicit here
1169+
byte[] binaryValue = context.parser().binaryValue();
1170+
ByteBuffer byteBuffer = ByteBuffer.wrap(Base64.getDecoder().decode(binaryValue)).order(ByteOrder.BIG_ENDIAN);
1171+
if (byteBuffer.remaining() != dims * Float.BYTES) {
1172+
throw new ParsingException(
1173+
context.parser().getTokenLocation(),
1174+
"Failed to parse object: Embedded vector byte length ["
1175+
+ byteBuffer.remaining()
1176+
+ "] does not match the expected length of ["
1177+
+ (dims * Float.BYTES)
1178+
+ "] for dimension count ["
1179+
+ dims
1180+
+ "]"
1181+
);
1182+
}
1183+
float[] decodedVector = new float[dims];
1184+
byteBuffer.asFloatBuffer().get(decodedVector);
1185+
dimChecker.accept(decodedVector.length, true);
1186+
VectorData vectorData = VectorData.fromFloats(decodedVector);
1187+
float squaredMagnitude = (float) computeSquaredMagnitude(vectorData);
1188+
return new VectorDataAndMagnitude(vectorData, squaredMagnitude);
1189+
}
1190+
11401191
VectorDataAndMagnitude parseBase64EncodedVector(DocumentParserContext context, IntBooleanConsumer dimChecker, int dims)
11411192
throws IOException {
1193+
// BIG_ENDIAN is the default, but just being explicit here
11421194
ByteBuffer byteBuffer = ByteBuffer.wrap(Base64.getDecoder().decode(context.parser().text())).order(ByteOrder.BIG_ENDIAN);
11431195
if (byteBuffer.remaining() != dims * Float.BYTES) {
11441196
throw new ParsingException(
@@ -1244,6 +1296,13 @@ VectorData parseStringValue(
12441296
return VectorData.fromBytes(decodedVector);
12451297
}
12461298

1299+
@Override
1300+
VectorData parseBase64BinaryEncodedVector(byte[] binaryValue, IntBooleanConsumer dimChecker, VectorSimilarity similarity) {
1301+
byte[] decodedVector = Base64.getDecoder().decode(binaryValue);
1302+
dimChecker.accept(decodedVector.length * Byte.SIZE, true);
1303+
return VectorData.fromBytes(decodedVector);
1304+
}
1305+
12471306
@Override
12481307
public int getNumBytes(int dimensions) {
12491308
assert dimensions % Byte.SIZE == 0;
@@ -2456,12 +2515,10 @@ public List<Object> fetchValues(Source source, int doc, List<Object> ignoredValu
24562515
return List.of();
24572516
}
24582517
try {
2459-
if (sourceValue instanceof List<?> v) {
2460-
values.addAll(v);
2461-
} else if (sourceValue instanceof String s) {
2462-
values.add(s);
2463-
} else {
2464-
ignoredValues.add(sourceValue);
2518+
switch (sourceValue) {
2519+
case List<?> v -> values.addAll(v);
2520+
case String s -> values.add(s);
2521+
default -> ignoredValues.add(sourceValue);
24652522
}
24662523
} catch (Exception e) {
24672524
// if parsing fails here then it would have failed at index time
@@ -2922,7 +2979,8 @@ public List<Object> fetchValues(Source source, int doc, List<Object> ignoredValu
29222979
values.add(NumberFieldMapper.NumberType.FLOAT.parse(o, false));
29232980
}
29242981
} else if (sourceValue instanceof String s) {
2925-
if ((element == BYTE_ELEMENT || element == BIT_ELEMENT)
2982+
if ((element.elementType() == BYTE_ELEMENT.elementType()
2983+
|| element.elementType() == BIT_ELEMENT.elementType())
29262984
&& s.length() == dims * 2
29272985
&& ByteElement.isMaybeHexString(s)) {
29282986
byte[] bytes;

x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/DenseVectorFieldTypeIT.java

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,32 @@
2222
import org.junit.Before;
2323

2424
import java.io.IOException;
25+
import java.nio.ByteBuffer;
2526
import java.util.ArrayList;
2627
import java.util.Arrays;
28+
import java.util.Base64;
2729
import java.util.HashMap;
30+
import java.util.HexFormat;
2831
import java.util.List;
2932
import java.util.Locale;
3033
import java.util.Map;
3134
import java.util.Set;
3235
import java.util.stream.Collectors;
3336

3437
import static org.elasticsearch.index.IndexSettings.INDEX_MAPPER_SOURCE_MODE_SETTING;
38+
import static org.elasticsearch.index.IndexSettings.INDEX_MAPPING_EXCLUDE_SOURCE_VECTORS_SETTING;
3539
import static org.elasticsearch.index.mapper.SourceFieldMapper.Mode.SYNTHETIC;
3640
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
3741
import static org.elasticsearch.xpack.esql.action.EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION;
3842

3943
public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
4044

45+
private enum VectorSourceOptions {
46+
DEFAULT,
47+
SYNTHETIC,
48+
INCLUDE_SOURCE_VECTORS
49+
}
50+
4151
public static final Set<String> ALL_DENSE_VECTOR_INDEX_TYPES = Arrays.stream(DenseVectorFieldMapper.VectorIndexType.values())
4252
.map(v -> v.getName().toLowerCase(Locale.ROOT))
4353
.collect(Collectors.toSet());
@@ -51,7 +61,7 @@ public class DenseVectorFieldTypeIT extends AbstractEsqlIntegTestCase {
5161

5262
private final ElementType elementType;
5363
private final DenseVectorFieldMapper.VectorSimilarity similarity;
54-
private final boolean synthetic;
64+
private final VectorSourceOptions sourceOptions;
5565
private final boolean index;
5666

5767
@ParametersFactory
@@ -64,13 +74,14 @@ public static Iterable<Object[]> parameters() throws Exception {
6474
if (elementType == ElementType.BIT && similarity != DenseVectorFieldMapper.VectorSimilarity.L2_NORM) {
6575
continue;
6676
}
67-
params.add(new Object[] { elementType, similarity, true, false });
77+
params.add(new Object[] { elementType, similarity, true, VectorSourceOptions.DEFAULT });
78+
params.add(new Object[] { elementType, similarity, true, VectorSourceOptions.SYNTHETIC });
79+
params.add(new Object[] { elementType, similarity, true, VectorSourceOptions.INCLUDE_SOURCE_VECTORS });
6880
}
6981

70-
// No indexing
71-
params.add(new Object[] { elementType, null, false, false });
72-
// No indexing, synthetic source
73-
params.add(new Object[] { elementType, null, false, true });
82+
params.add(new Object[] { elementType, null, false, VectorSourceOptions.DEFAULT });
83+
params.add(new Object[] { elementType, null, false, VectorSourceOptions.SYNTHETIC });
84+
params.add(new Object[] { elementType, null, false, VectorSourceOptions.INCLUDE_SOURCE_VECTORS });
7485
}
7586

7687
return params;
@@ -80,12 +91,12 @@ public DenseVectorFieldTypeIT(
8091
@Name("elementType") ElementType elementType,
8192
@Name("similarity") DenseVectorFieldMapper.VectorSimilarity similarity,
8293
@Name("index") boolean index,
83-
@Name("synthetic") boolean synthetic
94+
@Name("synthetic") VectorSourceOptions sourceOptions
8495
) {
8596
this.elementType = elementType;
8697
this.similarity = similarity;
8798
this.index = index;
88-
this.synthetic = synthetic;
99+
this.sourceOptions = sourceOptions;
89100
}
90101

91102
private final Map<Integer, List<Number>> indexedVectors = new HashMap<>();
@@ -169,7 +180,7 @@ public void testRetrieveDenseVectorFieldData() {
169180
}
170181

171182
public void testNonIndexedDenseVectorField() throws IOException {
172-
createIndexWithDenseVector("no_dense_vectors");
183+
createIndexWithDenseVector("no_dense_vectors", 64);
173184

174185
int numDocs = randomIntBetween(10, 100);
175186
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
@@ -199,9 +210,11 @@ public void testNonIndexedDenseVectorField() throws IOException {
199210

200211
@Before
201212
public void setup() throws IOException {
202-
createIndexWithDenseVector("test");
203-
204-
int numDims = randomIntBetween(32, 64) * 2; // min 64, even number
213+
int numDims = randomIntBetween(8, 16) * 8; // min 64, even number
214+
createIndexWithDenseVector("test", numDims);
215+
if (elementType == ElementType.BIT) {
216+
numDims = numDims / 8;
217+
}
205218
int numDocs = randomIntBetween(10, 100);
206219
IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
207220
for (int i = 0; i < numDocs; i++) {
@@ -222,15 +235,39 @@ public void setup() throws IOException {
222235
float magnitude = DenseVector.getMagnitude(vector);
223236
vector.replaceAll(number -> number.floatValue() / magnitude);
224237
}
225-
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vector);
238+
Object vectorToIndex;
239+
if (randomBoolean()) {
240+
vectorToIndex = vector;
241+
} else {
242+
// Test array input
243+
vectorToIndex = switch (elementType) {
244+
case FLOAT -> {
245+
float[] array = new float[numDims];
246+
for (int k = 0; k < numDims; k++) {
247+
array[k] = vector.get(k).floatValue();
248+
}
249+
final ByteBuffer buffer = ByteBuffer.allocate(Float.BYTES * numDims);
250+
buffer.asFloatBuffer().put(array);
251+
yield Base64.getEncoder().encode(buffer.array());
252+
}
253+
case BYTE, BIT -> {
254+
byte[] array = new byte[numDims];
255+
for (int k = 0; k < numDims; k++) {
256+
array[k] = vector.get(k).byteValue();
257+
}
258+
yield randomBoolean() ? Base64.getEncoder().encode(array) : HexFormat.of().formatHex(array);
259+
}
260+
};
261+
}
262+
docs[i] = prepareIndex("test").setId("" + i).setSource("id", String.valueOf(i), "vector", vectorToIndex);
226263
indexedVectors.put(i, vector);
227264
}
228265
}
229266

230267
indexRandom(true, docs);
231268
}
232269

233-
private void createIndexWithDenseVector(String indexName) throws IOException {
270+
private void createIndexWithDenseVector(String indexName, int dims) throws IOException {
234271
var client = client().admin().indices();
235272
XContentBuilder mapping = XContentFactory.jsonBuilder()
236273
.startObject()
@@ -240,6 +277,7 @@ private void createIndexWithDenseVector(String indexName) throws IOException {
240277
.endObject()
241278
.startObject("vector")
242279
.field("type", "dense_vector")
280+
.field("dims", dims)
243281
.field("element_type", elementType.toString().toLowerCase(Locale.ROOT))
244282
.field("index", index);
245283
if (index) {
@@ -256,8 +294,14 @@ private void createIndexWithDenseVector(String indexName) throws IOException {
256294
Settings.Builder settingsBuilder = Settings.builder()
257295
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
258296
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5));
259-
if (synthetic) {
260-
settingsBuilder.put(INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SYNTHETIC);
297+
switch (sourceOptions) {
298+
// ensure vectors are actually in _source
299+
case INCLUDE_SOURCE_VECTORS -> settingsBuilder.put(INDEX_MAPPING_EXCLUDE_SOURCE_VECTORS_SETTING.getKey(), false);
300+
// ensure synthetic source is on
301+
case SYNTHETIC -> settingsBuilder.put(INDEX_MAPPER_SOURCE_MODE_SETTING.getKey(), SYNTHETIC);
302+
// default, which is vectors outside of source and synthetic off
303+
case DEFAULT -> {
304+
}
261305
}
262306

263307
var createRequest = client.prepareCreate(indexName)

0 commit comments

Comments
 (0)