Skip to content

Commit 1d99668

Browse files
committed
Normalize the vector when used for cosine
1 parent 0a610fe commit 1d99668

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,14 @@ public static class ByteDenseVectorFunction extends DenseVectorFunction {
5555
* @param scoreScript The script in which this function was referenced.
5656
* @param field The vector field.
5757
* @param queryVector The query vector.
58+
* @param normalizeFloatQuery {@code true} if the query vector is a float vector, then normalize it.
59+
* @param allowedTypes The types the vector is allowed to be.
5860
*/
5961
public ByteDenseVectorFunction(
6062
ScoreScript scoreScript,
6163
DenseVectorDocValuesField field,
6264
List<Number> queryVector,
65+
boolean normalizeFloatQuery,
6366
ElementType... allowedTypes
6467
) {
6568
super(scoreScript, field);
@@ -74,6 +77,16 @@ public ByteDenseVectorFunction(
7477
byteQueryVector = null;
7578
floatQueryVector = floatValues;
7679
qvMagnitude = -1; // invalid valid, not used for float vectors
80+
81+
if (normalizeFloatQuery) {
82+
double queryMagnitude = 0.0;
83+
for (float val : floatQueryVector) {
84+
queryMagnitude += val * val;
85+
}
86+
for (int i = 0; i < floatQueryVector.length; i++) {
87+
floatQueryVector[i] /= (float) queryMagnitude;
88+
}
89+
}
7790
break;
7891
case BYTE:
7992
floatQueryVector = null;
@@ -156,7 +169,7 @@ public interface L1NormInterface {
156169
public static class ByteL1Norm extends ByteDenseVectorFunction implements L1NormInterface {
157170

158171
public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
159-
super(scoreScript, field, queryVector, ElementType.BYTE);
172+
super(scoreScript, field, queryVector, false, ElementType.BYTE);
160173
}
161174

162175
public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -220,7 +233,7 @@ public interface HammingDistanceInterface {
220233
public static class ByteHammingDistance extends ByteDenseVectorFunction implements HammingDistanceInterface {
221234

222235
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
223-
super(scoreScript, field, queryVector, ElementType.BYTE);
236+
super(scoreScript, field, queryVector, false, ElementType.BYTE);
224237
}
225238

226239
public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -266,7 +279,7 @@ public interface L2NormInterface {
266279
public static class ByteL2Norm extends ByteDenseVectorFunction implements L2NormInterface {
267280

268281
public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
269-
super(scoreScript, field, queryVector, ElementType.BYTE);
282+
super(scoreScript, field, queryVector, false, ElementType.BYTE);
270283
}
271284

272285
public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -411,7 +424,7 @@ public double dotProduct() {
411424
public static class ByteDotProduct extends ByteDenseVectorFunction implements DotProductInterface {
412425

413426
public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
414-
super(scoreScript, field, queryVector, ElementType.BYTE, ElementType.FLOAT);
427+
super(scoreScript, field, queryVector, false, ElementType.BYTE, ElementType.FLOAT);
415428
}
416429

417430
public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -488,7 +501,7 @@ public interface CosineSimilarityInterface {
488501
public static class ByteCosineSimilarity extends ByteDenseVectorFunction implements CosineSimilarityInterface {
489502

490503
public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
491-
super(scoreScript, field, queryVector, ElementType.BYTE, ElementType.FLOAT);
504+
super(scoreScript, field, queryVector, true, ElementType.BYTE, ElementType.FLOAT);
492505
}
493506

494507
public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -498,7 +511,7 @@ public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField f
498511
public double cosineSimilarity() {
499512
setNextVector();
500513
if (floatQueryVector != null) {
501-
return field.get().cosineSimilarity(floatQueryVector);
514+
return field.get().cosineSimilarity(floatQueryVector, false);
502515
} else {
503516
return field.get().cosineSimilarity(byteQueryVector, qvMagnitude);
504517
}

0 commit comments

Comments
 (0)