Skip to content

Commit 82668b4

Browse files
authored
Add basic implementations of float-byte script comparisons (#122381)
Add implementations of `cosineSimilarity` and `dotProduct` to query byte vector fields using float vectors
1 parent 5697f7f commit 82668b4

File tree

14 files changed

+325
-169
lines changed

14 files changed

+325
-169
lines changed

docs/changelog/122381.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 122381
2+
summary: Adds implementations of dotProduct and cosineSimilarity painless methods to operate on float vectors for byte fields
3+
area: Vector Search
4+
type: enhancement
5+
issues:
6+
- 117274

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ public static float ipFloatBit(float[] q, byte[] d) {
8080
return IMPL.ipFloatBit(q, d);
8181
}
8282

83+
/**
84+
* Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a byte vector.
85+
* @param q the query vector
86+
* @param d the document vector
87+
* @return the inner product of the two vectors
88+
*/
89+
public static float ipFloatByte(float[] q, byte[] d) {
90+
if (q.length != d.length) {
91+
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + d.length);
92+
}
93+
return IMPL.ipFloatByte(q, d);
94+
}
95+
8396
/**
8497
* AND bit count computed over signed bytes.
8598
* Copied from Lucene's XOR implementation

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ public float ipFloatBit(float[] q, byte[] d) {
3939
return ipFloatBitImpl(q, d);
4040
}
4141

42+
@Override
43+
public float ipFloatByte(float[] q, byte[] d) {
44+
return ipFloatByteImpl(q, d);
45+
}
46+
4247
public static int ipByteBitImpl(byte[] q, byte[] d) {
4348
assert q.length == d.length * Byte.SIZE;
4449
int acc0 = 0;
@@ -101,4 +106,12 @@ public static long ipByteBinByteImpl(byte[] q, byte[] d) {
101106
}
102107
return ret;
103108
}
109+
110+
public static float ipFloatByteImpl(float[] q, byte[] d) {
111+
float ret = 0;
112+
for (int i = 0; i < q.length; i++) {
113+
ret += q[i] * d[i];
114+
}
115+
return ret;
116+
}
104117
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@ public interface ESVectorUtilSupport {
1818
int ipByteBit(byte[] q, byte[] d);
1919

2020
float ipFloatBit(float[] q, byte[] d);
21+
22+
float ipFloatByte(float[] q, byte[] d);
2123
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ public float ipFloatBit(float[] q, byte[] d) {
5858
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
5959
}
6060

61+
@Override
62+
public float ipFloatByte(float[] q, byte[] d) {
63+
return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
64+
}
65+
6166
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
6267
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
6368

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,28 @@ public void testIpByteBit() {
3232
public void testIpFloatBit() {
3333
float[] q = new float[16];
3434
byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
35-
random().nextFloat();
35+
for (int i = 0; i < q.length; i++) {
36+
q[i] = random().nextFloat();
37+
}
3638
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
3739
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
3840
}
3941

42+
public void testIpFloatByte() {
43+
float[] q = new float[16];
44+
byte[] d = new byte[16];
45+
for (int i = 0; i < q.length; i++) {
46+
q[i] = random().nextFloat();
47+
}
48+
random().nextBytes(d);
49+
50+
float expected = 0;
51+
for (int i = 0; i < q.length; i++) {
52+
expected += q[i] * d[i];
53+
}
54+
assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6);
55+
}
56+
4057
public void testBitAndCount() {
4158
testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
4259
}

modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/145_dense_vector_byte_basic.yml

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,38 @@ setup:
107107
- match: {hits.hits.2._id: "1"}
108108
- match: {hits.hits.2._score: 1632.0}
109109
---
110+
"Dot Product float":
111+
- requires:
112+
capabilities:
113+
- path: /_search
114+
capabilities: [byte_float_dot_product_capability]
115+
test_runner_features: [capabilities]
116+
reason: "float vector queries capability added"
117+
- do:
118+
headers:
119+
Content-Type: application/json
120+
search:
121+
rest_total_hits_as_int: true
122+
body:
123+
query:
124+
script_score:
125+
query: {match_all: {} }
126+
script:
127+
source: "dotProduct(params.query_vector, 'vector')"
128+
params:
129+
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
130+
131+
- match: {hits.total: 3}
132+
133+
- match: {hits.hits.0._id: "2"}
134+
- match: {hits.hits.0._score: 32865.2}
135+
136+
- match: {hits.hits.1._id: "3"}
137+
- match: {hits.hits.1._score: 21413.4}
138+
139+
- match: {hits.hits.2._id: "1"}
140+
- match: {hits.hits.2._score: 1862.3}
141+
---
110142
"Cosine Similarity":
111143
- do:
112144
headers:
@@ -198,3 +230,39 @@ setup:
198230
- match: {hits.hits.2._id: "1"}
199231
- gte: {hits.hits.2._score: 0.509}
200232
- lte: {hits.hits.2._score: 0.512}
233+
234+
---
235+
"Cosine Similarity float":
236+
- requires:
237+
capabilities:
238+
- path: /_search
239+
capabilities: [byte_float_dot_product_capability]
240+
test_runner_features: [capabilities]
241+
reason: "float vector queries capability added"
242+
- do:
243+
headers:
244+
Content-Type: application/json
245+
search:
246+
rest_total_hits_as_int: true
247+
body:
248+
query:
249+
script_score:
250+
query: {match_all: {} }
251+
script:
252+
source: "cosineSimilarity(params.query_vector, 'vector')"
253+
params:
254+
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
255+
256+
- match: {hits.total: 3}
257+
258+
- match: {hits.hits.0._id: "2"}
259+
- gte: {hits.hits.0._score: 0.989}
260+
- lte: {hits.hits.0._score: 0.992}
261+
262+
- match: {hits.hits.1._id: "3"}
263+
- gte: {hits.hits.1._score: 0.885}
264+
- lte: {hits.hits.1._score: 0.888}
265+
266+
- match: {hits.hits.2._id: "1"}
267+
- gte: {hits.hits.2._score: 0.505}
268+
- lte: {hits.hits.2._score: 0.508}

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,17 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp
346346
}
347347

348348
@Override
349-
public void checkVectorBounds(float[] vector) {
350-
checkNanAndInfinite(vector);
351-
352-
StringBuilder errorBuilder = null;
349+
StringBuilder checkVectorErrors(float[] vector) {
350+
StringBuilder errors = checkNanAndInfinite(vector);
351+
if (errors != null) {
352+
return errors;
353+
}
353354

354355
for (int index = 0; index < vector.length; ++index) {
355356
float value = vector[index];
356357

357358
if (value % 1.0f != 0.0f) {
358-
errorBuilder = new StringBuilder(
359+
errors = new StringBuilder(
359360
"element_type ["
360361
+ this
361362
+ "] vectors only support non-decimal values but found decimal value ["
@@ -368,7 +369,7 @@ public void checkVectorBounds(float[] vector) {
368369
}
369370

370371
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
371-
errorBuilder = new StringBuilder(
372+
errors = new StringBuilder(
372373
"element_type ["
373374
+ this
374375
+ "] vectors only support integers between ["
@@ -385,9 +386,7 @@ public void checkVectorBounds(float[] vector) {
385386
}
386387
}
387388

388-
if (errorBuilder != null) {
389-
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
390-
}
389+
return errors;
391390
}
392391

393392
@Override
@@ -614,8 +613,8 @@ public FloatVectorValues getFloatVectorValues(String fieldName) throws IOExcepti
614613
}
615614

616615
@Override
617-
public void checkVectorBounds(float[] vector) {
618-
checkNanAndInfinite(vector);
616+
StringBuilder checkVectorErrors(float[] vector) {
617+
return checkNanAndInfinite(vector);
619618
}
620619

621620
@Override
@@ -768,16 +767,17 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp
768767
}
769768

770769
@Override
771-
public void checkVectorBounds(float[] vector) {
772-
checkNanAndInfinite(vector);
773-
774-
StringBuilder errorBuilder = null;
770+
StringBuilder checkVectorErrors(float[] vector) {
771+
StringBuilder errors = checkNanAndInfinite(vector);
772+
if (errors != null) {
773+
return errors;
774+
}
775775

776776
for (int index = 0; index < vector.length; ++index) {
777777
float value = vector[index];
778778

779779
if (value % 1.0f != 0.0f) {
780-
errorBuilder = new StringBuilder(
780+
errors = new StringBuilder(
781781
"element_type ["
782782
+ this
783783
+ "] vectors only support non-decimal values but found decimal value ["
@@ -790,7 +790,7 @@ public void checkVectorBounds(float[] vector) {
790790
}
791791

792792
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
793-
errorBuilder = new StringBuilder(
793+
errors = new StringBuilder(
794794
"element_type ["
795795
+ this
796796
+ "] vectors only support integers between ["
@@ -807,9 +807,7 @@ public void checkVectorBounds(float[] vector) {
807807
}
808808
}
809809

810-
if (errorBuilder != null) {
811-
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
812-
}
810+
return errors;
813811
}
814812

815813
@Override
@@ -993,7 +991,44 @@ public abstract VectorData parseKnnVector(
993991

994992
public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);
995993

996-
public abstract void checkVectorBounds(float[] vector);
994+
/**
995+
* Checks the input {@code vector} is one of the {@code possibleTypes},
996+
* and returns the first type that it matches
997+
*/
998+
public static ElementType checkValidVector(float[] vector, ElementType... possibleTypes) {
999+
assert possibleTypes.length != 0;
1000+
// we're looking for one valid allowed type
1001+
// assume the types are in order of specificity
1002+
StringBuilder[] errors = new StringBuilder[possibleTypes.length];
1003+
for (int i = 0; i < possibleTypes.length; i++) {
1004+
StringBuilder error = possibleTypes[i].checkVectorErrors(vector);
1005+
if (error == null) {
1006+
// this one works - use it
1007+
return possibleTypes[i];
1008+
} else {
1009+
errors[i] = error;
1010+
}
1011+
}
1012+
1013+
// oh dear, none of the possible types work with this vector. Generate the error message and throw.
1014+
StringBuilder message = new StringBuilder();
1015+
for (int i = 0; i < possibleTypes.length; i++) {
1016+
if (i > 0) {
1017+
message.append(" ");
1018+
}
1019+
message.append("Vector is not a ").append(possibleTypes[i]).append(" vector: ").append(errors[i]);
1020+
}
1021+
throw new IllegalArgumentException(appendErrorElements(message, vector).toString());
1022+
}
1023+
1024+
public void checkVectorBounds(float[] vector) {
1025+
StringBuilder errors = checkVectorErrors(vector);
1026+
if (errors != null) {
1027+
throw new IllegalArgumentException(appendErrorElements(errors, vector).toString());
1028+
}
1029+
}
1030+
1031+
abstract StringBuilder checkVectorErrors(float[] vector);
9971032

