Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
Expand Down Expand Up @@ -50,13 +49,8 @@ public Request(StreamInput in) throws IOException {
super(in);
this.inferenceEndpointId = in.readString();
this.taskType = TaskType.fromStream(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
this.forceDelete = Boolean.TRUE.equals(in.readOptionalBoolean());
this.dryRun = Boolean.TRUE.equals(in.readOptionalBoolean());
} else {
this.forceDelete = false;
this.dryRun = false;
}
this.forceDelete = Boolean.TRUE.equals(in.readOptionalBoolean());
this.dryRun = Boolean.TRUE.equals(in.readOptionalBoolean());
}

public String getInferenceEndpointId() {
Expand All @@ -80,10 +74,8 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEndpointId);
taskType.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
out.writeOptionalBoolean(forceDelete);
out.writeOptionalBoolean(dryRun);
}
out.writeOptionalBoolean(forceDelete);
out.writeOptionalBoolean(dryRun);
}

@Override
Expand Down Expand Up @@ -121,32 +113,17 @@ public Response(boolean acknowledged, Set<String> pipelineIds, Set<String> seman

public Response(StreamInput in) throws IOException {
super(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
pipelineIds = in.readCollectionAsSet(StreamInput::readString);
} else {
pipelineIds = Set.of();
}

if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
indexes = in.readCollectionAsSet(StreamInput::readString);
dryRunMessage = in.readOptionalString();
} else {
indexes = Set.of();
dryRunMessage = null;
}

pipelineIds = in.readCollectionAsSet(StreamInput::readString);
indexes = in.readCollectionAsSet(StreamInput::readString);
dryRunMessage = in.readOptionalString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) {
out.writeCollection(pipelineIds, StreamOutput::writeString);
}
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
out.writeCollection(indexes, StreamOutput::writeString);
out.writeOptionalString(dryRunMessage);
}
out.writeCollection(pipelineIds, StreamOutput::writeString);
out.writeCollection(indexes, StreamOutput::writeString);
out.writeOptionalString(dryRunMessage);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
Expand All @@ -19,7 +18,6 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -59,11 +57,7 @@ public Request(StreamInput in) throws IOException {
super(in);
this.inferenceEntityId = in.readString();
this.taskType = TaskType.fromStream(in);
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
this.persistDefaultConfig = in.readBoolean();
} else {
this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS;
}
this.persistDefaultConfig = in.readBoolean();
}

public String getInferenceEntityId() {
Expand All @@ -83,9 +77,7 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEntityId);
taskType.writeTo(out);
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) {
out.writeBoolean(this.persistDefaultConfig);
}
out.writeBoolean(this.persistDefaultConfig);
}

@Override
Expand Down Expand Up @@ -113,12 +105,7 @@ public Response(List<ModelConfigurations> endpoints) {
}

public Response(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
endpoints = in.readCollectionAsList(ModelConfigurations::new);
} else {
endpoints = new ArrayList<>();
endpoints.add(new ModelConfigurations(in));
}
endpoints = in.readCollectionAsList(ModelConfigurations::new);
}

public List<ModelConfigurations> getEndpoints() {
Expand All @@ -127,11 +114,7 @@ public List<ModelConfigurations> getEndpoints() {

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
out.writeCollection(endpoints);
} else {
endpoints.get(0).writeTo(out);
}
out.writeCollection(endpoints);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand All @@ -19,23 +17,16 @@
import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xcontent.ObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.inference.InferenceContext;
import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -79,12 +70,6 @@ public static Builder builder(String inferenceEntityId, TaskType taskType) {
PARSER.declareString(Builder::setInferenceTimeout, TIMEOUT);
}

private static final EnumSet<InputType> validEnumsBeforeUnspecifiedAdded = EnumSet.of(InputType.INGEST, InputType.SEARCH);
private static final EnumSet<InputType> validEnumsBeforeClassificationClusteringAdded = EnumSet.range(
InputType.INGEST,
InputType.UNSPECIFIED
);

public static Builder parseRequest(String inferenceEntityId, TaskType taskType, InferenceContext context, XContentParser parser)
throws IOException {
Request.Builder builder = PARSER.apply(parser, null);
Expand Down Expand Up @@ -164,25 +149,11 @@ public Request(StreamInput in) throws IOException {
super(in);
this.taskType = TaskType.fromStream(in);
this.inferenceEntityId = in.readString();
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
this.input = in.readStringCollectionAsList();
} else {
this.input = List.of(in.readString());
}
this.input = in.readStringCollectionAsList();
this.taskSettings = in.readGenericMap();
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
this.inputType = in.readEnum(InputType.class);
} else {
this.inputType = InputType.UNSPECIFIED;
}

if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) {
this.query = in.readOptionalString();
this.inferenceTimeout = in.readTimeValue();
} else {
this.query = null;
this.inferenceTimeout = DEFAULT_TIMEOUT;
}
this.inputType = in.readEnum(InputType.class);
this.query = in.readOptionalString();
this.inferenceTimeout = in.readTimeValue();

