Skip to content

Commit f14c8bd

Browse files
Add new multi_dense_vector field for brute-force search (#116275) (#116526)
This adds a new `multi_dense_vector` field that focuses on the maxSim usecase provided by Col[BERT|Pali]. Indexing vectors in HNSW as it stands makes no sense. Performance wise or for cost. However, we should totally support rescoring and brute-force search over vectors with maxSim. This is step one of many. Behind a feature flag, this adds support for indexing any number of vectors of the same dimension. Supports bit/byte/float. Scripting support will be a follow up. Marking as non-issue as its behind a flag and unusable currently. (cherry picked from commit 7369c08) Co-authored-by: Elastic Machine <[email protected]>
1 parent 44bd65b commit f14c8bd

File tree

10 files changed

+1450
-64
lines changed

10 files changed

+1450
-64
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
setup:
2+
- requires:
3+
capabilities:
4+
- method: POST
5+
path: /_search
6+
capabilities: [ multi_dense_vector_field_mapper ]
7+
test_runner_features: capabilities
8+
reason: "Support for multi dense vector field mapper capability required"
9+
---
10+
"Test create multi-vector field":
11+
- do:
12+
indices.create:
13+
index: test
14+
body:
15+
mappings:
16+
properties:
17+
vector1:
18+
type: multi_dense_vector
19+
dims: 3
20+
- do:
21+
index:
22+
index: test
23+
id: "1"
24+
body:
25+
vector1: [[2, -1, 1]]
26+
- do:
27+
index:
28+
index: test
29+
id: "2"
30+
body:
31+
vector1: [[2, -1, 1], [3, 4, 5]]
32+
- do:
33+
index:
34+
index: test
35+
id: "3"
36+
body:
37+
vector1: [[2, -1, 1], [3, 4, 5], [6, 7, 8]]
38+
- do:
39+
indices.refresh: {}
40+
---
41+
"Test create dynamic dim multi-vector field":
42+
- do:
43+
indices.create:
44+
index: test
45+
body:
46+
mappings:
47+
properties:
48+
name:
49+
type: keyword
50+
vector1:
51+
type: multi_dense_vector
52+
- do:
53+
index:
54+
index: test
55+
id: "1"
56+
body:
57+
vector1: [[2, -1, 1]]
58+
- do:
59+
index:
60+
index: test
61+
id: "2"
62+
body:
63+
vector1: [[2, -1, 1], [3, 4, 5]]
64+
- do:
65+
index:
66+
index: test
67+
id: "3"
68+
body:
69+
vector1: [[2, -1, 1], [3, 4, 5], [6, 7, 8]]
70+
- do:
71+
cluster.health:
72+
wait_for_events: languid
73+
74+
# verify some other dimension will fail
75+
- do:
76+
catch: bad_request
77+
index:
78+
index: test
79+
id: "4"
80+
body:
81+
vector1: [[2, -1, 1], [3, 4, 5], [6, 7, 8, 9]]
82+
---
83+
"Test dynamic dim mismatch fails multi-vector field":
84+
- do:
85+
indices.create:
86+
index: test
87+
body:
88+
mappings:
89+
properties:
90+
vector1:
91+
type: multi_dense_vector
92+
- do:
93+
catch: bad_request
94+
index:
95+
index: test
96+
id: "1"
97+
body:
98+
vector1: [[2, -1, 1], [2]]
99+
---
100+
"Test static dim mismatch fails multi-vector field":
101+
- do:
102+
indices.create:
103+
index: test
104+
body:
105+
mappings:
106+
properties:
107+
vector1:
108+
type: multi_dense_vector
109+
dims: 3
110+
- do:
111+
catch: bad_request
112+
index:
113+
index: test
114+
id: "1"
115+
body:
116+
vector1: [[2, -1, 1], [2]]
117+
---
118+
"Test poorly formatted multi-vector field":
119+
- do:
120+
indices.create:
121+
index: poorly_formatted_vector
122+
body:
123+
mappings:
124+
properties:
125+
vector1:
126+
type: multi_dense_vector
127+
dims: 3
128+
- do:
129+
catch: bad_request
130+
index:
131+
index: poorly_formatted_vector
132+
id: "1"
133+
body:
134+
vector1: [[[2, -1, 1]]]
135+
- do:
136+
catch: bad_request
137+
index:
138+
index: poorly_formatted_vector
139+
id: "1"
140+
body:
141+
vector1: [[2, -1, 1], [[2, -1, 1]]]

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -415,13 +415,18 @@ public double computeSquaredMagnitude(VectorData vectorData) {
415415
return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector());
416416
}
417417

