Skip to content

Commit 2540fdd

Browse files
committed
Use byte or float vectors as appropriate
1 parent 062f2a7 commit 2540fdd

File tree

3 files changed

+117
-60
lines changed

3 files changed

+117
-60
lines changed

modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/145_dense_vector_byte_basic.yml

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,18 +126,18 @@ setup:
126126
script:
127127
source: "dotProduct(params.query_vector, 'vector')"
128128
params:
129-
query_vector: [0.0, 111.0, -13.0, 14.0, -124.0]
129+
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
130130

131131
- match: {hits.total: 3}
132132

133133
- match: {hits.hits.0._id: "2"}
134-
- match: {hits.hits.0._score: 28732.0}
134+
- match: {hits.hits.0._score: 32865.2}
135135

136136
- match: {hits.hits.1._id: "3"}
137-
- match: {hits.hits.1._score: 17439.0}
137+
- match: {hits.hits.1._score: 21413.4}
138138

139139
- match: {hits.hits.2._id: "1"}
140-
- match: {hits.hits.2._score: 1632.0}
140+
- match: {hits.hits.2._score: 1862.3}
141141
---
142142
"Cosine Similarity":
143143
- do:
@@ -251,18 +251,18 @@ setup:
251251
script:
252252
source: "cosineSimilarity(params.query_vector, 'vector')"
253253
params:
254-
query_vector: [0.0, 111.0, -13.0, 14.0, -124.0]
254+
query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
255255

256256
- match: {hits.total: 3}
257257

258258
- match: {hits.hits.0._id: "2"}
259-
- gte: {hits.hits.0._score: 0.995}
260-
- lte: {hits.hits.0._score: 0.998}
259+
- gte: {hits.hits.0._score: 0.989}
260+
- lte: {hits.hits.0._score: 0.992}
261261

262262
- match: {hits.hits.1._id: "3"}
263-
- gte: {hits.hits.1._score: 0.829}
264-
- lte: {hits.hits.1._score: 0.832}
263+
- gte: {hits.hits.1._score: 0.885}
264+
- lte: {hits.hits.1._score: 0.888}
265265

266266
- match: {hits.hits.2._id: "1"}
267-
- gte: {hits.hits.2._score: 0.509}
268-
- lte: {hits.hits.2._score: 0.512}
267+
- gte: {hits.hits.2._score: 0.505}
268+
- lte: {hits.hits.2._score: 0.508}

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,17 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp
346346
}
347347

