Skip to content

Commit ae1a1d2

Browse files
apply suggestions
1 parent c8c74d6 commit ae1a1d2

File tree

6 files changed

+111
-64
lines changed

6 files changed

+111
-64
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,8 @@ public HttpRequest createHttpRequest() {
5656
HttpPost httpPost = new HttpPost(account.uri());
5757

5858
ByteArrayEntity byteEntity = new ByteArrayEntity(
59-
Strings.toString(
60-
new HuggingFaceRerankRequestEntity(
61-
query,
62-
input,
63-
returnDocuments,
64-
topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly(),
65-
model.getTaskSettings()
66-
)
67-
).getBytes(StandardCharsets.UTF_8)
59+
Strings.toString(new HuggingFaceRerankRequestEntity(query, input, returnDocuments, getTopN(), model.getTaskSettings()))
60+
.getBytes(StandardCharsets.UTF_8)
6861
);
6962
httpPost.setEntity(byteEntity);
7063
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
@@ -89,7 +82,7 @@ public URI getURI() {
8982
}
9083

9184
public Integer getTopN() {
92-
return topN;
85+
return topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly();
9386
}
9487

9588
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ public record HuggingFaceRerankRequestEntity(
2424
HuggingFaceRerankTaskSettings taskSettings
2525
) implements ToXContentObject {
2626

27+
private static final String RETURN_TEXT = "return_text";
2728
private static final String DOCUMENTS_FIELD = "texts";
2829
private static final String QUERY_FIELD = "query";
2930

@@ -42,9 +43,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
4243

4344
// prefer the root level return_documents over task settings
4445
if (returnDocuments != null) {
45-
builder.field(HuggingFaceRerankTaskSettings.RETURN_TEXT, returnDocuments);
46+
builder.field(RETURN_TEXT, returnDocuments);
4647
} else if (taskSettings.getReturnDocuments() != null) {
47-
builder.field(HuggingFaceRerankTaskSettings.RETURN_TEXT, taskSettings.getReturnDocuments());
48+
builder.field(RETURN_TEXT, taskSettings.getReturnDocuments());
4849
}
4950

5051
if (topN != null) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettings.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
public class HuggingFaceRerankTaskSettings implements TaskSettings {
2929

3030
public static final String NAME = "hugging_face_rerank_task_settings";
31-
public static final String RETURN_TEXT = "return_text";
31+
public static final String RETURN_DOCUMENTS = "return_documents";
3232
public static final String TOP_N_DOCS_ONLY = "top_n";
3333

3434
static final HuggingFaceRerankTaskSettings EMPTY_SETTINGS = new HuggingFaceRerankTaskSettings(null, null);
@@ -40,7 +40,7 @@ public static HuggingFaceRerankTaskSettings fromMap(Map<String, Object> map) {
4040
return EMPTY_SETTINGS;
4141
}
4242

43-
Boolean returnDocuments = extractOptionalBoolean(map, RETURN_TEXT, validationException);
43+
Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException);
4444
Integer topNDocumentsOnly = extractOptionalPositiveInteger(
4545
map,
4646
TOP_N_DOCS_ONLY,
@@ -85,7 +85,7 @@ public static HuggingFaceRerankTaskSettings of(Integer topNDocumentsOnly, Boolea
8585
private final Boolean returnDocuments;
8686

8787
public HuggingFaceRerankTaskSettings(StreamInput in) throws IOException {
88-
this(in.readOptionalInt(), in.readOptionalBoolean());
88+
this(in.readOptionalVInt(), in.readOptionalBoolean());
8989
}
9090

9191
public HuggingFaceRerankTaskSettings(@Nullable Integer topNDocumentsOnly, @Nullable Boolean doReturnDocuments) {
@@ -105,7 +105,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
105105
builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly);
106106
}
107107
if (returnDocuments != null) {
108-
builder.field(RETURN_TEXT, returnDocuments);
108+
builder.field(RETURN_DOCUMENTS, returnDocuments);
109109
}
110110
builder.endObject();
111111
return builder;
@@ -123,7 +123,7 @@ public TransportVersion getMinimalSupportedVersion() {
123123

124124
@Override
125125
public void writeTo(StreamOutput out) throws IOException {
126-
out.writeOptionalInt(topNDocumentsOnly);
126+
out.writeOptionalVInt(topNDocumentsOnly);
127127
out.writeOptionalBoolean(returnDocuments);
128128
}
129129

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,15 @@ public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, H
8181
private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser) throws IOException {
8282
return parseList(parser, (listParser, index) -> {
8383
var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser);
84-
return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text);
84+
return new RankedDocsResults.RankedDoc(parsedRankedDoc.index, parsedRankedDoc.score, parsedRankedDoc.text);
8585
});
8686
}
8787

88-
private record RankedDocEntry(Integer id, Float score, @Nullable String text) {
88+
private record RankedDocEntry(Integer index, Float score, @Nullable String text) {
8989

9090
private static final ParseField TEXT = new ParseField("text");
9191
private static final ParseField SCORE = new ParseField("score");
92-
private static final ParseField ID = new ParseField("index");
92+
private static final ParseField INDEX = new ParseField("index");
9393
private static final ConstructingObjectParser<HuggingFaceRerankResponseEntity.RankedDocEntry, Void> PARSER =
9494
new ConstructingObjectParser<>(
9595
"hugging_face_rerank_response",
@@ -98,7 +98,7 @@ private record RankedDocEntry(Integer id, Float score, @Nullable String text) {
9898
);
9999

100100
static {
101-
PARSER.declareInt(ConstructingObjectParser.constructorArg(), ID);
101+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), INDEX);
102102
PARSER.declareFloat(ConstructingObjectParser.constructorArg(), SCORE);
103103
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), TEXT);
104104
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankTaskSettingsTests.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77

88
package org.elasticsearch.xpack.inference.services.huggingface.rerank;
99

10+
import org.elasticsearch.TransportVersion;
1011
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.common.io.stream.Writeable;
12-
import org.elasticsearch.test.AbstractWireSerializingTestCase;
13+
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
1314

1415
import java.io.IOException;
1516
import java.util.HashMap;
1617
import java.util.Map;
1718

1819
import static org.hamcrest.Matchers.containsString;
1920

20-
public class HuggingFaceRerankTaskSettingsTests extends AbstractWireSerializingTestCase<HuggingFaceRerankTaskSettings> {
21+
public class HuggingFaceRerankTaskSettingsTests extends AbstractBWCWireSerializationTestCase<HuggingFaceRerankTaskSettings> {
2122

2223
public static HuggingFaceRerankTaskSettings createRandom() {
2324
var returnDocuments = randomBoolean() ? randomBoolean() : null;
@@ -28,7 +29,7 @@ public static HuggingFaceRerankTaskSettings createRandom() {
2829

2930
public void testFromMap_WithValidValues_ReturnsSettings() {
3031
Map<String, Object> taskMap = Map.of(
31-
HuggingFaceRerankTaskSettings.RETURN_TEXT,
32+
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
3233
true,
3334
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
3435
5
@@ -46,18 +47,18 @@ public void testFromMap_WithNullValues_ReturnsSettingsWithNulls() {
4647

4748
public void testFromMap_WithInvalidReturnDocuments_ThrowsValidationException() {
4849
Map<String, Object> taskMap = Map.of(
49-
HuggingFaceRerankTaskSettings.RETURN_TEXT,
50+
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
5051
"invalid",
5152
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
5253
5
5354
);
5455
var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceRerankTaskSettings.fromMap(new HashMap<>(taskMap)));
55-
assertThat(thrownException.getMessage(), containsString("field [return_text] is not of the expected type"));
56+
assertThat(thrownException.getMessage(), containsString("field [return_documents] is not of the expected type"));
5657
}
5758

5859
public void testFromMap_WithInvalidTopNDocsOnly_ThrowsValidationException() {
5960
Map<String, Object> taskMap = Map.of(
60-
HuggingFaceRerankTaskSettings.RETURN_TEXT,
61+
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
6162
true,
6263
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
6364
"invalid"
@@ -74,7 +75,7 @@ public void UpdatedTaskSettings_WithEmptyMap_ReturnsSameSettings() {
7475

7576
public void testUpdatedTaskSettings_WithNewReturnDocuments_ReturnsUpdatedSettings() {
7677
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
77-
Map<String, Object> newSettings = Map.of(HuggingFaceRerankTaskSettings.RETURN_TEXT, false);
78+
Map<String, Object> newSettings = Map.of(HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS, false);
7879
HuggingFaceRerankTaskSettings updatedSettings = (HuggingFaceRerankTaskSettings) initialSettings.updatedTaskSettings(newSettings);
7980
assertFalse(updatedSettings.getReturnDocuments());
8081
assertEquals(initialSettings.getTopNDocumentsOnly(), updatedSettings.getTopNDocumentsOnly());
@@ -91,7 +92,7 @@ public void testUpdatedTaskSettings_WithNewTopNDocsOnly_ReturnsUpdatedSettings()
9192
public void testUpdatedTaskSettings_WithMultipleNewValues_ReturnsUpdatedSettings() {
9293
var initialSettings = new HuggingFaceRerankTaskSettings(5, true);
9394
Map<String, Object> newSettings = Map.of(
94-
HuggingFaceRerankTaskSettings.RETURN_TEXT,
95+
HuggingFaceRerankTaskSettings.RETURN_DOCUMENTS,
9596
false,
9697
HuggingFaceRerankTaskSettings.TOP_N_DOCS_ONLY,
9798
7
@@ -115,4 +116,9 @@ protected HuggingFaceRerankTaskSettings createTestInstance() {
115116
protected HuggingFaceRerankTaskSettings mutateInstance(HuggingFaceRerankTaskSettings instance) throws IOException {
116117
return randomValueOtherThan(instance, HuggingFaceRerankTaskSettingsTests::createRandom);
117118
}
119+
120+
@Override
121+
protected HuggingFaceRerankTaskSettings mutateInstanceForVersion(HuggingFaceRerankTaskSettings instance, TransportVersion version) {
122+
return instance;
123+
}
118124
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntityTests.java

Lines changed: 82 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,45 +27,82 @@
2727
public class HuggingFaceRerankResponseEntityTests extends ESTestCase {
2828
private static final String MISSED_FIELD_INDEX = "index";
2929
private static final String MISSED_FIELD_SCORE = "score";
30+
private static final String RESPONSE_JSON_TWO_DOCS = """
31+
[
32+
{
33+
"index": 4,
34+
"score": -0.22222222222222222,
35+
"text": "ranked second"
36+
},
37+
{
38+
"index": 1,
39+
"score": 1.11111111111111111,
40+
"text": "ranked first"
41+
}
42+
]
43+
""";
44+
private static final List<RankedDocsResultsTests.RerankExpectation> EXPECTED_TWO_DOCS = List.of(
45+
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")),
46+
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second"))
47+
);
3048

31-
public void testFromResponse_CreatesRankedDocsResults() throws IOException {
32-
String responseJson = """
33-
[
34-
{
35-
"index": 0,
36-
"score": -0.07996220886707306,
37-
"text": "luke"
38-
}
39-
]
40-
""";
41-
HuggingFaceRerankRequest huggingFaceRerankRequestMock = mock(HuggingFaceRerankRequest.class);
42-
when(huggingFaceRerankRequestMock.getTopN()).thenReturn(1);
49+
private static final String RESPONSE_JSON_FIVE_DOCS = """
50+
[
51+
{
52+
"index": 1,
53+
"score": 1.11111111111111111,
54+
"text": "ranked first"
55+
},
56+
{
57+
"index": 3,
58+
"score": -0.33333333333333333,
59+
"text": "ranked third"
60+
},
61+
{
62+
"index": 0,
63+
"score": -0.55555555555555555,
64+
"text": "ranked fifth"
65+
},
66+
{
67+
"index": 2,
68+
"score": -0.44444444444444444,
69+
"text": "ranked fourth"
70+
},
71+
{
72+
"index": 4,
73+
"score": -0.22222222222222222,
74+
"text": "ranked second"
75+
}
76+
]
77+
""";
78+
private static final List<RankedDocsResultsTests.RerankExpectation> EXPECTED_FIVE_DOCS = List.of(
79+
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 1, "relevance_score", 1.11111111111111111F, "text", "ranked first")),
80+
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 4, "relevance_score", -0.22222222222222222F, "text", "ranked second")),
81+
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 3, "relevance_score", -0.33333333333333333F, "text", "ranked third")),
82+
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 2, "relevance_score", -0.44444444444444444F, "text", "ranked fourth")),
83+
new RankedDocsResultsTests.RerankExpectation(Map.of("index", 0, "relevance_score", -0.55555555555555555F, "text", "ranked fifth"))
84+
);
4385

44-
RankedDocsResults parsedResults = HuggingFaceRerankResponseEntity.fromResponse(
45-
huggingFaceRerankRequestMock,
46-
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
47-
);
86+
private static final HuggingFaceRerankRequest REQUEST_MOCK = mock(HuggingFaceRerankRequest.class);
4887

49-
assertThat(
50-
parsedResults.asMap(),
51-
is(
52-
buildExpectationRerank(
53-
List.of(
54-
new RankedDocsResultsTests.RerankExpectation(
55-
Map.of("index", 0, "relevance_score", -0.07996220886707306F, "text", "luke")
56-
)
57-
)
58-
)
59-
)
60-
);
88+
public void testFromResponse_CreatesRankedDocsResults_TopNNull_FiveDocs_NoLimit() throws IOException {
89+
assertTopNLimit(null, RESPONSE_JSON_FIVE_DOCS, EXPECTED_FIVE_DOCS);
90+
}
91+
92+
public void testFromResponse_CreatesRankedDocsResults_TopN5_TwoDocs_NoLimit() throws IOException {
93+
assertTopNLimit(5, RESPONSE_JSON_TWO_DOCS, EXPECTED_TWO_DOCS);
94+
}
95+
96+
public void testFromResponse_CreatesRankedDocsResults_TopN2_FiveDocs_Limits() throws IOException {
97+
assertTopNLimit(2, RESPONSE_JSON_FIVE_DOCS, EXPECTED_TWO_DOCS);
6198
}
6299

63100
public void testFails_CreateRankedDocsResults_IndexFieldNull() {
64101
String responseJson = """
65102
[
66103
{
67-
"score": -0.07996220886707306,
68-
"text": "luke"
104+
"score": 1.11111111111111111,
105+
"text": "ranked first"
69106
}
70107
]
71108
""";
@@ -76,25 +113,35 @@ public void testFails_CreateRankedDocsResults_ScoreFieldNull() {
76113
String responseJson = """
77114
[
78115
{
79-
"index": 0,
80-
"text": "luke"
116+
"index": 1,
117+
"text": "ranked first"
81118
}
82119
]
83120
""";
84121
assertMissingFieldThrowsIllegalArgumentException(responseJson, MISSED_FIELD_SCORE);
85122
}
86123

87124
private void assertMissingFieldThrowsIllegalArgumentException(String responseJson, String missingField) {
88-
HuggingFaceRerankRequest huggingFaceRerankRequestMock = mock(HuggingFaceRerankRequest.class);
89-
when(huggingFaceRerankRequestMock.getTopN()).thenReturn(1);
125+
when(REQUEST_MOCK.getTopN()).thenReturn(1);
90126

91127
var thrownException = expectThrows(
92128
IllegalArgumentException.class,
93129
() -> HuggingFaceRerankResponseEntity.fromResponse(
94-
huggingFaceRerankRequestMock,
130+
REQUEST_MOCK,
95131
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
96132
)
97133
);
98134
assertThat(thrownException.getMessage(), is("Required [" + missingField + "]"));
99135
}
136+
137+
private void assertTopNLimit(
138+
Integer topN, String responseJson, List<RankedDocsResultsTests.RerankExpectation> expectation) throws IOException {
139+
when(REQUEST_MOCK.getTopN()).thenReturn(topN);
140+
141+
RankedDocsResults parsedResults = HuggingFaceRerankResponseEntity.fromResponse(
142+
REQUEST_MOCK,
143+
new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8))
144+
);
145+
assertThat(parsedResults.asMap(), is(buildExpectationRerank(expectation)));
146+
}
100147
}

0 commit comments

Comments
 (0)