Skip to content

Commit 15aca4a

Browse files
ymao1elasticsearchmachine
authored andcommitted
Adding common rerank options to Perform Inference API (elastic#125239)
* wip * Adding rerank common options * Linting * Linting * [CI] Auto commit changes from spotless * Update docs/changelog/125239.yaml * PR feedback --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent c205ac2 commit 15aca4a

File tree

66 files changed

+1248
-229
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+1248
-229
lines changed

docs/changelog/125239.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 125239
2+
summary: Adding common rerank options to Perform Inference API
3+
area: Machine Learning
4+
type: enhancement
5+
issues:
6+
- 111273

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ static TransportVersion def(int id) {
155155
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL_8_19 = def(8_841_0_12);
156156
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13);
157157
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
158+
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
158159
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
159160
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
160161
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
@@ -201,6 +202,7 @@ static TransportVersion def(int id) {
201202
public static final TransportVersion INDEXING_STATS_INCLUDES_RECENT_WRITE_LOAD = def(9_034_0_00);
202203
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(9_035_0_00);
203204
public static final TransportVersion INDEX_METADATA_INCLUDES_RECENT_WRITE_LOAD = def(9_036_0_00);
205+
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED = def(9_037_0_00);
204206

205207
/*
206208
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,22 @@ default boolean hideFromConfigurationApi() {
9191
/**
9292
* Perform inference on the model.
9393
*
94-
* @param model The model
95-
* @param query Inference query, mainly for re-ranking
96-
* @param input Inference input
97-
* @param stream Stream inference results
98-
* @param taskSettings Settings in the request to override the model's defaults
99-
* @param inputType For search, ingest etc
100-
* @param timeout The timeout for the request
101-
* @param listener Inference result listener
94+
* @param model The model
95+
* @param query Inference query, mainly for re-ranking
96+
* @param returnDocuments For re-ranking task type, whether to return documents
97+
* @param topN For re-ranking task type, how many docs to return
98+
* @param input Inference input
99+
* @param stream Stream inference results
100+
* @param taskSettings Settings in the request to override the model's defaults
101+
* @param inputType For search, ingest etc
102+
* @param timeout The timeout for the request
103+
* @param listener Inference result listener
102104
*/
103105
void infer(
104106
Model model,
105107
@Nullable String query,
108+
@Nullable Boolean returnDocuments,
109+
@Nullable Integer topN,
106110
List<String> input,
107111
boolean stream,
108112
Map<String, Object> taskSettings,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ public static class Request extends BaseInferenceActionRequest {
6060
public static final ParseField INPUT_TYPE = new ParseField("input_type");
6161
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
6262
public static final ParseField QUERY = new ParseField("query");
63+
public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents");
64+
public static final ParseField TOP_N = new ParseField("top_n");
6365
public static final ParseField TIMEOUT = new ParseField("timeout");
6466

6567
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
@@ -68,6 +70,8 @@ public static class Request extends BaseInferenceActionRequest {
6870
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
6971
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
7072
PARSER.declareString(Request.Builder::setQuery, QUERY);
73+
PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS);
74+
PARSER.declareInt(Request.Builder::setTopN, TOP_N);
7175
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
7276
}
7377

@@ -89,6 +93,8 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
8993
private final TaskType taskType;
9094
private final String inferenceEntityId;
9195
private final String query;
96+
private final Boolean returnDocuments;
97+
private final Integer topN;
9298
private final List<String> input;
9399
private final Map<String, Object> taskSettings;
94100
private final InputType inputType;
@@ -99,6 +105,8 @@ public Request(
99105
TaskType taskType,
100106
String inferenceEntityId,
101107
String query,
108+
Boolean returnDocuments,
109+
Integer topN,
102110
List<String> input,
103111
Map<String, Object> taskSettings,
104112
InputType inputType,
@@ -109,6 +117,8 @@ public Request(
109117
taskType,
110118
inferenceEntityId,
111119
query,
120+
returnDocuments,
121+
topN,
112122
input,
113123
taskSettings,
114124
inputType,
@@ -122,6 +132,8 @@ public Request(
122132
TaskType taskType,
123133
String inferenceEntityId,
124134
String query,
135+
Boolean returnDocuments,
136+
Integer topN,
125137
List<String> input,
126138
Map<String, Object> taskSettings,
127139
InputType inputType,
@@ -133,6 +145,8 @@ public Request(
133145
this.taskType = taskType;
134146
this.inferenceEntityId = inferenceEntityId;
135147
this.query = query;
148+
this.returnDocuments = returnDocuments;
149+
this.topN = topN;
136150
this.input = input;
137151
this.taskSettings = taskSettings;
138152
this.inputType = inputType;
@@ -164,6 +178,15 @@ public Request(StreamInput in) throws IOException {
164178
this.inferenceTimeout = DEFAULT_TIMEOUT;
165179
}
166180

181+
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
182+
|| in.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
183+
this.returnDocuments = in.readOptionalBoolean();
184+
this.topN = in.readOptionalInt();
185+
} else {
186+
this.returnDocuments = null;
187+
this.topN = null;
188+
}
189+
167190
// streaming is not supported yet for transport traffic
168191
this.stream = false;
169192
}
@@ -184,6 +207,14 @@ public String getQuery() {
184207
return query;
185208
}
186209

210+
public Boolean getReturnDocuments() {
211+
return returnDocuments;
212+
}
213+
214+
public Integer getTopN() {
215+
return topN;
216+
}
217+
187218
public Map<String, Object> getTaskSettings() {
188219
return taskSettings;
189220
}
@@ -225,6 +256,17 @@ public ActionRequestValidationException validate() {
225256
e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
226257
return e;
227258
}
259+
} else if (taskType.equals(TaskType.ANY) == false) {
260+
if (returnDocuments != null) {
261+
var e = new ActionRequestValidationException();
262+
e.addValidationError(format("Field [return_documents] cannot be specified for task type [%s]", taskType));
263+
return e;
264+
}
265+
if (topN != null) {
266+
var e = new ActionRequestValidationException();
267+
e.addValidationError(format("Field [top_n] cannot be specified for task type [%s]", taskType));
268+
return e;
269+
}
228270
}
229271

230272
if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
@@ -258,6 +300,12 @@ public void writeTo(StreamOutput out) throws IOException {
258300
out.writeOptionalString(query);
259301
out.writeTimeValue(inferenceTimeout);
260302
}
303+
304+
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
305+
|| out.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
306+
out.writeOptionalBoolean(returnDocuments);
307+
out.writeOptionalInt(topN);
308+
}
261309
}
262310

263311
// default for easier testing
@@ -283,6 +331,8 @@ public boolean equals(Object o) {
283331
&& taskType == request.taskType
284332
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
285333
&& Objects.equals(query, request.query)
334+
&& Objects.equals(returnDocuments, request.returnDocuments)
335+
&& Objects.equals(topN, request.topN)
286336
&& Objects.equals(input, request.input)
287337
&& Objects.equals(taskSettings, request.taskSettings)
288338
&& inputType == request.inputType
@@ -296,6 +346,8 @@ public int hashCode() {
296346
taskType,
297347
inferenceEntityId,
298348
query,
349+
returnDocuments,
350+
topN,
299351
input,
300352
taskSettings,
301353
inputType,
@@ -312,6 +364,8 @@ public static class Builder {
312364
private InputType inputType = InputType.UNSPECIFIED;
313365
private Map<String, Object> taskSettings = Map.of();
314366
private String query;
367+
private Boolean returnDocuments;
368+
private Integer topN;
315369
private TimeValue timeout = DEFAULT_TIMEOUT;
316370
private boolean stream = false;
317371
private InferenceContext context;
@@ -338,6 +392,16 @@ public Builder setQuery(String query) {
338392
return this;
339393
}
340394

395+
public Builder setReturnDocuments(Boolean returnDocuments) {
396+
this.returnDocuments = returnDocuments;
397+
return this;
398+
}
399+
400+
public Builder setTopN(Integer topN) {
401+
this.topN = topN;
402+
return this;
403+
}
404+
341405
public Builder setInputType(InputType inputType) {
342406
this.inputType = inputType;
343407
return this;
@@ -373,7 +437,19 @@ public Builder setContext(InferenceContext context) {
373437
}
374438

375439
public Request build() {
376-
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
440+
return new Request(
441+
taskType,
442+
inferenceEntityId,
443+
query,
444+
returnDocuments,
445+
topN,
446+
input,
447+
taskSettings,
448+
inputType,
449+
timeout,
450+
stream,
451+
context
452+
);
377453
}
378454
}
379455

@@ -384,6 +460,10 @@ public String toString() {
384460
+ this.getInferenceEntityId()
385461
+ ", query="
386462
+ this.getQuery()
463+
+ ", returnDocuments="
464+
+ this.getReturnDocuments()
465+
+ ", topN="
466+
+ this.getTopN()
387467
+ ", input="
388468
+ this.getInput()
389469
+ ", taskSettings="

0 commit comments

Comments
 (0)