Skip to content

Commit 3d5606f

Browse files
committed
Calculate magnitude once only
1 parent 7069f48 commit 3d5606f

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,7 @@ public abstract VectorData parseKnnVector(
998998
public static ElementType checkValidVector(float[] vector, ElementType... possibleTypes) {
999999
assert possibleTypes.length != 0;
10001000
// we're looking for one valid allowed type
1001+
// assume the types are in order of specificity
10011002
StringBuilder[] errors = new StringBuilder[possibleTypes.length];
10021003
for (int i = 0; i < possibleTypes.length; i++) {
10031004
StringBuilder error = possibleTypes[i].checkVectorErrors(vector);

server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,13 @@ public ByteDenseVectorFunction(
6868
super(scoreScript, field);
6969
field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
7070
float[] floatValues = new float[queryVector.size()];
71+
double queryMagnitude = 0;
7172
for (int i = 0; i < queryVector.size(); i++) {
72-
floatValues[i] = queryVector.get(i).floatValue();
73+
float value = queryVector.get(i).floatValue();
74+
floatValues[i] = value;
75+
queryMagnitude += value * value;
7376
}
77+
queryMagnitude = Math.sqrt(queryMagnitude);
7478

7579
switch (ElementType.checkValidVector(floatValues, allowedTypes)) {
7680
case FLOAT:
@@ -79,11 +83,6 @@ public ByteDenseVectorFunction(
7983
qvMagnitude = -1; // invalid valid, not used for float vectors
8084

8185
if (normalizeFloatQuery) {
82-
double queryMagnitude = 0.0;
83-
for (float val : floatQueryVector) {
84-
queryMagnitude += val * val;
85-
}
86-
queryMagnitude = Math.sqrt(queryMagnitude);
8786
for (int i = 0; i < floatQueryVector.length; i++) {
8887
floatQueryVector[i] /= (float) queryMagnitude;
8988
}
@@ -92,12 +91,10 @@ public ByteDenseVectorFunction(
9291
case BYTE:
9392
floatQueryVector = null;
9493
byteQueryVector = new byte[floatValues.length];
95-
float queryMagnitude = 0;
9694
for (int i = 0; i < floatValues.length; i++) {
9795
byteQueryVector[i] = (byte) floatValues[i];
98-
queryMagnitude += floatValues[i] * floatValues[i];
9996
}
100-
this.qvMagnitude = (float) Math.sqrt(queryMagnitude);
97+
this.qvMagnitude = (float) queryMagnitude;
10198
break;
10299
default:
103100
throw new AssertionError("Unexpected element type");
@@ -116,7 +113,7 @@ public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesFiel
116113
super(scoreScript, field);
117114
byteQueryVector = queryVector;
118115
floatQueryVector = null;
119-
float queryMagnitude = 0.0f;
116+
double queryMagnitude = 0.0f;
120117
for (byte value : queryVector) {
121118
queryMagnitude += value * value;
122119
}
@@ -512,6 +509,7 @@ public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField f
512509
public double cosineSimilarity() {
513510
setNextVector();
514511
if (floatQueryVector != null) {
512+
// float vector is already normalized by the superclass constructor
515513
return field.get().cosineSimilarity(floatQueryVector, false);
516514
} else {
517515
return field.get().cosineSimilarity(byteQueryVector, qvMagnitude);

0 commit comments

Comments
 (0)