@@ -55,11 +55,14 @@ public static class ByteDenseVectorFunction extends DenseVectorFunction {
5555         * @param scoreScript The script in which this function was referenced. 
5656         * @param field The vector field. 
5757         * @param queryVector The query vector. 
58+          * @param normalizeFloatQuery {@code true} if the query vector is a float vector, then normalize it. 
59+          * @param allowedTypes The types the vector is allowed to be. 
5860         */ 
5961        public  ByteDenseVectorFunction (
6062            ScoreScript  scoreScript ,
6163            DenseVectorDocValuesField  field ,
6264            List <Number > queryVector ,
65+             boolean  normalizeFloatQuery ,
6366            ElementType ... allowedTypes 
6467        ) {
6568            super (scoreScript , field );
@@ -74,6 +77,16 @@ public ByteDenseVectorFunction(
7477                    byteQueryVector  = null ;
7578                    floatQueryVector  = floatValues ;
7679                    qvMagnitude  = -1 ;   // invalid valid, not used for float vectors 
80+ 
81+                     if  (normalizeFloatQuery ) {
82+                         double  queryMagnitude  = 0.0 ;
83+                         for  (float  val  : floatQueryVector ) {
84+                             queryMagnitude  += val  * val ;
85+                         }
86+                         for  (int  i  = 0 ; i  < floatQueryVector .length ; i ++) {
87+                             floatQueryVector [i ] /= (float ) queryMagnitude ;
88+                         }
89+                     }
7790                    break ;
7891                case  BYTE :
7992                    floatQueryVector  = null ;
@@ -156,7 +169,7 @@ public interface L1NormInterface {
156169    public  static  class  ByteL1Norm  extends  ByteDenseVectorFunction  implements  L1NormInterface  {
157170
158171        public  ByteL1Norm (ScoreScript  scoreScript , DenseVectorDocValuesField  field , List <Number > queryVector ) {
159-             super (scoreScript , field , queryVector , ElementType .BYTE );
172+             super (scoreScript , field , queryVector , false ,  ElementType .BYTE );
160173        }
161174
162175        public  ByteL1Norm (ScoreScript  scoreScript , DenseVectorDocValuesField  field , byte [] queryVector ) {
@@ -220,7 +233,7 @@ public interface HammingDistanceInterface {
220233    public  static  class  ByteHammingDistance  extends  ByteDenseVectorFunction  implements  HammingDistanceInterface  {
221234
222235        public  ByteHammingDistance (ScoreScript  scoreScript , DenseVectorDocValuesField  field , List <Number > queryVector ) {
223-             super (scoreScript , field , queryVector , ElementType .BYTE );
236+             super (scoreScript , field , queryVector , false ,  ElementType .BYTE );
224237        }
225238
226239        public  ByteHammingDistance (ScoreScript  scoreScript , DenseVectorDocValuesField  field , byte [] queryVector ) {
@@ -266,7 +279,7 @@ public interface L2NormInterface {
266279    public  static  class  ByteL2Norm  extends  ByteDenseVectorFunction  implements  L2NormInterface  {
267280
268281        public  ByteL2Norm (ScoreScript  scoreScript , DenseVectorDocValuesField  field , List <Number > queryVector ) {
269-             super (scoreScript , field , queryVector , ElementType .BYTE );
282+             super (scoreScript , field , queryVector , false ,  ElementType .BYTE );
270283        }
271284
272285        public  ByteL2Norm (ScoreScript  scoreScript , DenseVectorDocValuesField  field , byte [] queryVector ) {
@@ -411,7 +424,7 @@ public double dotProduct() {
411424    public  static  class  ByteDotProduct  extends  ByteDenseVectorFunction  implements  DotProductInterface  {
412425
413426        public  ByteDotProduct (ScoreScript  scoreScript , DenseVectorDocValuesField  field , List <Number > queryVector ) {
414-             super (scoreScript , field , queryVector , ElementType .BYTE , ElementType .FLOAT );
427+             super (scoreScript , field , queryVector , false ,  ElementType .BYTE , ElementType .FLOAT );
415428        }
416429
417430        public  ByteDotProduct (ScoreScript  scoreScript , DenseVectorDocValuesField  field , byte [] queryVector ) {
@@ -488,7 +501,7 @@ public interface CosineSimilarityInterface {
488501    public  static  class  ByteCosineSimilarity  extends  ByteDenseVectorFunction  implements  CosineSimilarityInterface  {
489502
490503        public  ByteCosineSimilarity (ScoreScript  scoreScript , DenseVectorDocValuesField  field , List <Number > queryVector ) {
491-             super (scoreScript , field , queryVector , ElementType .BYTE , ElementType .FLOAT );
504+             super (scoreScript , field , queryVector , true ,  ElementType .BYTE , ElementType .FLOAT );
492505        }
493506
494507        public  ByteCosineSimilarity (ScoreScript  scoreScript , DenseVectorDocValuesField  field , byte [] queryVector ) {
@@ -498,7 +511,7 @@ public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField f
498511        public  double  cosineSimilarity () {
499512            setNextVector ();
500513            if  (floatQueryVector  != null ) {
501-                 return  field .get ().cosineSimilarity (floatQueryVector );
514+                 return  field .get ().cosineSimilarity (floatQueryVector ,  false );
502515            } else  {
503516                return  field .get ().cosineSimilarity (byteQueryVector , qvMagnitude );
504517            }
0 commit comments