Skip to content

Commit 36afa61

Browse files
authored
Adding common rerank options to Perform Inference API (#125239) (#125600)
* wip * Adding rerank common options * Linting * Linting * [CI] Auto commit changes from spotless * Update docs/changelog/125239.yaml * PR feedback --------- Co-authored-by: elasticsearchmachine <[email protected]> (cherry picked from commit a6f685c) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java
1 parent f346ff8 commit 36afa61

File tree

66 files changed

+1241
-226
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+1241
-226
lines changed

docs/changelog/125239.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 125239
2+
summary: Adding common rerank options to Perform Inference API
3+
area: Machine Learning
4+
type: enhancement
5+
issues:
6+
- 111273

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ static TransportVersion def(int id) {
199199
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(8_841_0_12);
200200
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13);
201201
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
202+
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
202203

203204
/*
204205
* STOP! READ THIS FIRST! No, really,

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,22 @@ default boolean hideFromConfigurationApi() {
9191
/**
9292
* Perform inference on the model.
9393
*
94-
* @param model The model
95-
* @param query Inference query, mainly for re-ranking
96-
* @param input Inference input
97-
* @param stream Stream inference results
98-
* @param taskSettings Settings in the request to override the model's defaults
99-
* @param inputType For search, ingest etc
100-
* @param timeout The timeout for the request
101-
* @param listener Inference result listener
94+
* @param model The model
95+
* @param query Inference query, mainly for re-ranking
96+
* @param returnDocuments For re-ranking task type, whether to return documents
97+
* @param topN For re-ranking task type, how many docs to return
98+
* @param input Inference input
99+
* @param stream Stream inference results
100+
* @param taskSettings Settings in the request to override the model's defaults
101+
* @param inputType For search, ingest etc
102+
* @param timeout The timeout for the request
103+
* @param listener Inference result listener
102104
*/
103105
void infer(
104106
Model model,
105107
@Nullable String query,
108+
@Nullable Boolean returnDocuments,
109+
@Nullable Integer topN,
106110
List<String> input,
107111
boolean stream,
108112
Map<String, Object> taskSettings,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ public static class Request extends BaseInferenceActionRequest {
6060
public static final ParseField INPUT_TYPE = new ParseField("input_type");
6161
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
6262
public static final ParseField QUERY = new ParseField("query");
63+
public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents");
64+
public static final ParseField TOP_N = new ParseField("top_n");
6365
public static final ParseField TIMEOUT = new ParseField("timeout");
6466

6567
static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
@@ -68,6 +70,8 @@ public static class Request extends BaseInferenceActionRequest {
6870
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
6971
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
7072
PARSER.declareString(Request.Builder::setQuery, QUERY);
73+
PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS);
74+
PARSER.declareInt(Request.Builder::setTopN, TOP_N);
7175
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
7276
}
7377

@@ -89,6 +93,8 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
8993
private final TaskType taskType;
9094
private final String inferenceEntityId;
9195
private final String query;
96+
private final Boolean returnDocuments;
97+
private final Integer topN;
9298
private final List<String> input;
9399
private final Map<String, Object> taskSettings;
94100
private final InputType inputType;
@@ -99,6 +105,8 @@ public Request(
99105
TaskType taskType,
100106
String inferenceEntityId,
101107
String query,
108+
Boolean returnDocuments,
109+
Integer topN,
102110
List<String> input,
103111
Map<String, Object> taskSettings,
104112
InputType inputType,
@@ -109,6 +117,8 @@ public Request(
109117
taskType,
110118
inferenceEntityId,
111119
query,
120+
returnDocuments,
121+
topN,
112122
input,
113123
taskSettings,
114124
inputType,
@@ -122,6 +132,8 @@ public Request(
122132
TaskType taskType,
123133
String inferenceEntityId,
124134
String query,
135+
Boolean returnDocuments,
136+
Integer topN,
125137
List<String> input,
126138
Map<String, Object> taskSettings,
127139
InputType inputType,
@@ -133,6 +145,8 @@ public Request(
133145
this.taskType = taskType;
134146
this.inferenceEntityId = inferenceEntityId;
135147
this.query = query;
148+
this.returnDocuments = returnDocuments;
149+
this.topN = topN;
136150
this.input = input;
137151
this.taskSettings = taskSettings;
138152
this.inputType = inputType;
@@ -164,6 +178,14 @@ public Request(StreamInput in) throws IOException {
164178
this.inferenceTimeout = DEFAULT_TIMEOUT;
165179
}
166180

181+
if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
182+
this.returnDocuments = in.readOptionalBoolean();
183+
this.topN = in.readOptionalInt();
184+
} else {
185+
this.returnDocuments = null;
186+
this.topN = null;
187+
}
188+
167189
// streaming is not supported yet for transport traffic
168190
this.stream = false;
169191
}
@@ -184,6 +206,14 @@ public String getQuery() {
184206
return query;
185207
}
186208

209+
public Boolean getReturnDocuments() {
210+
return returnDocuments;
211+
}
212+
213+
public Integer getTopN() {
214+
return topN;
215+
}
216+
187217
public Map<String, Object> getTaskSettings() {
188218
return taskSettings;
189219
}
@@ -225,6 +255,17 @@ public ActionRequestValidationException validate() {
225255
e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
226256
return e;
227257
}
258+
} else if (taskType.equals(TaskType.ANY) == false) {
259+
if (returnDocuments != null) {
260+
var e = new ActionRequestValidationException();
261+
e.addValidationError(format("Field [return_documents] cannot be specified for task type [%s]", taskType));
262+
return e;
263+
}
264+
if (topN != null) {
265+
var e = new ActionRequestValidationException();
266+
e.addValidationError(format("Field [top_n] cannot be specified for task type [%s]", taskType));
267+
return e;
268+
}
228269
}
229270

230271
if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
@@ -258,6 +299,11 @@ public void writeTo(StreamOutput out) throws IOException {
258299
out.writeOptionalString(query);
259300
out.writeTimeValue(inferenceTimeout);
260301
}
302+
303+
if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
304+
out.writeOptionalBoolean(returnDocuments);
305+
out.writeOptionalInt(topN);
306+
}
261307
}
262308

263309
// default for easier testing
@@ -283,6 +329,8 @@ public boolean equals(Object o) {
283329
&& taskType == request.taskType
284330
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
285331
&& Objects.equals(query, request.query)
332+
&& Objects.equals(returnDocuments, request.returnDocuments)
333+
&& Objects.equals(topN, request.topN)
286334
&& Objects.equals(input, request.input)
287335
&& Objects.equals(taskSettings, request.taskSettings)
288336
&& inputType == request.inputType
@@ -296,6 +344,8 @@ public int hashCode() {
296344
taskType,
297345
inferenceEntityId,
298346
query,
347+
returnDocuments,
348+
topN,
299349
input,
300350
taskSettings,
301351
inputType,
@@ -312,6 +362,8 @@ public static class Builder {
312362
private InputType inputType = InputType.UNSPECIFIED;
313363
private Map<String, Object> taskSettings = Map.of();
314364
private String query;
365+
private Boolean returnDocuments;
366+
private Integer topN;
315367
private TimeValue timeout = DEFAULT_TIMEOUT;
316368
private boolean stream = false;
317369
private InferenceContext context;
@@ -338,6 +390,16 @@ public Builder setQuery(String query) {
338390
return this;
339391
}
340392

393+
public Builder setReturnDocuments(Boolean returnDocuments) {
394+
this.returnDocuments = returnDocuments;
395+
return this;
396+
}
397+
398+
public Builder setTopN(Integer topN) {
399+
this.topN = topN;
400+
return this;
401+
}
402+
341403
public Builder setInputType(InputType inputType) {
342404
this.inputType = inputType;
343405
return this;
@@ -373,7 +435,19 @@ public Builder setContext(InferenceContext context) {
373435
}
374436

375437
public Request build() {
376-
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
438+
return new Request(
439+
taskType,
440+
inferenceEntityId,
441+
query,
442+
returnDocuments,
443+
topN,
444+
input,
445+
taskSettings,
446+
inputType,
447+
timeout,
448+
stream,
449+
context
450+
);
377451
}
378452
}
379453

@@ -384,6 +458,10 @@ public String toString() {
384458
+ this.getInferenceEntityId()
385459
+ ", query="
386460
+ this.getQuery()
461+
+ ", returnDocuments="
462+
+ this.getReturnDocuments()
463+
+ ", topN="
464+
+ this.getTopN()
387465
+ ", input="
388466
+ this.getInput()
389467
+ ", taskSettings="

0 commit comments

Comments
 (0)