1111
1212import org .elasticsearch .ExceptionsHelper ;
1313import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper ;
14+ import org .elasticsearch .index .mapper .vectors .DenseVectorFieldMapper .ElementType ;
1415import org .elasticsearch .script .field .vectors .DenseVector ;
1516import org .elasticsearch .script .field .vectors .DenseVectorDocValuesField ;
1617
@@ -42,7 +43,10 @@ void setNextVector() {
4243 }
4344
4445 public static class ByteDenseVectorFunction extends DenseVectorFunction {
45- protected final byte [] queryVector ;
46+ // either byteQueryVector or floatQueryVector will be non-null
47+ protected final byte [] byteQueryVector ;
48+ protected final float [] floatQueryVector ;
49+ // only valid if byteQueryVector is used
4650 protected final float qvMagnitude ;
4751
4852 /**
@@ -52,21 +56,39 @@ public static class ByteDenseVectorFunction extends DenseVectorFunction {
5256 * @param field The vector field.
5357 * @param queryVector The query vector.
5458 */
55- public ByteDenseVectorFunction (ScoreScript scoreScript , DenseVectorDocValuesField field , List <Number > queryVector ) {
59+ public ByteDenseVectorFunction (
60+ ScoreScript scoreScript ,
61+ DenseVectorDocValuesField field ,
62+ List <Number > queryVector ,
63+ ElementType ... allowedTypes
64+ ) {
5665 super (scoreScript , field );
5766 field .getElementType ().checkDimensions (field .get ().getDims (), queryVector .size ());
58- this .queryVector = new byte [queryVector .size ()];
59- float [] validateValues = new float [queryVector .size ()];
60- int queryMagnitude = 0 ;
67+ float [] floatValues = new float [queryVector .size ()];
6168 for (int i = 0 ; i < queryVector .size (); i ++) {
62- final Number number = queryVector .get (i );
63- byte value = number .byteValue ();
64- this .queryVector [i ] = value ;
65- queryMagnitude += value * value ;
66- validateValues [i ] = number .floatValue ();
69+ floatValues [i ] = queryVector .get (i ).floatValue ();
6770 }
68- this .qvMagnitude = (float ) Math .sqrt (queryMagnitude );
69- field .getElementType ().checkVectorBounds (validateValues );
71+
72+ switch (ElementType .checkValidVector (floatValues , allowedTypes )) {
73+ case FLOAT :
74+ byteQueryVector = null ;
75+ floatQueryVector = floatValues ;
76+ qvMagnitude = -1 ; // invalid valid, not used for float vectors
77+ break ;
78+ case BYTE :
79+ floatQueryVector = null ;
80+ byteQueryVector = new byte [floatValues .length ];
81+ float queryMagnitude = 0 ;
82+ for (int i = 0 ; i < floatValues .length ; i ++) {
83+ byteQueryVector [i ] = (byte ) floatValues [i ];
84+ queryMagnitude += floatValues [i ] * floatValues [i ];
85+ }
86+ this .qvMagnitude = (float ) Math .sqrt (queryMagnitude );
87+ break ;
88+ default :
89+ throw new AssertionError ("Unexpected element type" );
90+ }
91+
7092 }
7193
7294 /**
@@ -78,7 +100,8 @@ public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesFiel
78100 */
79101 public ByteDenseVectorFunction (ScoreScript scoreScript , DenseVectorDocValuesField field , byte [] queryVector ) {
80102 super (scoreScript , field );
81- this .queryVector = queryVector ;
103+ byteQueryVector = queryVector ;
104+ floatQueryVector = null ;
82105 float queryMagnitude = 0.0f ;
83106 for (byte value : queryVector ) {
84107 queryMagnitude += value * value ;
@@ -115,7 +138,7 @@ public FloatDenseVectorFunction(
115138 queryMagnitude += value * value ;
116139 }
117140 queryMagnitude = Math .sqrt (queryMagnitude );
118- field . getElementType () .checkVectorBounds (this .queryVector );
141+ DenseVectorFieldMapper . ElementType . FLOAT .checkVectorBounds (this .queryVector );
119142
120143 if (normalizeQuery ) {
121144 for (int dim = 0 ; dim < this .queryVector .length ; dim ++) {
@@ -133,7 +156,7 @@ public interface L1NormInterface {
133156 public static class ByteL1Norm extends ByteDenseVectorFunction implements L1NormInterface {
134157
135158 public ByteL1Norm (ScoreScript scoreScript , DenseVectorDocValuesField field , List <Number > queryVector ) {
136- super (scoreScript , field , queryVector );
159+ super (scoreScript , field , queryVector , ElementType . BYTE );
137160 }
138161
139162 public ByteL1Norm (ScoreScript scoreScript , DenseVectorDocValuesField field , byte [] queryVector ) {
@@ -142,7 +165,7 @@ public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte
142165
143166 public double l1norm () {
144167 setNextVector ();
145- return field .get ().l1Norm (queryVector );
168+ return field .get ().l1Norm (byteQueryVector );
146169 }
147170 }
148171
@@ -197,7 +220,7 @@ public interface HammingDistanceInterface {
197220 public static class ByteHammingDistance extends ByteDenseVectorFunction implements HammingDistanceInterface {
198221
199222 public ByteHammingDistance (ScoreScript scoreScript , DenseVectorDocValuesField field , List <Number > queryVector ) {
200- super (scoreScript , field , queryVector );
223+ super (scoreScript , field , queryVector , ElementType . BYTE );
201224 }
202225
203226 public ByteHammingDistance (ScoreScript scoreScript , DenseVectorDocValuesField field , byte [] queryVector ) {
@@ -206,7 +229,7 @@ public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField fi
206229
207230 public int hamming () {
208231 setNextVector ();
209- return field .get ().hamming (queryVector );
232+ return field .get ().hamming (byteQueryVector );
210233 }
211234 }
212235
@@ -243,7 +266,7 @@ public interface L2NormInterface {
243266 public static class ByteL2Norm extends ByteDenseVectorFunction implements L2NormInterface {
244267
245268 public ByteL2Norm (ScoreScript scoreScript , DenseVectorDocValuesField field , List <Number > queryVector ) {
246- super (scoreScript , field , queryVector );
269+ super (scoreScript , field , queryVector , ElementType . BYTE );
247270 }
248271
249272 public ByteL2Norm (ScoreScript scoreScript , DenseVectorDocValuesField field , byte [] queryVector ) {
@@ -252,7 +275,7 @@ public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte
252275
253276 public double l2norm () {
254277 setNextVector ();
255- return field .get ().l2Norm (queryVector );
278+ return field .get ().l2Norm (byteQueryVector );
256279 }
257280 }
258281
@@ -388,7 +411,7 @@ public double dotProduct() {
388411 public static class ByteDotProduct extends ByteDenseVectorFunction implements DotProductInterface {
389412
390413 public ByteDotProduct (ScoreScript scoreScript , DenseVectorDocValuesField field , List <Number > queryVector ) {
391- super (scoreScript , field , queryVector );
414+ super (scoreScript , field , queryVector , ElementType . BYTE , ElementType . FLOAT );
392415 }
393416
394417 public ByteDotProduct (ScoreScript scoreScript , DenseVectorDocValuesField field , byte [] queryVector ) {
@@ -397,7 +420,11 @@ public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field,
397420
398421 public double dotProduct () {
399422 setNextVector ();
400- return field .get ().dotProduct (queryVector );
423+ if (floatQueryVector != null ) {
424+ return field .get ().dotProduct (floatQueryVector );
425+ } else {
426+ return field .get ().dotProduct (byteQueryVector );
427+ }
401428 }
402429 }
403430
@@ -461,7 +488,7 @@ public interface CosineSimilarityInterface {
461488 public static class ByteCosineSimilarity extends ByteDenseVectorFunction implements CosineSimilarityInterface {
462489
463490 public ByteCosineSimilarity (ScoreScript scoreScript , DenseVectorDocValuesField field , List <Number > queryVector ) {
464- super (scoreScript , field , queryVector );
491+ super (scoreScript , field , queryVector , ElementType . BYTE , ElementType . FLOAT );
465492 }
466493
467494 public ByteCosineSimilarity (ScoreScript scoreScript , DenseVectorDocValuesField field , byte [] queryVector ) {
@@ -470,7 +497,11 @@ public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField f
470497
471498 public double cosineSimilarity () {
472499 setNextVector ();
473- return field .get ().cosineSimilarity (queryVector , qvMagnitude );
500+ if (floatQueryVector != null ) {
501+ return field .get ().cosineSimilarity (floatQueryVector );
502+ } else {
503+ return field .get ().cosineSimilarity (byteQueryVector , qvMagnitude );
504+ }
474505 }
475506 }
476507
0 commit comments