Skip to content

Commit 65b20bd

Browse files
committed
Added the single embeddings provider
1 parent c55628c commit 65b20bd

File tree

3 files changed

+48
-1
lines changed

3 files changed

+48
-1
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
100100
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
101101
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
102+
import org.elasticsearch.xpack.inference.queries.SingleEmbeddingsProvider;
102103
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
103104
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
104105
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
@@ -429,6 +430,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
429430
entries.add(new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new));
430431
entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom));
431432
entries.add(new NamedWriteableRegistry.Entry(EmbeddingsProvider.class, MapEmbeddingsProvider.NAME, MapEmbeddingsProvider::new));
433+
entries.add(
434+
new NamedWriteableRegistry.Entry(EmbeddingsProvider.class, SingleEmbeddingsProvider.NAME, SingleEmbeddingsProvider::new)
435+
);
432436
return entries;
433437
}
434438

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/MapEmbeddingsProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public MapEmbeddingsProvider() {
2525
}
2626

2727
public MapEmbeddingsProvider(StreamInput in) throws IOException {
28-
embeddings = in.readMap(InferenceEndpointKey::new, i -> i.readNamedWriteable(InferenceResults.class));
28+
this.embeddings = in.readMap(InferenceEndpointKey::new, i -> i.readNamedWriteable(InferenceResults.class));
2929
}
3030

3131
@Override
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.inference.InferenceResults;
13+
14+
import java.io.IOException;
15+
16+
public class SingleEmbeddingsProvider implements EmbeddingsProvider {
17+
public static final String NAME = "single_embeddings_provider";
18+
19+
private final InferenceResults embeddings;
20+
21+
public SingleEmbeddingsProvider(InferenceResults embeddings) {
22+
this.embeddings = embeddings;
23+
}
24+
25+
public SingleEmbeddingsProvider(StreamInput in) throws IOException {
26+
this.embeddings = in.readNamedWriteable(InferenceResults.class);
27+
}
28+
29+
@Override
30+
public String getWriteableName() {
31+
return NAME;
32+
}
33+
34+
@Override
35+
public void writeTo(StreamOutput out) throws IOException {
36+
out.writeNamedWriteable(embeddings);
37+
}
38+
39+
@Override
40+
public InferenceResults getEmbeddings(InferenceEndpointKey key) {
41+
return embeddings;
42+
}
43+
}

0 commit comments

Comments
 (0)