Skip to content
Merged
6 changes: 6 additions & 0 deletions docs/changelog/125239.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 125239
summary: Adding common rerank options to Perform Inference API
area: Machine Learning
type: enhancement
issues:
- 111273
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL_8_19 = def(8_841_0_12);
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13);
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand Down Expand Up @@ -200,6 +201,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INDEXING_STATS_INCLUDES_RECENT_WRITE_LOAD = def(9_034_0_00);
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(9_035_0_00);
public static final TransportVersion INDEX_METADATA_INCLUDES_RECENT_WRITE_LOAD = def(9_036_0_00);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED = def(9_037_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,22 @@ default boolean hideFromConfigurationApi() {
/**
* Perform inference on the model.
*
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param input Inference input
* @param stream Stream inference results
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Inference result listener
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param returnDocuments For re-ranking task type, whether to return documents
* @param topN For re-ranking task type, how many docs to return
* @param input Inference input
* @param stream Stream inference results
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Inference result listener
*/
void infer(
Model model,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static class Request extends BaseInferenceActionRequest {
public static final ParseField INPUT_TYPE = new ParseField("input_type");
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
public static final ParseField QUERY = new ParseField("query");
public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents");
public static final ParseField TOP_N = new ParseField("top_n");
public static final ParseField TIMEOUT = new ParseField("timeout");

static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
Expand All @@ -68,6 +70,8 @@ public static class Request extends BaseInferenceActionRequest {
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
PARSER.declareString(Request.Builder::setQuery, QUERY);
PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS);
PARSER.declareInt(Request.Builder::setTopN, TOP_N);
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
}

Expand All @@ -89,6 +93,8 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
private final TaskType taskType;
private final String inferenceEntityId;
private final String query;
private final Boolean returnDocuments;
private final Integer topN;
private final List<String> input;
private final Map<String, Object> taskSettings;
private final InputType inputType;
Expand All @@ -99,6 +105,8 @@ public Request(
TaskType taskType,
String inferenceEntityId,
String query,
Boolean returnDocuments,
Integer topN,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
Expand All @@ -109,6 +117,8 @@ public Request(
taskType,
inferenceEntityId,
query,
returnDocuments,
topN,
input,
taskSettings,
inputType,
Expand All @@ -122,6 +132,8 @@ public Request(
TaskType taskType,
String inferenceEntityId,
String query,
Boolean returnDocuments,
Integer topN,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
Expand All @@ -133,6 +145,8 @@ public Request(
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.query = query;
this.returnDocuments = returnDocuments;
this.topN = topN;
this.input = input;
this.taskSettings = taskSettings;
this.inputType = inputType;
Expand Down Expand Up @@ -164,6 +178,15 @@ public Request(StreamInput in) throws IOException {
this.inferenceTimeout = DEFAULT_TIMEOUT;
}

if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
|| in.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
this.returnDocuments = in.readOptionalBoolean();
this.topN = in.readOptionalInt();
} else {
this.returnDocuments = null;
this.topN = null;
}

// streaming is not supported yet for transport traffic
this.stream = false;
}
Expand All @@ -184,6 +207,14 @@ public String getQuery() {
return query;
}

public Boolean getReturnDocuments() {
return returnDocuments;
}

public Integer getTopN() {
return topN;
}

public Map<String, Object> getTaskSettings() {
return taskSettings;
}
Expand Down Expand Up @@ -225,6 +256,17 @@ public ActionRequestValidationException validate() {
e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
return e;
}
} else if (taskType.equals(TaskType.ANY) == false) {
if (returnDocuments != null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [return_documents] cannot be specified for task type [%s]", taskType));
return e;
}
if (topN != null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [top_n] cannot be specified for task type [%s]", taskType));
return e;
}
}

if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
Expand Down Expand Up @@ -258,6 +300,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(query);
out.writeTimeValue(inferenceTimeout);
}

if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED)
|| out.getTransportVersion().isPatchFrom(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
out.writeOptionalBoolean(returnDocuments);
out.writeOptionalInt(topN);
}
}

// default for easier testing
Expand All @@ -283,6 +331,8 @@ public boolean equals(Object o) {
&& taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(query, request.query)
&& Objects.equals(returnDocuments, request.returnDocuments)
&& Objects.equals(topN, request.topN)
&& Objects.equals(input, request.input)
&& Objects.equals(taskSettings, request.taskSettings)
&& inputType == request.inputType
Expand All @@ -296,6 +346,8 @@ public int hashCode() {
taskType,
inferenceEntityId,
query,
returnDocuments,
topN,
input,
taskSettings,
inputType,
Expand All @@ -312,6 +364,8 @@ public static class Builder {
private InputType inputType = InputType.UNSPECIFIED;
private Map<String, Object> taskSettings = Map.of();
private String query;
private Boolean returnDocuments;
private Integer topN;
private TimeValue timeout = DEFAULT_TIMEOUT;
private boolean stream = false;
private InferenceContext context;
Expand All @@ -338,6 +392,16 @@ public Builder setQuery(String query) {
return this;
}

public Builder setReturnDocuments(Boolean returnDocuments) {
this.returnDocuments = returnDocuments;
return this;
}

public Builder setTopN(Integer topN) {
this.topN = topN;
return this;
}

public Builder setInputType(InputType inputType) {
this.inputType = inputType;
return this;
Expand Down Expand Up @@ -373,7 +437,19 @@ public Builder setContext(InferenceContext context) {
}

public Request build() {
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
return new Request(
taskType,
inferenceEntityId,
query,
returnDocuments,
topN,
input,
taskSettings,
inputType,
timeout,
stream,
context
);
}
}

Expand All @@ -384,6 +460,10 @@ public String toString() {
+ this.getInferenceEntityId()
+ ", query="
+ this.getQuery()
+ ", returnDocuments="
+ this.getReturnDocuments()
+ ", topN="
+ this.getTopN()
+ ", input="
+ this.getInput()
+ ", taskSettings="
Expand Down
Loading
Loading