Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/125239.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 125239
summary: Adding common rerank options to Perform Inference API
area: Machine Learning
type: enhancement
issues:
- 111273
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_AGGREGATE_METRIC_DOUBLE_LITERAL = def(8_841_0_12);
public static final TransportVersion INFERENCE_MODEL_REGISTRY_METADATA_8_19 = def(8_841_0_13);
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,22 @@ default boolean hideFromConfigurationApi() {
/**
* Perform inference on the model.
*
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param input Inference input
* @param stream Stream inference results
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Inference result listener
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param returnDocuments For re-ranking task type, whether to return documents
* @param topN For re-ranking task type, how many docs to return
* @param input Inference input
* @param stream Stream inference results
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Inference result listener
*/
void infer(
Model model,
@Nullable String query,
@Nullable Boolean returnDocuments,
@Nullable Integer topN,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public static class Request extends BaseInferenceActionRequest {
public static final ParseField INPUT_TYPE = new ParseField("input_type");
public static final ParseField TASK_SETTINGS = new ParseField("task_settings");
public static final ParseField QUERY = new ParseField("query");
public static final ParseField RETURN_DOCUMENTS = new ParseField("return_documents");
public static final ParseField TOP_N = new ParseField("top_n");
public static final ParseField TIMEOUT = new ParseField("timeout");

static final ObjectParser<Request.Builder, Void> PARSER = new ObjectParser<>(NAME, Request.Builder::new);
Expand All @@ -68,6 +70,8 @@ public static class Request extends BaseInferenceActionRequest {
PARSER.declareString(Request.Builder::setInputType, INPUT_TYPE);
PARSER.declareObject(Request.Builder::setTaskSettings, (p, c) -> p.mapOrdered(), TASK_SETTINGS);
PARSER.declareString(Request.Builder::setQuery, QUERY);
PARSER.declareBoolean(Request.Builder::setReturnDocuments, RETURN_DOCUMENTS);
PARSER.declareInt(Request.Builder::setTopN, TOP_N);
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
}

Expand All @@ -89,6 +93,8 @@ public static Builder parseRequest(String inferenceEntityId, TaskType taskType,
private final TaskType taskType;
private final String inferenceEntityId;
private final String query;
private final Boolean returnDocuments;
private final Integer topN;
private final List<String> input;
private final Map<String, Object> taskSettings;
private final InputType inputType;
Expand All @@ -99,6 +105,8 @@ public Request(
TaskType taskType,
String inferenceEntityId,
String query,
Boolean returnDocuments,
Integer topN,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
Expand All @@ -109,6 +117,8 @@ public Request(
taskType,
inferenceEntityId,
query,
returnDocuments,
topN,
input,
taskSettings,
inputType,
Expand All @@ -122,6 +132,8 @@ public Request(
TaskType taskType,
String inferenceEntityId,
String query,
Boolean returnDocuments,
Integer topN,
List<String> input,
Map<String, Object> taskSettings,
InputType inputType,
Expand All @@ -133,6 +145,8 @@ public Request(
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.query = query;
this.returnDocuments = returnDocuments;
this.topN = topN;
this.input = input;
this.taskSettings = taskSettings;
this.inputType = inputType;
Expand Down Expand Up @@ -164,6 +178,14 @@ public Request(StreamInput in) throws IOException {
this.inferenceTimeout = DEFAULT_TIMEOUT;
}

if (in.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
this.returnDocuments = in.readOptionalBoolean();
this.topN = in.readOptionalInt();
} else {
this.returnDocuments = null;
this.topN = null;
}

// streaming is not supported yet for transport traffic
this.stream = false;
}
Expand All @@ -184,6 +206,14 @@ public String getQuery() {
return query;
}

public Boolean getReturnDocuments() {
return returnDocuments;
}

public Integer getTopN() {
return topN;
}

public Map<String, Object> getTaskSettings() {
return taskSettings;
}
Expand Down Expand Up @@ -225,6 +255,17 @@ public ActionRequestValidationException validate() {
e.addValidationError(format("Field [query] cannot be empty for task type [%s]", TaskType.RERANK));
return e;
}
} else if (taskType.equals(TaskType.ANY) == false) {
if (returnDocuments != null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [return_documents] cannot be specified for task type [%s]", taskType));
return e;
}
if (topN != null) {
var e = new ActionRequestValidationException();
e.addValidationError(format("Field [top_n] cannot be specified for task type [%s]", taskType));
return e;
}
}

if (taskType.equals(TaskType.TEXT_EMBEDDING) == false
Expand Down Expand Up @@ -258,6 +299,11 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(query);
out.writeTimeValue(inferenceTimeout);
}

if (out.getTransportVersion().onOrAfter(TransportVersions.RERANK_COMMON_OPTIONS_ADDED_8_19)) {
out.writeOptionalBoolean(returnDocuments);
out.writeOptionalInt(topN);
}
}

// default for easier testing
Expand All @@ -283,6 +329,8 @@ public boolean equals(Object o) {
&& taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(query, request.query)
&& Objects.equals(returnDocuments, request.returnDocuments)
&& Objects.equals(topN, request.topN)
&& Objects.equals(input, request.input)
&& Objects.equals(taskSettings, request.taskSettings)
&& inputType == request.inputType
Expand All @@ -296,6 +344,8 @@ public int hashCode() {
taskType,
inferenceEntityId,
query,
returnDocuments,
topN,
input,
taskSettings,
inputType,
Expand All @@ -312,6 +362,8 @@ public static class Builder {
private InputType inputType = InputType.UNSPECIFIED;
private Map<String, Object> taskSettings = Map.of();
private String query;
private Boolean returnDocuments;
private Integer topN;
private TimeValue timeout = DEFAULT_TIMEOUT;
private boolean stream = false;
private InferenceContext context;
Expand All @@ -338,6 +390,16 @@ public Builder setQuery(String query) {
return this;
}

public Builder setReturnDocuments(Boolean returnDocuments) {
this.returnDocuments = returnDocuments;
return this;
}

public Builder setTopN(Integer topN) {
this.topN = topN;
return this;
}

public Builder setInputType(InputType inputType) {
this.inputType = inputType;
return this;
Expand Down Expand Up @@ -373,7 +435,19 @@ public Builder setContext(InferenceContext context) {
}

public Request build() {
return new Request(taskType, inferenceEntityId, query, input, taskSettings, inputType, timeout, stream, context);
return new Request(
taskType,
inferenceEntityId,
query,
returnDocuments,
topN,
input,
taskSettings,
inputType,
timeout,
stream,
context
);
}
}

Expand All @@ -384,6 +458,10 @@ public String toString() {
+ this.getInferenceEntityId()
+ ", query="
+ this.getQuery()
+ ", returnDocuments="
+ this.getReturnDocuments()
+ ", topN="
+ this.getTopN()
+ ", input="
+ this.getInput()
+ ", taskSettings="
Expand Down
Loading