418-
private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
418+
private VectorData parseVectorArray(
419+
DocumentParserContext context,
420+
int dims,
421+
IntBooleanConsumer dimChecker,
422+
VectorSimilarity similarity
423+
) throws IOException {
419424
int index = 0;
420-
byte[] vector = new byte[fieldMapper.fieldType().dims];
425+
byte[] vector = new byte[dims];
421426
float squaredMagnitude = 0;
422427
for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser()
423428
.nextToken()) {
424-
fieldMapper.checkDimensionExceeded(index, context);
429+
dimChecker.accept(index, false);
425430
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
426431
final int value;
427432
if (context.parser().numberType() != XContentParser.NumberType.INT) {
@@ -459,30 +464,31 @@ private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFi
459464
vector[index++] = (byte) value;
460465
squaredMagnitude += value * value;
461466
}
462-
fieldMapper.checkDimensionMatches(index, context);
463-
checkVectorMagnitude(fieldMapper.fieldType().similarity, errorByteElementsAppender(vector), squaredMagnitude);
467+
dimChecker.accept(index, true);
468+
checkVectorMagnitude(similarity, errorByteElementsAppender(vector), squaredMagnitude);
464469
return VectorData.fromBytes(vector);
465470
}
466471

467-
private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
472+
private VectorData parseHexEncodedVector(
473+
DocumentParserContext context,
474+
IntBooleanConsumer dimChecker,
475+
VectorSimilarity similarity
476+
) throws IOException {
468477
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
469-
fieldMapper.checkDimensionMatches(decodedVector.length, context);
478+
dimChecker.accept(decodedVector.length, true);
470479
VectorData vectorData = VectorData.fromBytes(decodedVector);
471480
double squaredMagnitude = computeSquaredMagnitude(vectorData);
472-
checkVectorMagnitude(
473-
fieldMapper.fieldType().similarity,
474-
errorByteElementsAppender(decodedVector),
475-
(float) squaredMagnitude
476-
);
481+
checkVectorMagnitude(similarity, errorByteElementsAppender(decodedVector), (float) squaredMagnitude);
477482
return vectorData;
478483
}
479484

480485
@Override
481-
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
486+
VectorData parseKnnVector(DocumentParserContext context, int dims, IntBooleanConsumer dimChecker, VectorSimilarity similarity)
487+
throws IOException {
482488
XContentParser.Token token = context.parser().currentToken();
483489
return switch (token) {
484-
case START_ARRAY -> parseVectorArray(context, fieldMapper);
485-
case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper);
490+
case START_ARRAY -> parseVectorArray(context, dims, dimChecker, similarity);
491+
case VALUE_STRING -> parseHexEncodedVector(context, dimChecker, similarity);
486492
default -> throw new ParsingException(
487493
context.parser().getTokenLocation(),
488494
format("Unsupported type [%s] for provided value [%s]", token, context.parser().text())
@@ -492,7 +498,13 @@ VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper
492498

493499
@Override
494500
public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
495-
VectorData vectorData = parseKnnVector(context, fieldMapper);
501+
VectorData vectorData = parseKnnVector(context, fieldMapper.fieldType().dims, (i, end) -> {
502+
if (end) {
503+
fieldMapper.checkDimensionMatches(i, context);
504+
} else {
505+
fieldMapper.checkDimensionExceeded(i, context);
506+
}
507+
}, fieldMapper.fieldType().similarity);
496508
Field field = createKnnVectorField(
497509
fieldMapper.fieldType().name(),
498510
vectorData.asByteVector(),
@@ -676,21 +688,22 @@ && isNotUnitVector(squaredMagnitude)) {
676688
}
677689

678690
@Override
679-
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
691+
VectorData parseKnnVector(DocumentParserContext context, int dims, IntBooleanConsumer dimChecker, VectorSimilarity similarity)
692+
throws IOException {
680693
int index = 0;
681694
float squaredMagnitude = 0;
682-
float[] vector = new float[fieldMapper.fieldType().dims];
695+
float[] vector = new float[dims];
683696
for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
684-
fieldMapper.checkDimensionExceeded(index, context);
697+
dimChecker.accept(index, false);
685698
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
686699
float value = context.parser().floatValue(true);
687700
vector[index] = value;
688701
squaredMagnitude += value * value;
689702
index++;
690703
}
691-
fieldMapper.checkDimensionMatches(index, context);
704+
dimChecker.accept(index, true);
692705
checkVectorBounds(vector);
693-
checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude);
706+
checkVectorMagnitude(similarity, errorFloatElementsAppender(vector), squaredMagnitude);
694707
return VectorData.fromFloats(vector);
695708
}
696709

@@ -815,12 +828,17 @@ public double computeSquaredMagnitude(VectorData vectorData) {
815828
return count;
816829
}
817830

