Skip to content

Commit a52a1d8

Browse files
apply suggestions
1 parent 88d6929 commit a52a1d8

File tree

10 files changed

+178
-114
lines changed

10 files changed

+178
-114
lines changed

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/RankedDocsResultsTests.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,20 @@ private List<RankedDocsResults.RankedDoc> rankedDocsNullStringToEmpty(List<Ranke
8282
protected RankedDocsResults doParseInstance(XContentParser parser) throws IOException {
8383
return RankedDocsResults.createParser(true).apply(parser, null);
8484
}
85+
86+
public record RerankExpectation(Map<String, Object> rankedDocFields) {}
87+
88+
public static Map<String, Object> buildExpectationRerank(List<RerankExpectation> rerank) {
89+
return Map.of(
90+
RankedDocsResults.RERANK,
91+
rerank.stream()
92+
.map(
93+
rerankExpectation -> Map.of(
94+
RankedDocsResults.RankedDoc.NAME,
95+
rerankExpectation.rankedDocFields
96+
)
97+
)
98+
.toList()
99+
);
100+
}
85101
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.elasticsearch.xpack.inference.services.ServiceUtils;
2727
import org.elasticsearch.xpack.inference.services.huggingface.action.HuggingFaceActionCreator;
2828

29-
import java.util.Collections;
3029
import java.util.Map;
3130

3231
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
@@ -58,11 +57,7 @@ public void parseRequestConfig(
5857
) {
5958
try {
6059
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
61-
Map<String, Object> taskSettingsMap = Collections.emptyMap();
62-
63-
if (TaskType.RERANK.equals(taskType)) {
64-
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
65-
}
60+
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
6661

6762
ChunkingSettings chunkingSettings = null;
6863
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@@ -102,12 +97,8 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
10297
Map<String, Object> secrets
10398
) {
10499
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
100+
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
105101
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
106-
Map<String, Object> taskSettingsMap = Collections.emptyMap();
107-
108-
if (TaskType.RERANK.equals(taskType)) {
109-
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
110-
}
111102

112103
ChunkingSettings chunkingSettings = null;
113104
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
@@ -131,11 +122,7 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
131122
@Override
132123
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
133124
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
134-
Map<String, Object> taskSettingsMap = Collections.emptyMap();
135-
136-
if (TaskType.RERANK.equals(taskType)) {
137-
taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
138-
}
125+
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
139126

140127
ChunkingSettings chunkingSettings = null;
141128
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -76,43 +76,43 @@ public HuggingFaceService(HttpRequestSender.Factory factory, ServiceComponents s
7676
}
7777

7878
@Override
79-
protected HuggingFaceModel createModel(HuggingFaceModelParameters input) {
80-
return switch (input.taskType()) {
79+
protected HuggingFaceModel createModel(HuggingFaceModelParameters params) {
80+
return switch (params.taskType()) {
8181
case RERANK -> new HuggingFaceRerankModel(
82-
input.inferenceEntityId(),
83-
input.taskType(),
82+
params.inferenceEntityId(),
83+
params.taskType(),
8484
NAME,
85-
input.serviceSettings(),
86-
input.taskSettings(),
87-
input.secretSettings(),
88-
input.context()
85+
params.serviceSettings(),
86+
params.taskSettings(),
87+
params.secretSettings(),
88+
params.context()
8989
);
9090
case TEXT_EMBEDDING -> new HuggingFaceEmbeddingsModel(
91-
input.inferenceEntityId(),
92-
input.taskType(),
91+
params.inferenceEntityId(),
92+
params.taskType(),
9393
NAME,
94-
input.serviceSettings(),
95-
input.chunkingSettings(),
96-
input.secretSettings(),
97-
input.context()
94+
params.serviceSettings(),
95+
params.chunkingSettings(),
96+
params.secretSettings(),
97+
params.context()
9898
);
9999
case SPARSE_EMBEDDING -> new HuggingFaceElserModel(
100-
input.inferenceEntityId(),
101-
input.taskType(),
100+
params.inferenceEntityId(),
101+
params.taskType(),
102102
NAME,
103-
input.serviceSettings(),
104-
input.secretSettings(),
105-
input.context()
103+
params.serviceSettings(),
104+
params.secretSettings(),
105+
params.context()
106106
);
107107
case CHAT_COMPLETION, COMPLETION -> new HuggingFaceChatCompletionModel(
108-
input.inferenceEntityId(),
109-
input.taskType(),
108+
params.inferenceEntityId(),
109+
params.taskType(),
110110
NAME,
111-
input.serviceSettings(),
112-
input.secretSettings(),
113-
input.context()
111+
params.serviceSettings(),
112+
params.secretSettings(),
113+
params.context()
114114
);
115-
default -> throw new ElasticsearchStatusException(input.failureMessage(), RestStatus.BAD_REQUEST);
115+
default -> throw new ElasticsearchStatusException(params.failureMessage(), RestStatus.BAD_REQUEST);
116116
};
117117
}
118118

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreator.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@ public class HuggingFaceActionCreator implements HuggingFaceActionVisitor {
5151
OpenAiChatCompletionResponseEntity::fromResponse
5252
);
5353
private static final ResponseHandler RERANK_HANDLER = new HuggingFaceResponseHandler("hugging face rerank", (request, response) -> {
54-
var errorMessage = format(INVALID_REQUEST_TYPE_MESSAGE, "RERANK", request != null ? request.getClass().getName() : "null");
55-
5654
if ((request instanceof HuggingFaceRerankRequest) == false) {
55+
var errorMessage = format(
56+
INVALID_REQUEST_TYPE_MESSAGE,
57+
"RERANK",
58+
request != null ? request.getClass().getSimpleName() : "null"
59+
);
5760
throw new IllegalArgumentException(errorMessage);
5861
}
5962
return HuggingFaceRerankResponseEntity.fromResponse((HuggingFaceRerankRequest) request, response);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequest.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,20 +62,19 @@ public HttpRequest createHttpRequest() {
6262
input,
6363
returnDocuments,
6464
topN != null ? topN : model.getTaskSettings().getTopNDocumentsOnly(),
65-
model.getTaskSettings(),
66-
model.getServiceSettings().modelId()
65+
model.getTaskSettings()
6766
)
6867
).getBytes(StandardCharsets.UTF_8)
6968
);
7069
httpPost.setEntity(byteEntity);
71-
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
70+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaTypeWithoutParameters());
7271

7372
decorateWithAuth(httpPost);
7473

7574
return new HttpRequest(httpPost, getInferenceEntityId());
7675
}
7776

78-
public void decorateWithAuth(HttpPost httpPost) {
77+
void decorateWithAuth(HttpPost httpPost) {
7978
httpPost.setHeader(createAuthBearerHeader(model.apiKey()));
8079
}
8180

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntity.java

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import java.util.Objects;
1818

1919
public record HuggingFaceRerankRequestEntity(
20-
String model,
2120
String query,
2221
List<String> documents,
2322
@Nullable Boolean returnDocuments,
@@ -34,17 +33,6 @@ public record HuggingFaceRerankRequestEntity(
3433
Objects.requireNonNull(taskSettings);
3534
}
3635

37-
public HuggingFaceRerankRequestEntity(
38-
String query,
39-
List<String> input,
40-
@Nullable Boolean returnDocuments,
41-
@Nullable Integer topN,
42-
HuggingFaceRerankTaskSettings taskSettings,
43-
String model
44-
) {
45-
this(model, query, input, returnDocuments, topN, taskSettings);
46-
}
47-
4836
@Override
4937
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
5038
builder.startObject();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/rerank/HuggingFaceRerankModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public HuggingFaceRerankModel(
4444
}
4545

4646
// Should only be used directly for testing
47-
public HuggingFaceRerankModel(
47+
HuggingFaceRerankModel(
4848
String inferenceEntityId,
4949
TaskType taskType,
5050
String service,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/response/HuggingFaceRerankResponseEntity.java

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,11 @@
2323
import java.util.Comparator;
2424
import java.util.List;
2525

26-
import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
2726
import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList;
28-
import static org.elasticsearch.core.Strings.format;
2927
import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
3028

3129
public class HuggingFaceRerankResponseEntity {
3230

33-
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Hugging Face rerank response";
34-
private static final String INVALID_ID_FIELD_FORMAT_TEMPLATE = "Expected numeric value for record ID field in Hugging Face rerank ";
35-
3631
/**
3732
* Parses the Hugging Face rerank response.
3833
@@ -41,10 +36,9 @@ public class HuggingFaceRerankResponseEntity {
4136
* <pre>
4237
* <code>
4338
* {
44-
* "input": ["luke", "like", "leia", "chewy","r2d2", "star", "wars"],
39+
* "texts": ["luke", "leia"],
4540
* "query": "star wars main character",
46-
* "return_documents": true,
47-
* "top_n": 1
41+
* "return_text": true
4842
* }
4943
* </code>
5044
* </pre>
@@ -53,15 +47,18 @@ public class HuggingFaceRerankResponseEntity {
5347
5448
* <pre>
5549
* <code>
56-
* {
57-
* "rerank": [
58-
* {
59-
* "index": 5,
60-
* "relevance_score": -0.06920313,
61-
* "text": "star"
62-
* }
63-
* ]
64-
* }
50+
* [
51+
* {
52+
* "index": 0,
53+
* "score": -0.07996220886707306,
54+
* "text": "luke"
55+
* },
56+
* {
57+
* "index": 1,
58+
* "score": -0.08393221348524094,
59+
* "text": "leia"
60+
* }
61+
* ]
6562
* </code>
6663
* </pre>
6764
*/
@@ -71,10 +68,6 @@ public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, H
7168

7269
try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) {
7370
moveToFirstToken(jsonParser);
74-
75-
XContentParser.Token token = jsonParser.currentToken();
76-
ensureExpectedToken(XContentParser.Token.START_ARRAY, token, jsonParser);
77-
7871
var rankedDocs = doParse(jsonParser);
7972
var rankedDocsByRelevanceStream = rankedDocs.stream()
8073
.sorted(Comparator.comparingDouble(RankedDocsResults.RankedDoc::relevanceScore).reversed());
@@ -88,24 +81,11 @@ public static RankedDocsResults fromResponse(HuggingFaceRerankRequest request, H
8881
private static List<RankedDocsResults.RankedDoc> doParse(XContentParser parser) throws IOException {
8982
return parseList(parser, (listParser, index) -> {
9083
var parsedRankedDoc = HuggingFaceRerankResponseEntity.RankedDocEntry.parse(parser);
91-
92-
if (parsedRankedDoc.id == null) {
93-
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDocEntry.ID.getPreferredName()));
94-
}
95-
96-
if (parsedRankedDoc.score == null) {
97-
throw new IllegalStateException(format(FAILED_TO_FIND_FIELD_TEMPLATE, RankedDocEntry.SCORE.getPreferredName()));
98-
}
99-
100-
try {
101-
return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text);
102-
} catch (NumberFormatException e) {
103-
throw new IllegalStateException(format(INVALID_ID_FIELD_FORMAT_TEMPLATE, parsedRankedDoc.id));
104-
}
84+
return new RankedDocsResults.RankedDoc(parsedRankedDoc.id, parsedRankedDoc.score, parsedRankedDoc.text);
10585
});
10686
}
10787

108-
private record RankedDocEntry(@Nullable Integer id, @Nullable Float score, @Nullable String text) {
88+
private record RankedDocEntry(Integer id, Float score, @Nullable String text) {
10989

11090
private static final ParseField TEXT = new ParseField("text");
11191
private static final ParseField SCORE = new ParseField("score");

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/request/rerank/HuggingFaceRerankRequestEntityTests.java

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,11 @@
1818
import java.io.IOException;
1919
import java.util.List;
2020

21-
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
21+
import static org.elasticsearch.common.xcontent.XContentHelper.stripWhitespace;
2222

2323
public class HuggingFaceRerankRequestEntityTests extends ESTestCase {
2424
private static final String INPUT = "texts";
2525
private static final String QUERY = "query";
26-
private static final String INFERENCE_ID = "model";
2726
private static final Integer TOP_N = 8;
2827
private static final Boolean RETURN_DOCUMENTS = false;
2928

@@ -33,36 +32,28 @@ public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException
3332
List.of(INPUT),
3433
Boolean.TRUE,
3534
TOP_N,
36-
new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS),
37-
INFERENCE_ID
35+
new HuggingFaceRerankTaskSettings(TOP_N, RETURN_DOCUMENTS)
3836
);
3937

4038
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
4139
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
4240
String xContentResult = Strings.toString(builder);
43-
44-
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
41+
String expected = """
4542
{"texts":["texts"],
4643
"query":"query",
4744
"return_text":true,
48-
"top_n":8}"""));
45+
"top_n":8}""";
46+
assertEquals(stripWhitespace(expected), xContentResult);
4947
}
5048

5149
public void testXContent_WritesMinimalFields() throws IOException {
52-
var entity = new HuggingFaceRerankRequestEntity(
53-
QUERY,
54-
List.of(INPUT),
55-
null,
56-
null,
57-
new HuggingFaceRerankTaskSettings(null, null),
58-
INFERENCE_ID
59-
);
50+
var entity = new HuggingFaceRerankRequestEntity(QUERY, List.of(INPUT), null, null, new HuggingFaceRerankTaskSettings(null, null));
6051

6152
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
6253
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
6354
String xContentResult = Strings.toString(builder);
64-
65-
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
66-
{"texts":["texts"],"query":"query"}"""));
55+
String expected = """
56+
{"texts":["texts"],"query":"query"}""";
57+
assertEquals(stripWhitespace(expected), xContentResult);
6758
}
6859
}

0 commit comments

Comments
 (0)