Skip to content

Commit 458ff31

Browse files
committed
PR feedback
1 parent 7214e48 commit 458ff31

File tree

4 files changed

+31
-35
lines changed

4 files changed

+31
-35
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequest.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,8 @@ public HttpRequest createHttpRequest() {
5858
query,
5959
input,
6060
returnDocuments,
61-
topN,
62-
model.getServiceSettings().modelId(),
63-
model.getTaskSettings().topN()
61+
topN != null ? topN : model.getTaskSettings().topN(),
62+
model.getServiceSettings().modelId()
6463
)
6564
).getBytes(StandardCharsets.UTF_8)
6665
);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntity.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ public record GoogleVertexAiRerankRequestEntity(
2020
List<String> inputs,
2121
@Nullable Boolean returnDocuments,
2222
@Nullable Integer topN,
23-
@Nullable String model,
24-
@Nullable Integer taskSettingsTopN
23+
@Nullable String model
2524
) implements ToXContentObject {
2625

2726
private static final String MODEL_FIELD = "model";
@@ -66,8 +65,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
6665
// prefer the root level top_n over task settings
6766
if (topN != null) {
6867
builder.field(TOP_N_FIELD, topN);
69-
} else if (taskSettingsTopN != null) {
70-
builder.field(TOP_N_FIELD, taskSettingsTopN);
7168
}
7269

7370
if (returnDocuments != null) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestEntityTests.java

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
public class GoogleVertexAiRerankRequestEntityTests extends ESTestCase {
2323
public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOException {
24-
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), Boolean.TRUE, 10, "model", 8);
24+
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), Boolean.TRUE, 10, "model");
2525

2626
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
2727
entity.toXContent(builder, null);
@@ -44,7 +44,7 @@ public void testXContent_SingleRequest_WritesAllFieldsIfDefined() throws IOExcep
4444
}
4545

4646
public void testXContent_SingleRequest_WritesMinimalFields() throws IOException {
47-
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null, null, null);
47+
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null, null);
4848

4949
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
5050
entity.toXContent(builder, null);
@@ -64,7 +64,7 @@ public void testXContent_SingleRequest_WritesMinimalFields() throws IOException
6464
}
6565

6666
public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOException {
67-
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), Boolean.FALSE, 12, "model", 8);
67+
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), Boolean.FALSE, 12, "model");
6868

6969
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
7070
entity.toXContent(builder, null);
@@ -91,7 +91,7 @@ public void testXContent_MultipleRequests_WritesAllFieldsIfDefined() throws IOEx
9191
}
9292

9393
public void testXContent_MultipleRequests_WritesMinimalFields() throws IOException {
94-
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null, null, null);
94+
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc", "def"), null, null, null);
9595

9696
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
9797
entity.toXContent(builder, null);
@@ -113,27 +113,4 @@ public void testXContent_MultipleRequests_WritesMinimalFields() throws IOExcepti
113113
}
114114
"""));
115115
}
116-
117-
public void testXContent_SingleRequest_UsesTaskSettingsTopNIfRootIsNotDefined() throws IOException {
118-
var entity = new GoogleVertexAiRerankRequestEntity("query", List.of("abc"), null, null, "model", 8);
119-
120-
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
121-
entity.toXContent(builder, null);
122-
String xContentResult = Strings.toString(builder);
123-
124-
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
125-
{
126-
"model": "model",
127-
"query": "query",
128-
"records": [
129-
{
130-
"id": "0",
131-
"content": "abc"
132-
}
133-
],
134-
"topN": 8
135-
}
136-
"""));
137-
}
138-
139116
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiRerankRequestTests.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,31 @@ public void testCreateRequest_WithTopNSet() throws IOException {
5353
var input = "input";
5454
var query = "query";
5555
var topN = 1;
56+
var taskSettingsTopN = 3;
5657

57-
var request = createRequest(query, input, null, topN, null, null);
58+
var request = createRequest(query, input, null, topN, null, taskSettingsTopN);
59+
var httpRequest = request.createHttpRequest();
60+
61+
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
62+
var httpPost = (HttpPost) httpRequest.httpRequestBase();
63+
64+
assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
65+
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
66+
67+
var requestMap = entityAsMap(httpPost.getEntity().getContent());
68+
69+
assertThat(requestMap, aMapWithSize(3));
70+
assertThat(requestMap.get("records"), is(List.of(Map.of("id", "0", "content", input))));
71+
assertThat(requestMap.get("query"), is(query));
72+
assertThat(requestMap.get("topN"), is(topN));
73+
}
74+
75+
public void testCreateRequest_UsesTaskSettingsTopNWhenRootLevelIsNull() throws IOException {
76+
var input = "input";
77+
var query = "query";
78+
var topN = 1;
79+
80+
var request = createRequest(query, input, null, null, null, topN);
5881
var httpRequest = request.createHttpRequest();
5982

6083
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));

0 commit comments

Comments
 (0)