Skip to content

Commit c55628c

Browse files
committed
Added the map embeddings provider
1 parent 0b8a1a9 commit c55628c

File tree

4 files changed

+117
-0
lines changed

4 files changed

+117
-0
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@
9393
import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper;
9494
import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper;
9595
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
96+
import org.elasticsearch.xpack.inference.queries.EmbeddingsProvider;
97+
import org.elasticsearch.xpack.inference.queries.MapEmbeddingsProvider;
9698
import org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor;
9799
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
98100
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
@@ -426,6 +428,7 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
426428
entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new));
427429
entries.add(new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new));
428430
entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom));
431+
entries.add(new NamedWriteableRegistry.Entry(EmbeddingsProvider.class, MapEmbeddingsProvider.NAME, MapEmbeddingsProvider::new));
429432
return entries;
430433
}
431434

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.elasticsearch.common.io.stream.NamedWriteable;
11+
import org.elasticsearch.inference.InferenceResults;
12+
13+
public interface EmbeddingsProvider extends NamedWriteable {
14+
InferenceResults getEmbeddings(InferenceEndpointKey key);
15+
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.inference.MinimalServiceSettings;
14+
15+
import java.io.IOException;
16+
import java.util.Objects;
17+
18+
public class InferenceEndpointKey implements Writeable {
19+
private final String inferenceId;
20+
private final MinimalServiceSettings serviceSettings;
21+
22+
public InferenceEndpointKey(String inferenceId, MinimalServiceSettings serviceSettings) {
23+
this.inferenceId = inferenceId;
24+
this.serviceSettings = serviceSettings;
25+
}
26+
27+
public InferenceEndpointKey(StreamInput in) throws IOException {
28+
this.inferenceId = in.readString();
29+
this.serviceSettings = in.readNamedWriteable(MinimalServiceSettings.class);
30+
}
31+
32+
@Override
33+
public void writeTo(StreamOutput out) throws IOException {
34+
out.writeString(inferenceId);
35+
out.writeNamedWriteable(serviceSettings);
36+
}
37+
38+
@Override
39+
public boolean equals(Object o) {
40+
if (this == o) return true;
41+
if (o == null || getClass() != o.getClass()) return false;
42+
InferenceEndpointKey that = (InferenceEndpointKey) o;
43+
return Objects.equals(inferenceId, that.inferenceId) && Objects.equals(serviceSettings, that.serviceSettings);
44+
}
45+
46+
@Override
47+
public int hashCode() {
48+
return Objects.hash(inferenceId, serviceSettings);
49+
}
50+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.queries;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.inference.InferenceResults;
13+
14+
import java.io.IOException;
15+
import java.util.HashMap;
16+
import java.util.Map;
17+
18+
public class MapEmbeddingsProvider implements EmbeddingsProvider {
19+
public static final String NAME = "map_embeddings_provider";
20+
21+
private final Map<InferenceEndpointKey, InferenceResults> embeddings;
22+
23+
public MapEmbeddingsProvider() {
24+
this.embeddings = new HashMap<>();
25+
}
26+
27+
public MapEmbeddingsProvider(StreamInput in) throws IOException {
28+
embeddings = in.readMap(InferenceEndpointKey::new, i -> i.readNamedWriteable(InferenceResults.class));
29+
}
30+
31+
@Override
32+
public String getWriteableName() {
33+
return NAME;
34+
}
35+
36+
@Override
37+
public void writeTo(StreamOutput out) throws IOException {
38+
out.writeMap(embeddings);
39+
}
40+
41+
@Override
42+
public InferenceResults getEmbeddings(InferenceEndpointKey key) {
43+
return embeddings.get(key);
44+
}
45+
46+
public void addEmbeddings(InferenceEndpointKey key, InferenceResults embeddings) {
47+
this.embeddings.put(key, embeddings);
48+
}
49+
}

0 commit comments

Comments
 (0)