Skip to content

Commit 54fb4ea

Browse files
author
Lior Knaany
committed
added an option to receive a base64 encoded vector as an input
1 parent 5bfd46f commit 54fb4ea

File tree

2 files changed

+54
-25
lines changed

2 files changed

+54
-25
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package com.liorkn.elasticsearch;
2+
3+
import java.nio.ByteBuffer;
4+
import java.nio.DoubleBuffer;
5+
import java.util.Base64;
6+
7+
/**
8+
* Created by Lior Knaany on 4/7/18.
9+
*/
10+
public class Util {
11+
12+
public static final double[] convertBase64ToArray(String base64Str) {
13+
final byte[] decode = Base64.getDecoder().decode(base64Str.getBytes());
14+
final DoubleBuffer doubleBuffer = ByteBuffer.wrap(decode).asDoubleBuffer();
15+
16+
final double[] dims = new double[doubleBuffer.capacity()];
17+
doubleBuffer.get(dims);
18+
return dims;
19+
}
20+
21+
public static final String convertArrayToBase64(double[] array) {
22+
final int capacity = 8 * array.length;
23+
final ByteBuffer bb = ByteBuffer.allocate(capacity);
24+
for (int i = 0; i < array.length; i++) {
25+
bb.putDouble(array[i]);
26+
}
27+
bb.rewind();
28+
final ByteBuffer encodedBB = Base64.getEncoder().encode(bb);
29+
return new String(encodedBB.array());
30+
}
31+
}

src/main/java/com/liorkn/elasticsearch/script/VectorScoreScript.java

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
package com.liorkn.elasticsearch.script;
1616

17+
import com.liorkn.elasticsearch.Util;
1718
import org.apache.lucene.index.BinaryDocValues;
18-
import org.apache.lucene.search.Scorer;
1919
import org.apache.lucene.store.ByteArrayDataInput;
2020
import org.elasticsearch.common.Nullable;
2121
import org.elasticsearch.script.ExecutableScript;
@@ -32,14 +32,13 @@
3232
*/
3333
public final class VectorScoreScript implements LeafSearchScript, ExecutableScript {
3434

35-
// private final static ESLogger logger = ESLoggerFactory.getLogger(VectorScoreScript.class.getName());
36-
public final static String SCRIPT_NAME = "binary_vector_score";
35+
public static final String SCRIPT_NAME = "binary_vector_score";
36+
37+
private static final int DOUBLE_SIZE = 8;
3738

3839
// the field containing the vectors to be scored against
3940
public final String field;
4041

41-
private static final int DOUBLE_SIZE = 8;
42-
4342
private int docId;
4443
private BinaryDocValues binaryEmbeddingReader;
4544

@@ -49,24 +48,15 @@ public final class VectorScoreScript implements LeafSearchScript, ExecutableScri
4948
private final boolean cosine;
5049

5150
@Override
52-
public void setScorer(Scorer scorer) {
53-
}
54-
public void setSource(Map<String, Object> source) {
55-
}
56-
public float runAsFloat() {
57-
return ((Number)this.run()).floatValue();
58-
}
59-
6051
public long runAsLong() {
6152
return ((Number)this.run()).longValue();
6253
}
54+
@Override
6355
public double runAsDouble() {
6456
return ((Number)this.run()).doubleValue();
6557
}
66-
public Object unwrap(Object value) {
67-
return value;
68-
}
69-
58+
@Override
59+
public void setNextVar(String name, Object value) {}
7060
@Override
7161
public void setDocument(int docId) {
7262
this.docId = docId;
@@ -127,17 +117,27 @@ public VectorScoreScript(Map<String, Object> params) {
127117
this.field = field.toString();
128118

129119
// get query inputVector - convert to primitive
130-
final ArrayList<Double> tmp = (ArrayList<Double>) params.get("vector");
131-
this.inputVector = new double[tmp.size()];
132-
for (int i = 0; i < inputVector.length; i++) {
133-
inputVector[i] = tmp.get(i);
120+
121+
final Object vector = params.get("vector");
122+
if(vector != null) {
123+
final ArrayList<Double> tmp = (ArrayList<Double>) vector;
124+
inputVector = new double[tmp.size()];
125+
for (int i = 0; i < inputVector.length; i++) {
126+
inputVector[i] = tmp.get(i);
127+
}
128+
} else {
129+
final Object encodedVector = params.get("encoded_vector");
130+
if(encodedVector == null) {
131+
throw new IllegalArgumentException("Must have at 'vector' or 'encoded_vector' as a parameter");
132+
}
133+
inputVector = Util.convertBase64ToArray((String) encodedVector);
134134
}
135135

136136
if(cosine) {
137137
// calc magnitude
138138
double queryVectorNorm = 0.0;
139139
// compute query inputVector norm once
140-
for (double v : this.inputVector) {
140+
for (double v : inputVector) {
141141
queryVectorNorm += v * v;
142142
}
143143
magnitude = Math.sqrt(queryVectorNorm);
@@ -146,9 +146,7 @@ public VectorScoreScript(Map<String, Object> params) {
146146
}
147147
}
148148

149-
@Override
150-
public void setNextVar(String name, Object value) {
151-
}
149+
152150

153151
/**
154152
* Called for each document

0 commit comments

Comments
 (0)