Skip to content

Commit 7369c08

Browse files
authored
Add new multi_dense_vector field for brute-force search (#116275)
This adds a new `multi_dense_vector` field that focuses on the maxSim usecase provided by Col[BERT|Pali]. Indexing vectors in HNSW as it stands makes no sense. Performance wise or for cost. However, we should totally support rescoring and brute-force search over vectors with maxSim. This is step one of many. Behind a feature flag, this adds support for indexing any number of vectors of the same dimension. Supports bit/byte/float. Scripting support will be a follow up. Marking as non-issue as its behind a flag and unusable currently.
1 parent 66123cf commit 7369c08

File tree

10 files changed

+1451
-65
lines changed

10 files changed

+1451
-65
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
setup:
2+
- requires:
3+
capabilities:
4+
- method: POST
5+
path: /_search
6+
capabilities: [ multi_dense_vector_field_mapper ]
7+
test_runner_features: capabilities
8+
reason: "Support for multi dense vector field mapper capability required"
9+
---
10+
"Test create multi-vector field":
11+
- do:
12+
indices.create:
13+
index: test
14+
body:
15+
mappings:
16+
properties:
17+
vector1:
18+
type: multi_dense_vector
19+
dims: 3
20+
- do:
21+
index:
22+
index: test
23+
id: "1"
24+
body:
25+
vector1: [[2, -1, 1]]
26+
- do:
27+
index:
28+
index: test
29+
id: "2"
30+
body:
31+
vector1: [[2, -1, 1], [3, 4, 5]]
32+
- do:
33+
index:
34+
index: test
35+
id: "3"
36+
body:
37+
vector1: [[2, -1, 1], [3, 4, 5], [6, 7, 8]]
38+
- do:
39+
indices.refresh: {}
40+
---
41+
"Test create dynamic dim multi-vector field":
42+
- do:
43+
indices.create:
44+
index: test
45+
body:
46+
mappings:
47+
properties:
48+
name:
49+
type: keyword
50+
vector1:
51+
type: multi_dense_vector
52+
- do:
53+
index:
54+
index: test
55+
id: "1"
56+
body:
57+
vector1: [[2, -1, 1]]
58+
- do:
59+
index:
60+
index: test
61+
id: "2"
62+
body:
63+
vector1: [[2, -1, 1], [3, 4, 5]]
64+
- do:
65+
index:
66+
index: test
67+
id: "3"
68+
body:
69+
vector1: [[2, -1, 1], [3, 4, 5], [6, 7, 8]]
70+
- do:
71+
cluster.health:
72+
wait_for_events: languid
73+
74+
# verify some other dimension will fail
75+
- do:
76+
catch: bad_request
77+
index:
78+
index: test
79+
id: "4"
80+
body:
81+
vector1: [[2, -1, 1], [3, 4, 5], [6, 7, 8, 9]]
82+
---
83+
"Test dynamic dim mismatch fails multi-vector field":
84+
- do:
85+
indices.create:
86+
index: test
87+
body:
88+
mappings:
89+
properties:
90+
vector1:
91+
type: multi_dense_vector
92+
- do:
93+
catch: bad_request
94+
index:
95+
index: test
96+
id: "1"
97+
body:
98+
vector1: [[2, -1, 1], [2]]
99+
---
100+
"Test static dim mismatch fails multi-vector field":
101+
- do:
102+
indices.create:
103+
index: test
104+
body:
105+
mappings:
106+
properties:
107+
vector1:
108+
type: multi_dense_vector
109+
dims: 3
110+
- do:
111+
catch: bad_request
112+
index:
113+
index: test
114+
id: "1"
115+
body:
116+
vector1: [[2, -1, 1], [2]]
117+
---
118+
"Test poorly formatted multi-vector field":
119+
- do:
120+
indices.create:
121+
index: poorly_formatted_vector
122+
body:
123+
mappings:
124+
properties:
125+
vector1:
126+
type: multi_dense_vector
127+
dims: 3
128+
- do:
129+
catch: bad_request
130+
index:
131+
index: poorly_formatted_vector
132+
id: "1"
133+
body:
134+
vector1: [[[2, -1, 1]]]
135+
- do:
136+
catch: bad_request
137+
index:
138+
index: poorly_formatted_vector
139+
id: "1"
140+
body:
141+
vector1: [[2, -1, 1], [[2, -1, 1]]]

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -416,13 +416,18 @@ public double computeSquaredMagnitude(VectorData vectorData) {
416416
return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector());
417417
}
418418

