1414
1515package com .liorkn .elasticsearch .script ;
1616
17+ import com .liorkn .elasticsearch .Util ;
1718import org .apache .lucene .index .BinaryDocValues ;
18- import org .apache .lucene .search .Scorer ;
1919import org .apache .lucene .store .ByteArrayDataInput ;
2020import org .elasticsearch .common .Nullable ;
2121import org .elasticsearch .script .ExecutableScript ;
3232 */
3333public final class VectorScoreScript implements LeafSearchScript , ExecutableScript {
3434
35- // private final static ESLogger logger = ESLoggerFactory.getLogger(VectorScoreScript.class.getName());
36- public final static String SCRIPT_NAME = "binary_vector_score" ;
35+ public static final String SCRIPT_NAME = "binary_vector_score" ;
36+
37+ private static final int DOUBLE_SIZE = 8 ;
3738
3839 // the field containing the vectors to be scored against
3940 public final String field ;
4041
41- private static final int DOUBLE_SIZE = 8 ;
42-
4342 private int docId ;
4443 private BinaryDocValues binaryEmbeddingReader ;
4544
@@ -49,24 +48,15 @@ public final class VectorScoreScript implements LeafSearchScript, ExecutableScri
4948 private final boolean cosine ;
5049
5150 @ Override
52- public void setScorer (Scorer scorer ) {
53- }
54- public void setSource (Map <String , Object > source ) {
55- }
56- public float runAsFloat () {
57- return ((Number )this .run ()).floatValue ();
58- }
59-
6051 public long runAsLong () {
6152 return ((Number )this .run ()).longValue ();
6253 }
54+ @ Override
6355 public double runAsDouble () {
6456 return ((Number )this .run ()).doubleValue ();
6557 }
66- public Object unwrap (Object value ) {
67- return value ;
68- }
69-
58+ @ Override
59+ public void setNextVar (String name , Object value ) {}
7060 @ Override
7161 public void setDocument (int docId ) {
7262 this .docId = docId ;
@@ -127,17 +117,27 @@ public VectorScoreScript(Map<String, Object> params) {
127117 this .field = field .toString ();
128118
129119 // get query inputVector - convert to primitive
130- final ArrayList <Double > tmp = (ArrayList <Double >) params .get ("vector" );
131- this .inputVector = new double [tmp .size ()];
132- for (int i = 0 ; i < inputVector .length ; i ++) {
133- inputVector [i ] = tmp .get (i );
120+
121+ final Object vector = params .get ("vector" );
122+ if (vector != null ) {
123+ final ArrayList <Double > tmp = (ArrayList <Double >) vector ;
124+ inputVector = new double [tmp .size ()];
125+ for (int i = 0 ; i < inputVector .length ; i ++) {
126+ inputVector [i ] = tmp .get (i );
127+ }
128+ } else {
129+ final Object encodedVector = params .get ("encoded_vector" );
130+ if (encodedVector == null ) {
131+ throw new IllegalArgumentException ("Must have at 'vector' or 'encoded_vector' as a parameter" );
132+ }
133+ inputVector = Util .convertBase64ToArray ((String ) encodedVector );
134134 }
135135
136136 if (cosine ) {
137137 // calc magnitude
138138 double queryVectorNorm = 0.0 ;
139139 // compute query inputVector norm once
140- for (double v : this . inputVector ) {
140+ for (double v : inputVector ) {
141141 queryVectorNorm += v * v ;
142142 }
143143 magnitude = Math .sqrt (queryVectorNorm );
@@ -146,9 +146,7 @@ public VectorScoreScript(Map<String, Object> params) {
146146 }
147147 }
148148
149- @ Override
150- public void setNextVar (String name , Object value ) {
151- }
149+
152150
153151 /**
154152 * Called for each document
0 commit comments