Skip to content

Commit 6474c6a

Browse files
author
Lior Knaany
committed
added a unit test that performs the vector scoring search
1 parent 5bffaa7 commit 6474c6a

File tree

6 files changed

+162
-67
lines changed

6 files changed

+162
-67
lines changed

pom.xml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@
3636

3737

3838
<dependencies>
39+
<dependency>
40+
<groupId>org.apache.logging.log4j</groupId>
41+
<artifactId>log4j-api</artifactId>
42+
<version>2.7</version>
43+
<scope>test</scope>
44+
</dependency>
45+
<dependency>
46+
<groupId>org.apache.logging.log4j</groupId>
47+
<artifactId>log4j-core</artifactId>
48+
<version>2.7</version>
49+
<scope>test</scope>
50+
</dependency>
51+
3952
<dependency>
4053
<groupId>org.elasticsearch</groupId>
4154
<artifactId>elasticsearch</artifactId>
@@ -50,6 +63,13 @@
5063
<scope>test</scope>
5164
</dependency>
5265

66+
<dependency>
67+
<groupId>org.elasticsearch.plugin</groupId>
68+
<artifactId>transport-netty3-client</artifactId>
69+
<version>${elasticsearch.version}</version>
70+
<scope>test</scope>
71+
</dependency>
72+
5373
<dependency>
5474
<groupId>org.codelibs.elasticsearch.module</groupId>
5575
<artifactId>reindex</artifactId>

src/test/com/liorkn/elasticsearch/TestPlugin.java

Lines changed: 0 additions & 65 deletions
This file was deleted.
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package com.liorkn.elasticsearch;
2+
3+
import com.fasterxml.jackson.core.JsonParser;
4+
import com.fasterxml.jackson.databind.ObjectMapper;
5+
import org.apache.http.HttpHost;
6+
import org.apache.http.entity.ContentType;
7+
import org.apache.http.entity.StringEntity;
8+
import org.apache.http.nio.entity.NStringEntity;
9+
import org.apache.http.util.EntityUtils;
10+
import org.elasticsearch.client.Response;
11+
import org.elasticsearch.client.RestClient;
12+
import org.junit.AfterClass;
13+
import org.junit.Assert;
14+
import org.junit.BeforeClass;
15+
import org.junit.Test;
16+
17+
import java.io.IOException;
18+
import java.util.Collections;
19+
import java.util.HashMap;
20+
import java.util.Map;
21+
22+
/**
23+
* Created by Lior Knaany on 4/7/18.
24+
*/
25+
public class PluginTest {
26+
27+
private static EmbeddedElasticsearchServer esServer;
28+
private static RestClient esClient;
29+
30+
@BeforeClass
31+
public static void init() throws Exception {
32+
esServer = new EmbeddedElasticsearchServer();
33+
esClient = RestClient.builder(new HttpHost("localhost", esServer.getPort(), "http")).build();
34+
35+
// delete test index if exists
36+
try {
37+
esClient.performRequest("DELETE", "/test", Collections.emptyMap());
38+
} catch (Exception e) {}
39+
40+
// create test index
41+
String mappingJson = "{\n" +
42+
" \"mappings\": {\n" +
43+
" \"type\": {\n" +
44+
" \"properties\": {\n" +
45+
" \"embedding_vector\": {\n" +
46+
" \"doc_values\": true,\n" +
47+
" \"type\": \"binary\"\n" +
48+
" },\n" +
49+
" \"job_id\": {\n" +
50+
" \"type\": \"long\"\n" +
51+
" },\n" +
52+
" \"vector\": {\n" +
53+
" \"type\": \"float\"\n" +
54+
" }\n" +
55+
" }\n" +
56+
" }\n" +
57+
" }\n" +
58+
"}";
59+
esClient.performRequest("PUT", "/test", Collections.emptyMap(), new NStringEntity(mappingJson, ContentType.APPLICATION_JSON));
60+
}
61+
62+
public static final ObjectMapper mapper = new ObjectMapper();
63+
static {
64+
mapper.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, true);
65+
mapper.configure(JsonParser.Feature.ALLOW_UNQUOTED_FIELD_NAMES, true);
66+
}
67+
68+
69+
@Test
70+
public void test() throws Exception {
71+
final Map<String, String> params = new HashMap<>();
72+
params.put("refresh", "true");
73+
final ObjectMapper mapper = new ObjectMapper();
74+
final TestObject[] objs = {new TestObject(1, new double[] {0.0, 0.5, 1.0}),
75+
new TestObject(2, new double[] {0.2, 0.6, 0.99})};
76+
77+
for (int i = 0; i < objs.length; i++) {
78+
final TestObject t = objs[i];
79+
final String json = mapper.writeValueAsString(t);
80+
System.out.println(json);
81+
final Response put = esClient.performRequest("PUT", "/test/type/" + t.jobId, params, new StringEntity(json, ContentType.APPLICATION_JSON));
82+
System.out.println(put);
83+
System.out.println(EntityUtils.toString(put.getEntity()));
84+
final int statusCode = put.getStatusLine().getStatusCode();
85+
Assert.assertTrue(statusCode == 200 || statusCode == 201);
86+
}
87+
88+
// Test cosine score function
89+
String body = "{" +
90+
" \"query\": {" +
91+
" \"function_score\": {" +
92+
" \"boost_mode\": \"replace\"," +
93+
" \"script_score\": {" +
94+
" \"script\": {" +
95+
" \"inline\": \"binary_vector_score\"," +
96+
" \"lang\": \"knn\"," +
97+
" \"params\": {" +
98+
" \"cosine\": false," +
99+
" \"field\": \"embedding_vector\"," +
100+
" \"vector\": [" +
101+
" 0.1, 0.2, 0.3" +
102+
" ]" +
103+
" }" +
104+
" }" +
105+
" }" +
106+
" }" +
107+
" }," +
108+
" \"size\": 100" +
109+
"}";
110+
final Response res = esClient.performRequest("POST", "/test/_search", Collections.emptyMap(), new NStringEntity(body, ContentType.APPLICATION_JSON));
111+
System.out.println(res);
112+
final String resBody = EntityUtils.toString(res.getEntity());
113+
System.out.println(resBody);
114+
Assert.assertEquals("search should return status code 200", 200, res.getStatusLine().getStatusCode());
115+
Assert.assertTrue(String.format("There should be %d documents in the search response", objs.length), resBody.contains("\"hits\":{\"total\":" + objs.length));
116+
}
117+
118+
@AfterClass
119+
public static void shutdown() {
120+
try {
121+
esClient.close();
122+
esServer.shutdown();
123+
} catch (IOException e) {
124+
e.printStackTrace();
125+
}
126+
}
127+
128+
}

src/test/com/liorkn/elasticsearch/TestObject.java renamed to src/test/java/com/liorkn/elasticsearch/TestObject.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,24 @@
99
@JsonNaming(PropertyNamingStrategy.SnakeCaseStrategy.class)
1010
public class TestObject {
1111
int jobId;
12-
String base64Vector;
12+
String embeddingVector;
1313
double[] vector;
1414

15+
public int getJobId() {
16+
return jobId;
17+
}
18+
19+
public String getEmbeddingVector() {
20+
return embeddingVector;
21+
}
22+
23+
public double[] getVector() {
24+
return vector;
25+
}
26+
1527
public TestObject(int jobId, double[] vector) {
1628
this.jobId = jobId;
1729
this.vector = vector;
18-
this.base64Vector = Util.convertArrayToBase64(vector);
30+
this.embeddingVector = Util.convertArrayToBase64(vector);
1931
}
2032
}

0 commit comments

Comments
 (0)