419-
private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
419+
private VectorData parseVectorArray(
420+
DocumentParserContext context,
421+
int dims,
422+
IntBooleanConsumer dimChecker,
423+
VectorSimilarity similarity
424+
) throws IOException {
420425
int index = 0;
421-
byte[] vector = new byte[fieldMapper.fieldType().dims];
426+
byte[] vector = new byte[dims];
422427
float squaredMagnitude = 0;
423428
for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser()
424429
.nextToken()) {
425-
fieldMapper.checkDimensionExceeded(index, context);
430+
dimChecker.accept(index, false);
426431
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
427432
final int value;
428433
if (context.parser().numberType() != XContentParser.NumberType.INT) {
@@ -460,30 +465,31 @@ private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFi
460465
vector[index++] = (byte) value;
461466
squaredMagnitude += value * value;
462467
}
463-
fieldMapper.checkDimensionMatches(index, context);
464-
checkVectorMagnitude(fieldMapper.fieldType().similarity, errorByteElementsAppender(vector), squaredMagnitude);
468+
dimChecker.accept(index, true);
469+
checkVectorMagnitude(similarity, errorByteElementsAppender(vector), squaredMagnitude);
465470
return VectorData.fromBytes(vector);
466471
}
467472

468-
private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
473+
private VectorData parseHexEncodedVector(
474+
DocumentParserContext context,
475+
IntBooleanConsumer dimChecker,
476+
VectorSimilarity similarity
477+
) throws IOException {
469478
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
470-
fieldMapper.checkDimensionMatches(decodedVector.length, context);
479+
dimChecker.accept(decodedVector.length, true);
471480
VectorData vectorData = VectorData.fromBytes(decodedVector);
472481
double squaredMagnitude = computeSquaredMagnitude(vectorData);
473-
checkVectorMagnitude(
474-
fieldMapper.fieldType().similarity,
475-
errorByteElementsAppender(decodedVector),
476-
(float) squaredMagnitude
477-
);
482+
checkVectorMagnitude(similarity, errorByteElementsAppender(decodedVector), (float) squaredMagnitude);
478483
return vectorData;
479484
}
480485

481486
@Override
482-
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
487+
VectorData parseKnnVector(DocumentParserContext context, int dims, IntBooleanConsumer dimChecker, VectorSimilarity similarity)
488+
throws IOException {
483489
XContentParser.Token token = context.parser().currentToken();
484490
return switch (token) {
485-
case START_ARRAY -> parseVectorArray(context, fieldMapper);
486-
case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper);
491+
case START_ARRAY -> parseVectorArray(context, dims, dimChecker, similarity);
492+
case VALUE_STRING -> parseHexEncodedVector(context, dimChecker, similarity);
487493
default -> throw new ParsingException(
488494
context.parser().getTokenLocation(),
489495
format("Unsupported type [%s] for provided value [%s]", token, context.parser().text())
@@ -493,7 +499,13 @@ VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper
493499

494500
@Override
495501
public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
496-
VectorData vectorData = parseKnnVector(context, fieldMapper);
502+
VectorData vectorData = parseKnnVector(context, fieldMapper.fieldType().dims, (i, end) -> {
503+
if (end) {
504+
fieldMapper.checkDimensionMatches(i, context);
505+
} else {
506+
fieldMapper.checkDimensionExceeded(i, context);
507+
}
508+
}, fieldMapper.fieldType().similarity);
497509
Field field = createKnnVectorField(
498510
fieldMapper.fieldType().name(),
499511
vectorData.asByteVector(),
@@ -677,21 +689,22 @@ && isNotUnitVector(squaredMagnitude)) {
677689
}
678690

679691
@Override
680-
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
692+
VectorData parseKnnVector(DocumentParserContext context, int dims, IntBooleanConsumer dimChecker, VectorSimilarity similarity)
693+
throws IOException {
681694
int index = 0;
682695
float squaredMagnitude = 0;
683-
float[] vector = new float[fieldMapper.fieldType().dims];
696+
float[] vector = new float[dims];
684697
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
685-
fieldMapper.checkDimensionExceeded(index, context);
698+
dimChecker.accept(index, false);
686699
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
687700
float value = context.parser().floatValue(true);
688701
vector[index] = value;
689702
squaredMagnitude += value * value;
690703
index++;
691704
}
692-
fieldMapper.checkDimensionMatches(index, context);
705+
dimChecker.accept(index, true);
693706
checkVectorBounds(vector);
694-
checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude);
707+
checkVectorMagnitude(similarity, errorFloatElementsAppender(vector), squaredMagnitude);
695708
return VectorData.fromFloats(vector);
696709
}
697710

@@ -816,12 +829,17 @@ public double computeSquaredMagnitude(VectorData vectorData) {
816829
return count;
817830
}
818831