348348
@Override
349-
public void checkVectorBounds(float[] vector) {
350-
checkNanAndInfinite(vector);
351-
352-
StringBuilder errorBuilder = null;
349+
StringBuilder checkVectorErrors(float[] vector) {
350+
StringBuilder errors = checkNanAndInfinite(vector);
351+
if (errors != null) {
352+
return errors;
353+
}
353354

354355
for (int index = 0; index < vector.length; ++index) {
355356
float value = vector[index];
356357

357358
if (value % 1.0f != 0.0f) {
358-
errorBuilder = new StringBuilder(
359+
errors = new StringBuilder(
359360
"element_type ["
360361
+ this
361362
+ "] vectors only support non-decimal values but found decimal value ["
@@ -368,7 +369,7 @@ public void checkVectorBounds(float[] vector) {
368369
}
369370

370371
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
371-
errorBuilder = new StringBuilder(
372+
errors = new StringBuilder(
372373
"element_type ["
373374
+ this
374375
+ "] vectors only support integers between ["
@@ -385,9 +386,7 @@ public void checkVectorBounds(float[] vector) {
385386
}
386387
}
387388

388-
if (errorBuilder != null) {
389-
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
390-
}
389+
return errors;
391390
}
392391

393392
@Override
@@ -614,8 +613,8 @@ public FloatVectorValues getFloatVectorValues(String fieldName) throws IOExcepti
614613
}
615614

616615
@Override
617-
public void checkVectorBounds(float[] vector) {
618-
checkNanAndInfinite(vector);
616+
StringBuilder checkVectorErrors(float[] vector) {
617+
return checkNanAndInfinite(vector);
619618
}
620619

621620
@Override
@@ -768,16 +767,17 @@ IndexFieldData.Builder fielddataBuilder(DenseVectorFieldType denseVectorFieldTyp
768767
}
769768

770769
@Override
771-
public void checkVectorBounds(float[] vector) {
772-
checkNanAndInfinite(vector);
773-
774-
StringBuilder errorBuilder = null;
770+
StringBuilder checkVectorErrors(float[] vector) {
771+
StringBuilder errors = checkNanAndInfinite(vector);
772+
if (errors != null) {
773+
return errors;
774+
}
775775

776776
for (int index = 0; index < vector.length; ++index) {
777777
float value = vector[index];
778778

779779
if (value % 1.0f != 0.0f) {
780-
errorBuilder = new StringBuilder(
780+
errors = new StringBuilder(
781781
"element_type ["
782782
+ this
783783
+ "] vectors only support non-decimal values but found decimal value ["
@@ -790,7 +790,7 @@ public void checkVectorBounds(float[] vector) {
790790
}
791791

792792
if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
793-
errorBuilder = new StringBuilder(
793+
errors = new StringBuilder(
794794
"element_type ["
795795
+ this
796796
+ "] vectors only support integers between ["
@@ -807,9 +807,7 @@ public void checkVectorBounds(float[] vector) {
807807
}
808808
}
809809

810-
if (errorBuilder != null) {
811-
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
812-
}
810+
return errors;
813811
}
814812

815813
@Override
@@ -993,7 +991,37 @@ public abstract VectorData parseKnnVector(
993991

994992
public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);
995993

996-
public abstract void checkVectorBounds(float[] vector);
994+
public static ElementType checkValidVector(float[] vector, ElementType... possibleTypes) {
995+
assert possibleTypes.length != 0;
996+
// we're looking for one valid allowed type
997+
StringBuilder[] errors = new StringBuilder[possibleTypes.length];
998+
for (int i = 0; i < possibleTypes.length; i++) {
999+
StringBuilder error = possibleTypes[i].checkVectorErrors(vector);
1000+
if (error == null) {
1001+
// this one works - all ok
1002+
return possibleTypes[i];
1003+
} else {
1004+
errors[i] = error;
1005+
}
1006+
}
1007+
1008+
// oh dear, none of the types work with this vector. Generate the error message and throw.
1009+
StringBuilder message = new StringBuilder();
1010+
for (int i = 0; i < possibleTypes.length; i++) {
1011+
if (i > 0) message.append(" ");
1012+
message.append("Vector is not a ").append(possibleTypes[i]).append(" vector: ").append(errors[i]);
1013+
}
1014+
throw new IllegalArgumentException(appendErrorElements(message, vector).toString());
1015+
}
1016+
1017+
public void checkVectorBounds(float[] vector) {
1018+
StringBuilder errors = checkVectorErrors(vector);
1019+
if (errors != null) {
1020+
throw new IllegalArgumentException(appendErrorElements(errors, vector).toString());
1021+
}
1022+
}
1023+
1024+
abstract StringBuilder checkVectorErrors(float[] vector);
9971025

9981026
abstract void checkVectorMagnitude(
9991027
VectorSimilarity similarity,
@@ -1017,7 +1045,7 @@ public int parseDimensionCount(DocumentParserContext context) throws IOException
10171045
return index;
10181046
}
10191047

1020-
void checkNanAndInfinite(float[] vector) {
1048+
StringBuilder checkNanAndInfinite(float[] vector) {
10211049
StringBuilder errorBuilder = null;
10221050

10231051
for (int index = 0; index < vector.length; ++index) {
@@ -1044,9 +1072,7 @@ void checkNanAndInfinite(float[] vector) {
10441072
}
10451073
}
10461074

1047-
if (errorBuilder != null) {
1048-
throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
1049-
}
1075+
return errorBuilder;
10501076
}
10511077

10521078
static StringBuilder appendErrorElements(StringBuilder errorBuilder, float[] vector) {

server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.elasticsearch.ExceptionsHelper;
1313
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
1415
import org.elasticsearch.script.field.vectors.DenseVector;
1516
import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
1617

@@ -42,7 +43,10 @@ void setNextVector() {
4243
}
4344

4445
public static class ByteDenseVectorFunction extends DenseVectorFunction {
45-
protected final byte[] queryVector;
46+
// either byteQueryVector or floatQueryVector will be non-null
47+
protected final byte[] byteQueryVector;
48+
protected final float[] floatQueryVector;
49+
// only valid if byteQueryVector is used
4650
protected final float qvMagnitude;
4751

4852
/**
@@ -52,21 +56,39 @@ public static class ByteDenseVectorFunction extends DenseVectorFunction {
5256
* @param field The vector field.
5357
* @param queryVector The query vector.
5458
*/
55-
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
59+
public ByteDenseVectorFunction(
60+
ScoreScript scoreScript,
61+
DenseVectorDocValuesField field,
62+
List<Number> queryVector,
63+
ElementType... allowedTypes
64+
) {
5665
super(scoreScript, field);
5766
field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
58-
this.queryVector = new byte[queryVector.size()];
59-
float[] validateValues = new float[queryVector.size()];
60-
int queryMagnitude = 0;
67+
float[] floatValues = new float[queryVector.size()];
6168
for (int i = 0; i < queryVector.size(); i++) {
62-
final Number number = queryVector.get(i);
63-
byte value = number.byteValue();
64-
this.queryVector[i] = value;
65-
queryMagnitude += value * value;
66-
validateValues[i] = number.floatValue();
69+
floatValues[i] = queryVector.get(i).floatValue();
6770
}
68-
this.qvMagnitude = (float) Math.sqrt(queryMagnitude);
69-
field.getElementType().checkVectorBounds(validateValues);
71+
72+
switch (ElementType.checkValidVector(floatValues, allowedTypes)) {
73+
case FLOAT:
74+
byteQueryVector = null;
75+
floatQueryVector = floatValues;
76+
qvMagnitude = -1; // invalid valid, not used for float vectors
77+
break;
78+
case BYTE:
79+
floatQueryVector = null;
80+
byteQueryVector = new byte[floatValues.length];
81+
float queryMagnitude = 0;
82+
for (int i = 0; i < floatValues.length; i++) {
83+
byteQueryVector[i] = (byte) floatValues[i];
84+
queryMagnitude += floatValues[i] * floatValues[i];
85+
}
86+
this.qvMagnitude = (float) Math.sqrt(queryMagnitude);
87+
break;
88+
default:
89+
throw new AssertionError("Unexpected element type");
90+
}
91+
7092
}
7193

7294
/**
@@ -78,7 +100,8 @@ public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesFiel
78100
*/
79101
public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
80102
super(scoreScript, field);
81-
this.queryVector = queryVector;
103+
byteQueryVector = queryVector;
104+
floatQueryVector = null;
82105
float queryMagnitude = 0.0f;
83106
for (byte value : queryVector) {
84107
queryMagnitude += value * value;
@@ -115,7 +138,7 @@ public FloatDenseVectorFunction(
115138
queryMagnitude += value * value;
116139
}
117140
queryMagnitude = Math.sqrt(queryMagnitude);
118-
field.getElementType().checkVectorBounds(this.queryVector);
141+
DenseVectorFieldMapper.ElementType.FLOAT.checkVectorBounds(this.queryVector);
119142

120143
if (normalizeQuery) {
121144
for (int dim = 0; dim < this.queryVector.length; dim++) {
@@ -133,7 +156,7 @@ public interface L1NormInterface {
133156
public static class ByteL1Norm extends ByteDenseVectorFunction implements L1NormInterface {
134157

135158
public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
136-
super(scoreScript, field, queryVector);
159+
super(scoreScript, field, queryVector, ElementType.BYTE);
137160
}
138161

139162
public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -142,7 +165,7 @@ public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte
142165

143166
public double l1norm() {
144167
setNextVector();
145-
return field.get().l1Norm(queryVector);
168+
return field.get().l1Norm(byteQueryVector);
146169
}
147170
}
148171

@@ -197,7 +220,7 @@ public interface HammingDistanceInterface {
197220
public static class ByteHammingDistance extends ByteDenseVectorFunction implements HammingDistanceInterface {
198221

199222
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
200-
super(scoreScript, field, queryVector);
223+
super(scoreScript, field, queryVector, ElementType.BYTE);
201224
}
202225

203226
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -206,7 +229,7 @@ public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField fi
206229

207230
public int hamming() {
208231
setNextVector();
209-
return field.get().hamming(queryVector);
232+
return field.get().hamming(byteQueryVector);
210233
}
211234
}
212235

@@ -243,7 +266,7 @@ public interface L2NormInterface {
243266
public static class ByteL2Norm extends ByteDenseVectorFunction implements L2NormInterface {
244267

245268
public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
246-
super(scoreScript, field, queryVector);
269+
super(scoreScript, field, queryVector, ElementType.BYTE);
247270
}
248271

249272
public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -252,7 +275,7 @@ public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte
252275

253276
public double l2norm() {
254277
setNextVector();
255-
return field.get().l2Norm(queryVector);
278+
return field.get().l2Norm(byteQueryVector);
256279
}
257280
}
258281

@@ -388,7 +411,7 @@ public double dotProduct() {
388411
public static class ByteDotProduct extends ByteDenseVectorFunction implements DotProductInterface {
389412

390413
public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
391-
super(scoreScript, field, queryVector);
414+
super(scoreScript, field, queryVector, ElementType.BYTE, ElementType.FLOAT);
392415
}
393416

394417
public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -397,7 +420,11 @@ public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field,
397420

398421
public double dotProduct() {
399422
setNextVector();
400-
return field.get().dotProduct(queryVector);
423+
if (floatQueryVector != null) {
424+
return field.get().dotProduct(floatQueryVector);
425+
} else {
426+
return field.get().dotProduct(byteQueryVector);
427+
}
401428
}
402429
}
403430

@@ -461,7 +488,7 @@ public interface CosineSimilarityInterface {
461488
public static class ByteCosineSimilarity extends ByteDenseVectorFunction implements CosineSimilarityInterface {
462489

463490
public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
464-
super(scoreScript, field, queryVector);
491+
super(scoreScript, field, queryVector, ElementType.BYTE, ElementType.FLOAT);
465492
}
466493

467494
public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -470,7 +497,11 @@ public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField f
470497

471498
public double cosineSimilarity() {
472499
setNextVector();
473-
return field.get().cosineSimilarity(queryVector, qvMagnitude);
500+
if (floatQueryVector != null) {
501+
return field.get().cosineSimilarity(floatQueryVector);
502+
} else {
503+
return field.get().cosineSimilarity(byteQueryVector, qvMagnitude);
504+
}
474505
}
475506
}
476507

0 commit comments

Comments
 (0)