Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
854c69b
Changed script to work with Elastic 6.0.0. Inline scripts are now dep…
Jan 30, 2019
7cc89d0
renamed engine appropriately
Zakkery Jan 31, 2019
ec09b19
optimize plugin and use float32 (#19)
ran22 Feb 10, 2019
41961da
updated jackson version
Feb 10, 2019
f0c27c5
added a cosine score test
Feb 25, 2019
f688c3a
Broke it back into multiple files for simplicity, added support for e…
Zakkery Apr 4, 2019
2aa0f87
Removed unused import
Zakkery Apr 4, 2019
44fafbd
Update README.md
lior-k Apr 4, 2019
3678458
Removed unused method
Zakkery Apr 4, 2019
63d84a3
changed a double usage in a test to float. as part of the move to flo…
Apr 23, 2019
d9e04a1
Fixed vector to be float instead of double
Zakkery Apr 23, 2019
71fae4e
Changed script to work with Elastic 6.0.0. Inline scripts are now dep…
Jan 30, 2019
343ee31
renamed engine appropriately
Zakkery Jan 31, 2019
9fdb85f
Broke it back into multiple files for simplicity, added support for e…
Zakkery Apr 4, 2019
b97a7c6
Removed unused import
Zakkery Apr 4, 2019
4479ef2
Removed unused method
Zakkery Apr 4, 2019
028c8b3
Fixed vector to be float instead of double
Zakkery Apr 23, 2019
4a2ab17
Merge branch 'es-6.0' of https://github.com/Zakkery/fast-elasticsearc…
Zakkery Apr 23, 2019
aa8be17
Changed script to work with Elastic 6.0.0. Inline scripts are now dep…
Jan 30, 2019
51032b5
renamed engine appropriately
Zakkery Jan 31, 2019
a0365c4
Broke it back into multiple files for simplicity, added support for e…
Zakkery Apr 4, 2019
ac7efae
Removed unused import
Zakkery Apr 4, 2019
79a8753
Removed unused method
Zakkery Apr 4, 2019
3489091
Fixed vector to be float instead of double
Zakkery Apr 23, 2019
c5b2843
Fixed vector to be float instead of double
Zakkery Apr 23, 2019
cd5caa1
Merge branch 'es-6.0' of https://github.com/Zakkery/fast-elasticsearc…
Zakkery Apr 25, 2019
1f1d5cd
Fixed testing and checking of sizes
Zakkery Apr 25, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ give it a try.


## Elasticsearch version
* Currently designed for Elasticsearch 5.6.0.
* Currently designed for Elasticsearch 6.0.0.
* for Elasticsearch 5.2.2 use branch `es-5.2.2`
* for Elasticsearch 2.4.4 use branch `es-2.4.4`

Expand Down Expand Up @@ -146,7 +146,7 @@ func convertBase64ToArray(base64Str string) ([]float64, error) {
"boost_mode": "replace",
"script_score": {
"script": {
"inline": "binary_vector_score",
"source": "binary_vector_score",
"lang": "knn",
"params": {
"cosine": false,
Expand Down
12 changes: 3 additions & 9 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<name>elasticsearch-binary-vector-scoring</name>
<groupId>com.liorkn.elasticsearch</groupId>
<artifactId>elasticsearch-binary-vector-scoring</artifactId>
<version>5.6.0</version>
<version>6.0.0</version>
<description>ElasticSearch Plugin for Binary Vector Scoring</description>

<licenses>
Expand All @@ -27,7 +27,7 @@
<elasticsearch.license.headerDefinition>${project.basedir}/src/main/resources/license-check/license_header_definition.xml</elasticsearch.license.headerDefinition>

<tests.ifNoTests>warn</tests.ifNoTests>
<elasticsearch.version>5.6.0</elasticsearch.version>
<elasticsearch.version>6.0.0</elasticsearch.version>
<commons-io.version>2.4</commons-io.version>
<httpcore.version>4.4.8</httpcore.version>
<junit.version>4.12</junit.version>
Expand Down Expand Up @@ -65,7 +65,7 @@

<dependency>
<groupId>org.elasticsearch.plugin</groupId>
<artifactId>transport-netty3-client</artifactId>
<artifactId>transport-netty4-client</artifactId>
<version>${elasticsearch.version}</version>
<scope>test</scope>
</dependency>
Expand All @@ -86,12 +86,6 @@
</exclusion>
</exclusions>
</dependency>
<!--<dependency>-->
<!--<groupId>org.elasticsearch.plugin</groupId>-->
<!--<artifactId>transport-netty3-client</artifactId>-->
<!--<version>${elasticsearch.version}</version>-->
<!--<scope>test</scope>-->
<!--</dependency>-->
<dependency>
<groupId>org.codelibs.elasticsearch.module</groupId>
<artifactId>lang-painless</artifactId>
Expand Down
154 changes: 149 additions & 5 deletions src/main/java/com/liorkn/elasticsearch/plugin/VectorScoringPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,167 @@
*/
package com.liorkn.elasticsearch.plugin;

import com.liorkn.elasticsearch.service.VectorScoringScriptEngineService;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.util.Collection;
import java.util.Map;

import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.Term;
import org.apache.lucene.store.ByteArrayDataInput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.ScriptPlugin;
import org.elasticsearch.script.ScriptEngineService;

import org.elasticsearch.script.ScriptContext;
import org.elasticsearch.script.ScriptEngine;
import org.elasticsearch.script.SearchScript;
import java.util.ArrayList;
/**
* This class is instantiated when Elasticsearch loads the plugin for the
* first time. If you change the name of this plugin, make sure to update
* src/main/resources/es-plugin.properties file that points to this class.
*/
public final class VectorScoringPlugin extends Plugin implements ScriptPlugin {

public final ScriptEngineService getScriptEngineService(Settings settings) {
return new VectorScoringScriptEngineService(settings);
@Override
public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {
return new VectorScoringPluginEngine();
}

/** This {@link ScriptEngine} uses Lucene segment details to implement document scoring based on their similarity with submitted document. */
private static class VectorScoringPluginEngine implements ScriptEngine {
@Override
public String getType() {
return "knn";
}

private static final int DOUBLE_SIZE = 8;

@Override
public <T> T compile(String scriptName, String scriptSource, ScriptContext<T> context, Map<String, String> params) {

if (context.equals(SearchScript.CONTEXT) == false) {
throw new IllegalArgumentException(getType() + " scripts cannot be used for context [" + context.name + "]");
}

// we use the script "source" as the script identifier
if ("binary_vector_score".equals(scriptSource)) {
SearchScript.Factory factory = (p, lookup) -> new SearchScript.LeafFactory() {
final String field;
final boolean cosine;
{
if (p.containsKey("vector") == false) {
throw new IllegalArgumentException("Missing parameter [vector]");
}
if (p.containsKey("field") == false) {
throw new IllegalArgumentException("Missing parameter [field]");
}
if (p.containsKey("cosine") == false) {
throw new IllegalArgumentException("Missing parameter [cosine]");
}
field = p.get("field").toString();
cosine = (boolean) p.get("cosine");
}

final ArrayList<Double> searchVector = (ArrayList<Double>) p.get("vector");
double magnitude;
{
if (cosine) {
// calc magnitude
double queryVectorNorm = 0.0;
// compute query inputVector norm once
for (Double v : this.searchVector) {
queryVectorNorm += v.doubleValue() * v.doubleValue();
}
magnitude = Math.sqrt(queryVectorNorm);
} else {
magnitude = 0.0;
}
}

@Override
public SearchScript newInstance(LeafReaderContext context) throws IOException {
return new SearchScript(p, lookup, context) {
BinaryDocValues docAccess = context.reader().getBinaryDocValues(field);
int currentDocid = -1;

@Override
public void setDocument(int docid) {
// Move to desired document
try {
docAccess.advanceExact(docid);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
currentDocid = docid;
}

@Override
public double runAsDouble() {
if (currentDocid < 0) {
return 0.0;
}
//actually run scoring
final int size = searchVector.size();

try {
final byte[] bytes = docAccess.binaryValue().bytes;
final ByteArrayDataInput input = new ByteArrayDataInput(bytes);
input.readVInt(); // returns the number of values which should be 1, MUST appear hear since it affect the next calls
final int len = input.readVInt(); // returns the number of bytes to read//if submitted vector is different size
if (len != size * DOUBLE_SIZE) {
return 0.0;
}

final int position = input.getPosition();
final DoubleBuffer doubleBuffer = ByteBuffer.wrap(bytes, position, len).asDoubleBuffer();

final double[] docVector = new double[size];
doubleBuffer.get(docVector);
double docVectorNorm = 0.0f;
double score = 0;
for (int i = 0; i < size; i++) {
// doc inputVector norm
if(cosine) {
docVectorNorm += docVector[i]*docVector[i];
}
// dot product
score += docVector[i] * searchVector.get(i).doubleValue();
}
if(cosine) {
// cosine similarity score
if (docVectorNorm == 0 || magnitude == 0){
return 0f;
} else {
return score / (Math.sqrt(docVectorNorm) * magnitude);
}
} else {
return score;
}
} catch (Exception e) {
return 0;
}
}
};
}

@Override
public boolean needs_score() {
return false;
}
};
return context.factoryClazz.cast(factory);
}
throw new IllegalArgumentException("Unknown script name " + scriptSource);
}

@Override
public void close() {
// optionally close resources
}
}
}
Loading