if (in.getTransportVersion().supports(RERANK_COMMON_OPTIONS_ADDED)) {
this.returnDocuments = in.readOptionalBoolean();
Expand Down Expand Up @@ -298,41 +269,18 @@ public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
taskType.writeTo(out);
out.writeString(inferenceEntityId);
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
out.writeStringCollection(input);
} else {
out.writeString(input.get(0));
}
out.writeStringCollection(input);
out.writeGenericMap(taskSettings);

if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
out.writeEnum(getInputTypeToWrite(inputType, out.getTransportVersion()));
}

if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0)) {
out.writeOptionalString(query);
out.writeTimeValue(inferenceTimeout);
}
out.writeEnum(inputType);
out.writeOptionalString(query);
out.writeTimeValue(inferenceTimeout);

if (out.getTransportVersion().supports(RERANK_COMMON_OPTIONS_ADDED)) {
out.writeOptionalBoolean(returnDocuments);
out.writeOptionalInt(topN);
}
}

// default for easier testing
static InputType getInputTypeToWrite(InputType inputType, TransportVersion version) {
if (version.before(TransportVersions.V_8_13_0)) {
if (validEnumsBeforeUnspecifiedAdded.contains(inputType) == false) {
return InputType.INGEST;
} else if (validEnumsBeforeClassificationClusteringAdded.contains(inputType) == false) {
return InputType.UNSPECIFIED;
}
}

return inputType;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down Expand Up @@ -509,65 +457,12 @@ public Response(InferenceServiceResults results, Flow.Publisher<InferenceService
}

public Response(StreamInput in) throws IOException {
if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
results = in.readNamedWriteable(InferenceServiceResults.class);
} else {
// It should only be InferenceResults aka TextEmbeddingResults from ml plugin for
// hugging face elser and elser
results = transformToServiceResults(List.of(in.readNamedWriteable(InferenceResults.class)));
}
this.results = in.readNamedWriteable(InferenceServiceResults.class);
// streaming isn't supported via Writeable yet
this.isStreaming = false;
this.publisher = null;
}

@SuppressWarnings("deprecation")
public static InferenceServiceResults transformToServiceResults(List<? extends InferenceResults> parsedResults) {
if (parsedResults.isEmpty()) {
throw new ElasticsearchStatusException(
"Failed to transform results to response format, expected a non-empty list, please remove and re-add the service",
RestStatus.INTERNAL_SERVER_ERROR
);
}

if (parsedResults.get(0) instanceof LegacyTextEmbeddingResults openaiResults) {
if (parsedResults.size() > 1) {
throw new ElasticsearchStatusException(
"Failed to transform results to response format, malformed text embedding result,"
+ " please remove and re-add the service",
RestStatus.INTERNAL_SERVER_ERROR
);
}

return openaiResults.transformToTextEmbeddingResults();
} else if (parsedResults.get(0) instanceof TextExpansionResults) {
return transformToSparseEmbeddingResult(parsedResults);
} else {
throw new ElasticsearchStatusException(
"Failed to transform results to response format, unknown embedding type received,"
+ " please remove and re-add the service",
RestStatus.INTERNAL_SERVER_ERROR
);
}
}

private static SparseEmbeddingResults transformToSparseEmbeddingResult(List<? extends InferenceResults> parsedResults) {
List<TextExpansionResults> textExpansionResults = new ArrayList<>(parsedResults.size());

for (InferenceResults result : parsedResults) {
if (result instanceof TextExpansionResults textExpansion) {
textExpansionResults.add(textExpansion);
} else {
throw new ElasticsearchStatusException(
"Failed to transform results to response format, please remove and re-add the service",
RestStatus.INTERNAL_SERVER_ERROR
);
}
}

return SparseEmbeddingResults.of(textExpansionResults);
}

public InferenceServiceResults getResults() {
return results;
}
Expand Down
Loading