Skip to content

Commit 37e31f3

Browse files
Add test for IbmWatonxRankedRequest
1 parent 65e187c commit 37e31f3

File tree

3 files changed

+156
-0
lines changed

3 files changed

+156
-0
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,18 @@ public Request truncate() {
9292
return this; // TODO?
9393
}
9494

95+
public String getQuery(){
96+
return query;
97+
}
98+
99+
public List<String> getInput(){
100+
return input;
101+
}
102+
103+
public IbmWatsonxRerankModel getModel(){
104+
return model;
105+
}
106+
95107
@Override
96108
public boolean[] getTruncationInfo() {
97109
return null;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.ibmwatsonx.rerank;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.elasticsearch.core.Nullable;
13+
import org.elasticsearch.core.Strings;
14+
import org.elasticsearch.test.ESTestCase;
15+
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.inference.external.request.Request;
17+
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxRerankRequest;
18+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModelTests;
19+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
20+
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModelTests;
21+
22+
import java.io.IOException;
23+
import java.net.URI;
24+
import java.util.List;
25+
import java.util.Map;
26+
27+
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
28+
import static org.hamcrest.Matchers.aMapWithSize;
29+
import static org.hamcrest.Matchers.endsWith;
30+
import static org.hamcrest.Matchers.instanceOf;
31+
import static org.hamcrest.Matchers.is;
32+
33+
public class IbmWatsonxRerankRequestTests extends ESTestCase {
34+
private static final String AUTH_HEADER_VALUE = "foo";
35+
36+
public void testCreateRequest() throws IOException {
37+
var model = "model";
38+
var projectId = "project_id";
39+
URI uri = null;
40+
try {
41+
uri = new URI("http://abc.com");
42+
} catch (Exception ignored) {}
43+
var apiVersion = "2023-05-04";
44+
var apiKey = "api_key";
45+
var query = "database";
46+
List<String> input = List.of("greenland", "google","john", "mysql","potter", "grammar");
47+
48+
var request = createRequest(model, projectId, uri, apiVersion, apiKey, query, input);
49+
var httpRequest = request.createHttpRequest();
50+
51+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
52+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
53+
54+
assertThat(httpPost.getURI().toString(), endsWith(Strings.format("%s=%s", "version", apiVersion)));
55+
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
56+
57+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
58+
assertThat(requestMap, aMapWithSize(5));
59+
assertThat(
60+
requestMap,
61+
is(
62+
63+
Map.of("project_id", "project_id", "model_id", "model", "inputs",
64+
List.of(Map.of("text", "greenland"),
65+
Map.of("text", "google"),
66+
Map.of("text", "john"),
67+
Map.of("text", "mysql"),
68+
Map.of("text", "potter"),
69+
Map.of("text", "grammar")
70+
),
71+
"query", "database",
72+
"parameters",
73+
Map.of("return_options",
74+
Map.of("top_n", 2,
75+
"inputs", true),
76+
"truncate_input_tokens", 100)
77+
)
78+
)
79+
);
80+
}
81+
82+
public static IbmWatsonxRerankRequest createRequest(
83+
String model,
84+
String projectId,
85+
URI uri,
86+
String apiVersion,
87+
String apiKey,
88+
String query,
89+
List<String> input
90+
) {
91+
var embeddingsModel = IbmWatsonxRerankModelTests.createModel(model, projectId, uri, apiVersion, apiKey);
92+
93+
return new IbmWatsonxRerankWithoutAuthRequest(
94+
query,
95+
input,
96+
embeddingsModel
97+
);
98+
}
99+
100+
private static class IbmWatsonxRerankWithoutAuthRequest extends IbmWatsonxRerankRequest {
101+
IbmWatsonxRerankWithoutAuthRequest(String query, List<String> input, IbmWatsonxRerankModel model) {
102+
super(query, input, model);
103+
}
104+
105+
@Override
106+
public void decorateWithAuth(HttpPost httpPost) {
107+
httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE);
108+
}
109+
}
110+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank;
9+
10+
import org.elasticsearch.common.settings.SecureString;
11+
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
14+
15+
import java.net.URI;
16+
17+
public class IbmWatsonxRerankModelTests extends ESTestCase {
18+
public static IbmWatsonxRerankModel createModel(
19+
String model,
20+
String projectId,
21+
URI uri,
22+
String apiVersion,
23+
String apiKey
24+
) {
25+
return new IbmWatsonxRerankModel(
26+
"id",
27+
TaskType.RERANK,
28+
"service",
29+
new IbmWatsonxRerankServiceSettings(uri, apiVersion, model, projectId, null),
30+
new IbmWatsonxRerankTaskSettings(2, true, 100),
31+
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
32+
);
33+
}
34+
}

0 commit comments

Comments
 (0)