818-
private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
831+
private VectorData parseVectorArray(
832+
DocumentParserContext context,
833+
int dims,
834+
IntBooleanConsumer dimChecker,
835+
VectorSimilarity similarity
836+
) throws IOException {
819837
int index = 0;
820-
byte[] vector = new byte[fieldMapper.fieldType().dims / Byte.SIZE];
838+
byte[] vector = new byte[dims / Byte.SIZE];
821839
for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser()
822840
.nextToken()) {
823-
fieldMapper.checkDimensionExceeded(index, context);
841+
dimChecker.accept(index * Byte.SIZE, false);
824842
ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
825843
final int value;
826844
if (context.parser().numberType() != XContentParser.NumberType.INT) {
@@ -855,35 +873,25 @@ private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFi
855873
+ "];"
856874
);
857875
}
858-
if (index >= vector.length) {
859-
throw new IllegalArgumentException(
860-
"The number of dimensions for field ["
861-
+ fieldMapper.fieldType().name()
862-
+ "] should be ["
863-
+ fieldMapper.fieldType().dims
864-
+ "] but found ["
865-
+ (index + 1) * Byte.SIZE
866-
+ "]"
867-
);
868-
}
869876
vector[index++] = (byte) value;
870877
}
871-
fieldMapper.checkDimensionMatches(index * Byte.SIZE, context);
878+
dimChecker.accept(index * Byte.SIZE, true);
872879
return VectorData.fromBytes(vector);
873880
}
874881

875-
private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
882+
private VectorData parseHexEncodedVector(DocumentParserContext context, IntBooleanConsumer dimChecker) throws IOException {
876883
byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
877-
fieldMapper.checkDimensionMatches(decodedVector.length * Byte.SIZE, context);
884+
dimChecker.accept(decodedVector.length * Byte.SIZE, true);
878885
return VectorData.fromBytes(decodedVector);
879886
}
880887

881888
@Override
882-
VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
889+
VectorData parseKnnVector(DocumentParserContext context, int dims, IntBooleanConsumer dimChecker, VectorSimilarity similarity)
890+
throws IOException {
883891
XContentParser.Token token = context.parser().currentToken();
884892
return switch (token) {
885-
case START_ARRAY -> parseVectorArray(context, fieldMapper);
886-
case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper);
893+
case START_ARRAY -> parseVectorArray(context, dims, dimChecker, similarity);
894+
case VALUE_STRING -> parseHexEncodedVector(context, dimChecker);
887895
default -> throw new ParsingException(
888896
context.parser().getTokenLocation(),
889897
format("Unsupported type [%s] for provided value [%s]", token, context.parser().text())
@@ -893,7 +901,13 @@ VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper
893901

894902
@Override
895903
public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
896-
VectorData vectorData = parseKnnVector(context, fieldMapper);
904+
VectorData vectorData = parseKnnVector(context, fieldMapper.fieldType().dims, (i, end) -> {
905+
if (end) {
906+
fieldMapper.checkDimensionMatches(i, context);
907+
} else {
908+
fieldMapper.checkDimensionExceeded(i, context);
909+
}
910+
}, fieldMapper.fieldType().similarity);
897911
Field field = createKnnVectorField(
898912
fieldMapper.fieldType().name(),
899913
vectorData.asByteVector(),
@@ -957,7 +971,12 @@ public void checkDimensions(Integer dvDims, int qvDims) {
957971

958972
abstract void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException;
959973

960-
abstract VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException;
974+
abstract VectorData parseKnnVector(
975+
DocumentParserContext context,
976+
int dims,
977+
IntBooleanConsumer dimChecker,
978+
VectorSimilarity similarity
979+
) throws IOException;
961980

962981
abstract int getNumBytes(int dimensions);
963982

@@ -2179,7 +2198,13 @@ private void parseBinaryDocValuesVectorAndIndex(DocumentParserContext context) t
21792198
: elementType.getNumBytes(dims);
21802199

21812200
ByteBuffer byteBuffer = elementType.createByteBuffer(indexCreatedVersion, numBytes);
2182-
VectorData vectorData = elementType.parseKnnVector(context, this);
2201+
VectorData vectorData = elementType.parseKnnVector(context, dims, (i, b) -> {
2202+
if (b) {
2203+
checkDimensionMatches(i, context);
2204+
} else {
2205+
checkDimensionExceeded(i, context);
2206+
}
2207+
}, fieldType().similarity);
21832208
vectorData.addToBuffer(byteBuffer);
21842209
if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
21852210
// encode vector magnitude at the end
@@ -2427,4 +2452,11 @@ public String fieldName() {
24272452
return fullPath();
24282453
}
24292454
}
2455+
2456+
/**
2457+
* @FunctionalInterface for a function that takes a int and boolean
2458+
*/
2459+
interface IntBooleanConsumer {
2460+
void accept(int value, boolean isComplete);
2461+
}
24302462
}

0 commit comments

Comments
 (0)