9981033
abstract void checkVectorMagnitude(
9991034
VectorSimilarity similarity,
@@ -1017,7 +1052,7 @@ public int parseDimensionCount(DocumentParserContext context) throws IOException
10171052
return index;
10181053
}
10191054

1020-
void checkNanAndInfinite(float[] vector) {
1055+
StringBuilder checkNanAndInfinite(float[] vector) {
10211056
StringBuilder errorBuilder = null;
10221057

10231058
for (int index = 0; index < vector.length; ++index) {
@@ -1044,9 +1079,7 @@ void checkNanAndInfinite(float[] vector) {
10441079
}
10451080
}
10461081

1047-
if (errorBuilder != null) {
1048-
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
1049-
}
1082+
return errorBuilder;
10501083
}
10511084

10521085
static StringBuilder appendErrorElements(StringBuilder errorBuilder, float[] vector) {

server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ private SearchCapabilities() {}
2525
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
2626
/** Support Byte and Float with Bit dot product. */
2727
private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product_with_bugfix";
28+
/** Support float query vectors on byte vectors */
29+
private static final String BYTE_FLOAT_DOT_PRODUCT_CAPABILITY = "byte_float_dot_product_capability";
2830
/** Support docvalue_fields parameter for `dense_vector` field. */
2931
private static final String DENSE_VECTOR_DOCVALUE_FIELDS = "dense_vector_docvalue_fields";
3032
/** Support transforming rank rrf queries to the corresponding rrf retriever. */
@@ -50,6 +52,7 @@ private SearchCapabilities() {}
5052
capabilities.add(RANGE_REGEX_INTERVAL_QUERY_CAPABILITY);
5153
capabilities.add(BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY);
5254
capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY);
55+
capabilities.add(BYTE_FLOAT_DOT_PRODUCT_CAPABILITY);
5356
capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
5457
capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER);
5558
capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);

0 commit comments

Comments
 (0)