Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,12 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte

/**
* <p>Transform the result to match the format required for the TransportCoordinatedInferenceAction.
* For the inference plugin TextEmbeddingResults, the {@link #transformToLegacyFormat()} transforms the
* results into an intermediate format only used by the plugin's return value. It doesn't align with what the
* TransportCoordinatedInferenceAction expects. TransportCoordinatedInferenceAction expects an ml plugin
* TextEmbeddingResults.</p>
*
* <p>For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.</p>
* TransportCoordinatedInferenceAction expects an ml plugin TextEmbeddingResults or SparseEmbeddingResults.</p>
*/
default List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
}

/**
* Transform the result to match the format required for versions prior to
* {@link org.elasticsearch.TransportVersions#V_8_12_0}
*/
default List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("transformToLegacyFormat() is not implemented");
}

/**
* Convert the result to a map to aid with test assertions
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,12 +580,8 @@ public Flow.Publisher<InferenceServiceResults.Result> publisher() {

@Override
public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
out.writeNamedWriteable(results);
} else {
out.writeNamedWriteable(results.transformToLegacyFormat().get(0));
}
// streaming isn't supported via Writeable yet
out.writeNamedWriteable(results);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ public List<? extends InferenceResults> transformToCoordinationFormat() {
return results;
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException();
}

public List<Result> getResults() {
return results;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,6 @@ public List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("Coordination format not supported by " + NAME);
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("Legacy format not supported by " + NAME);
}

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,6 @@ public Map<String, Object> asMap() {

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return transformToLegacyFormat();
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
return embeddings.stream()
.map(
embedding -> new TextExpansionResults(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,6 @@ public List<? extends InferenceResults> transformToCoordinationFormat() {
.toList();
}

@Override
@SuppressWarnings("deprecation")
public List<? extends InferenceResults> transformToLegacyFormat() {
var legacyEmbedding = new LegacyTextEmbeddingResults(
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloatArray())).toList()
);

return List.of(legacyEmbedding);
}

public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TEXT_EMBEDDING_BITS, embeddings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,6 @@ public List<? extends InferenceResults> transformToCoordinationFormat() {
.toList();
}

@Override
@SuppressWarnings("deprecation")
public List<? extends InferenceResults> transformToLegacyFormat() {
var legacyEmbedding = new LegacyTextEmbeddingResults(
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.toFloatArray())).toList()
);

return List.of(legacyEmbedding);
}

public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TEXT_EMBEDDING_BYTES, embeddings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,6 @@ public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream().map(embedding -> new MlTextEmbeddingResults(TEXT_EMBEDDING, embedding.asDoubleArray(), false)).toList();
}

@Override
@SuppressWarnings("deprecation")
public List<? extends InferenceResults> transformToLegacyFormat() {
var legacyEmbedding = new LegacyTextEmbeddingResults(
embeddings.stream().map(embedding -> new LegacyTextEmbeddingResults.Embedding(embedding.values)).toList()
);

return List.of(legacyEmbedding);
}

public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(TEXT_EMBEDDING, embeddings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,6 @@ public List<? extends InferenceResults> transformToCoordinationFormat() {
return List.of();
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
return List.of();
}

@Override
public Map<String, Object> asMap() {
return Map.of();
Expand Down Expand Up @@ -283,11 +278,6 @@ public List<? extends InferenceResults> transformToCoordinationFormat() {
return List.of();
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
return List.of();
}

@Override
public Map<String, Object> asMap() {
return Map.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,6 @@ public List<? extends InferenceResults> transformToCoordinationFormat() {
throw new UnsupportedOperationException("Not implemented");
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
throw new UnsupportedOperationException("not implemented");
}

@Override
public Map<String, Object> asMap() {
throw new UnsupportedOperationException("Not implemented");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.TransportVersions.V_8_12_0;
import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Response.transformToServiceResults;

public class InferenceActionResponseTests extends AbstractBWCWireSerializationTestCase<InferenceAction.Response> {

@Override
Expand Down Expand Up @@ -58,61 +55,6 @@ protected InferenceAction.Response mutateInstance(InferenceAction.Response insta

@Override
protected InferenceAction.Response mutateInstanceForVersion(InferenceAction.Response instance, TransportVersion version) {
if (version.before(V_8_12_0)) {
var singleResultList = instance.getResults().transformToLegacyFormat().subList(0, 1);
return new InferenceAction.Response(transformToServiceResults(singleResultList));
}

return instance;
}

public void testSerializesInferenceServiceResultsAddedVersion() throws IOException {
var instance = createTestInstance();
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
assertOnBWCObject(copy, instance, V_8_12_0);
}

public void testSerializesOpenAiAddedVersion_UsingLegacyTextEmbeddingResult() throws IOException {
var embeddingResults = LegacyMlTextEmbeddingResultsTests.createRandomResults().transformToTextEmbeddingResults();
var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
assertOnBWCObject(copy, instance, V_8_12_0);
}

public void testSerializesOpenAiAddedVersion_UsingSparseEmbeddingResult() throws IOException {
var embeddingResults = SparseEmbeddingResultsTests.createRandomResults();
var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
assertOnBWCObject(copy, instance, V_8_12_0);
}

public void testSerializesMultipleInputsVersion_UsingLegacyTextEmbeddingResult() throws IOException {
var embeddingResults = TextEmbeddingFloatResultsTests.createRandomResults();
var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
assertOnBWCObject(copy, instance, V_8_12_0);
}

public void testSerializesMultipleInputsVersion_UsingSparseEmbeddingResult() throws IOException {
var embeddingResults = SparseEmbeddingResultsTests.createRandomResults();
var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
assertOnBWCObject(copy, instance, V_8_12_0);
}

// Technically we should never see a text embedding result in the transport version of this test because support
// for it wasn't added until openai
public void testSerializesSingleInputVersion_UsingLegacyTextEmbeddingResult() throws IOException {
var embeddingResults = TextEmbeddingFloatResultsTests.createRandomResults();
var instance = new InferenceAction.Response(embeddingResults);
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
assertOnBWCObject(copy, instance, V_8_12_0);
}

public void testSerializesSingleVersion_UsingSparseEmbeddingResult() throws IOException {
var embeddingResults = SparseEmbeddingResultsTests.createRandomResults().transformToLegacyFormat().subList(0, 1);
var instance = new InferenceAction.Response(transformToServiceResults(embeddingResults));
var copy = copyWriteable(instance, getNamedWriteableRegistry(), instanceReader(), V_8_12_0);
assertOnBWCObject(copy, instance, V_8_12_0);
}
}