819-
private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
832+
private VectorData parseVectorArray(
833+
DocumentParserContext context,
834+
int dims,
835+
IntBooleanConsumer dimChecker,
836+
VectorSimilarity similarity
837+
) throws IOException {
820838
int index = 0;
821-
byte[] vector = new byte[fieldMapper.fieldType().dims / Byte.SIZE];
839+
byte[] vector = new byte[dims / Byte.SIZE];
822840
for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser()
823841
.nextToken()) {
824-
fieldMapper.checkDimensionExceeded(index, context);
842+
dimChecker.accept(index * Byte.SIZE, false);
825843
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
826844
final int value;
827845
if (context.parser().numberType() != XContentParser.NumberType.INT) {
@@ -856,35 +874,25 @@ private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFi
856874
+ "];"
857875
);
858876
}
859-
if (index >= vector.length) {
860-
throw new IllegalArgumentException(
861-
"The number of dimensions for field ["
862-
+ fieldMapper.fieldType().name()
863-
+ "] should be ["
864-
+ fieldMapper.fieldType().dims
865-
+ "] but found ["
866-
+ (index + 1) * Byte.SIZE
867-
+ "]"
868-
);
869-
}
870877
vector[index++] = (byte) value;
871878
}
872-
fieldMapper.checkDimensionMatches(index * Byte.SIZE, context);
879+
dimChecker.accept(index * Byte.SIZE, true);
873880
return VectorData.fromBytes(vector);
874881
}
875882

876-
private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
883+
private VectorData parseHexEncodedVector(DocumentParserContext context, IntBooleanConsumer dimChecker) throws IOException {
877884
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
878-
fieldMapper.checkDimensionMatches(decodedVector.length * Byte.SIZE, context);
885+
dimChecker.accept(decodedVector.length * Byte.SIZE, true);
879886
return VectorData.fromBytes(decodedVector);
880887
}
881888

882889
@Override
883-
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
890+
VectorData parseKnnVector(DocumentParserContext context, int dims, IntBooleanConsumer dimChecker, VectorSimilarity similarity)
891+
throws IOException {
884892
XContentParser.Token token = context.parser().currentToken();
885893
return switch (token) {
886-
case START_ARRAY -> parseVectorArray(context, fieldMapper);
887-
case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper);
894+
case START_ARRAY -> parseVectorArray(context, dims, dimChecker, similarity);
895+
case VALUE_STRING -> parseHexEncodedVector(context, dimChecker);
888896
default -> throw new ParsingException(
889897
context.parser().getTokenLocation(),
890898
format("Unsupported type [%s] for provided value [%s]", token, context.parser().text())
@@ -894,7 +902,13 @@ VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper
894902

895903
@Override
896904
public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
897-
VectorData vectorData = parseKnnVector(context, fieldMapper);
905+
VectorData vectorData = parseKnnVector(context, fieldMapper.fieldType().dims, (i, end) -> {
906+
if (end) {
907+
fieldMapper.checkDimensionMatches(i, context);
908+
} else {
909+
fieldMapper.checkDimensionExceeded(i, context);
910+
}
911+
}, fieldMapper.fieldType().similarity);
898912
Field field = createKnnVectorField(
899913
fieldMapper.fieldType().name(),
900914
vectorData.asByteVector(),
@@ -958,7 +972,12 @@ public void checkDimensions(Integer dvDims, int qvDims) {
958972

959973
abstract void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException;
960974

961-
abstract VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException;
975+
abstract VectorData parseKnnVector(
976+
DocumentParserContext context,
977+
int dims,
978+
IntBooleanConsumer dimChecker,
979+
VectorSimilarity similarity
980+
) throws IOException;
962981

963982
abstract int getNumBytes(int dimensions);
964983

@@ -2180,7 +2199,13 @@ private void parseBinaryDocValuesVectorAndIndex(DocumentParserContext context) t
21802199
: elementType.getNumBytes(dims);
21812200

21822201
ByteBuffer byteBuffer = elementType.createByteBuffer(indexCreatedVersion, numBytes);
2183-
VectorData vectorData = elementType.parseKnnVector(context, this);
2202+
VectorData vectorData = elementType.parseKnnVector(context, dims, (i, b) -> {
2203+
if (b) {
2204+
checkDimensionMatches(i, context);
2205+
} else {
2206+
checkDimensionExceeded(i, context);
2207+
}
2208+
}, fieldType().similarity);
21842209
vectorData.addToBuffer(byteBuffer);
21852210
if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
21862211
// encode vector magnitude at the end
@@ -2433,4 +2458,11 @@ public String fieldName() {
24332458
return fullPath();
24342459
}
24352460
}
2461+
2462+
/**
2463+
* @FunctionalInterface for a function that takes a int and boolean
2464+
*/
2465+
interface IntBooleanConsumer {
2466+
void accept(int value, boolean isComplete);
2467+
}
24362468
}

0 commit comments

Comments
 (0)