Skip to content

Commit ab63ea7

Browse files
davidkyledakrone
authored andcommitted
[ML] Support sparse embedding models in the elasticsearch inference service (elastic#112270)
For a sparse embedding model created with the ml trained models APIs
1 parent 83eedf2 commit ab63ea7

File tree

8 files changed

+363
-250
lines changed

8 files changed

+363
-250
lines changed

docs/changelog/112270.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 112270
2+
summary: Support sparse embedding models in the elasticsearch inference service
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

docs/reference/inference/service-elasticsearch.asciidoc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ include::inference-shared.asciidoc[tag=task-type]
3131
Available task types:
3232

3333
* `rerank`,
34+
* `sparse_embedding`,
3435
* `text_embedding`.
3536
--
3637

@@ -182,4 +183,4 @@ PUT _inference/text_embedding/my-e5-model
182183
}
183184
}
184185
------------------------------------------------------------
185-
// TEST[skip:TBD]
186+
// TEST[skip:TBD]
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference;
9+
10+
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.core.Strings;
12+
import org.elasticsearch.inference.TaskType;
13+
14+
import java.io.IOException;
15+
import java.util.ArrayList;
16+
import java.util.Base64;
17+
import java.util.List;
18+
import java.util.stream.Collectors;
19+
20+
public class CustomElandModelIT extends InferenceBaseRestTest {
21+
22+
// The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT
23+
24+
static final String BASE_64_ENCODED_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAA"
25+
+ "AAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VG"
26+
+ "V4dEV4cGFuc2lvbgpxACmBfShYCAAAAHRyYWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29"
27+
+ "rcQJOdWJxAy5QSwcIITmbsFgAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxl"
28+
+ "bW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQhkAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWoWRT4+cMAzF7"
29+
+ "/spfASJomF3e0Ga3nrrn8vcELIyxAzRhAQlpjvbT19DWDrdquqBA/bvPT87nVUxwsm41xPd+PNtUi4a77"
30+
+ "KvXs+W8voBAHFSQY3EFCIiHKFp1+p57vs/ShyUccZdoIaz93aBTMR+thbPqru+qKBx8P4q/e8TyxRlmwVc"
31+
+ "tJp66H1YmCyS7WsZwD50A2L5V7pCBADGTTOj0bGGE7noQyqzv5JDfp0o9fZRCWqP37yjhE4+mqX5X3AdF"
32+
+ "ZHGM/2TzOHDpy1IvQWR+OWo3KwsRiKdpcqg4pBFDtm+QJ7nqwIPckrlnGfFJG0uNhOl38Sjut3pCqg26Qu"
33+
+ "Zy8BR9In7ScHHrKkKMW0TIucFrGQXCMpdaDO05O6DpOiy8e4kr0Ed/2YKOIhplW8gPr4ntygrd9ixpx3j9"
34+
+ "UZZVRagl2c6+imWUzBjuf5m+Ch7afphuvvW+r/0dsfn+2N9MZGb9+/SFtCYdhd83CMYp+mGy0LiKNs8y/e"
35+
+ "UuEA8B/d2z4dfUEsHCFSE3IaCAQAAIAMAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJwApAHNpbXBsZ"
36+
+ "W1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCJQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlp"
37+
+ "aWlpaWlpaWlpaWlpahZHLbtNAFIZtp03rSVIuLRKXjdk5ojitKJsiFq24lem0KKSqpRIZt55gE9/GM+lNL"
38+
+ "Fgx4i1Ys2aHhIBXgAVICNggHgNm6rqJN2BZGv36/v/MOWeea/Z5RVHurLfRUsfZXOnccx522itrd53O0vL"
39+
+ "qbaKYtsAKUe1pcege7hm9JNtzM8+kOOzNApIX0A3xBXE6YE7g0UWjg2OaZAJXbKvALOnj2GEHKc496ykLkt"
40+
+ "gNt3Jz17hprCUxFqExe7YIpQkNpO1/kfHhPUdtUAdH2/gfmeYiIFW7IkM6IBP2wrDNbMe3Mjf2ksiK3Hjg"
41+
+ "hg7F2DN9l/omZZl5Mmez2QRk0q4WUUB0+1oh9nDwxGdUXJdXPMRZQs352eGaRPV9s2lcMeZFGWBfKJJiw0Y"
42+
+ "gbCMLBaRmXyy4flx6a667Fch55q05QOq2Jg2ANOyZwplhNsjiohVApo7aa21QnNGW5+4GXv8gxK1beBeHSR"
43+
+ "rhmLXWVh+0aBhErZ7bx1ejxMOhlR6QU4ycNqGyk8/yNGCWkwY7/RCD7UEQek4QszCgDJAzZtfErA0VqHBy9"
44+
+ "ugQP9pUfUmgCjVYgWNwHFbhBJyEOgSwBuuwARWZmoI6J9PwLfzEocpRpPrT8DP8wqHG0b4UX+E3DiscvRgl"
45+
+ "XIoi81KKPwioHI5x9EooNKWiy0KOc/T6WF4SssrRuzJ9L2VNRXUhJzj6UKYfS4W/q/5wuh/l4M9R9qsU+y2"
46+
+ "dpoo2hJzkaEET8r6KRONicnRdK9EbUi6raFVIwNGjsrlbpk6ZPi7TbS3fv3LyNjPiEKzG0aG0tvNb6xw90/"
47+
+ "whe6ONjnJcUxobHDUqQ8bIOW79BVBLBwhfSmPKdAIAAE4EAABQSwMEAAAICAAAAAAAAAAAAAAAAAAAAAAAA"
48+
+ "BkABQBzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsRkIBAFqAAikuUEsHCG0vCVcEAAAABAAAAFBLAwQAAAgI"
49+
+ "AAAAAAAAAAAAAAAAAAAAAAAAEwA7AHNpbXBsZW1vZGVsL3ZlcnNpb25GQjcAWlpaWlpaWlpaWlpaWlpaWlp"
50+
+ "aWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWjMKUEsHCNGeZ1UCAAAAAgAAAFBLAQIAAA"
51+
+ "AACAgAAAAAAAAhOZuwWAAAAFgAAAAUAAAAAAAAAAAAAAAAAAAAAABzaW1wbGVtb2RlbC9kYXRhLnBrbFBLA"
52+
+ "QIAABQACAgIAAAAAABUhNyGggEAACADAAAdAAAAAAAAAAAAAAAAAKgAAABzaW1wbGVtb2RlbC9jb2RlL19f"
53+
+ "dG9yY2hfXy5weVBLAQIAABQACAgIAAAAAABfSmPKdAIAAE4EAAAnAAAAAAAAAAAAAAAAAJICAABzaW1wbGVt"
54+
+ "b2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xQSwECAAAAAAgIAAAAAAAAbS8JVwQAAAAEAAAAGQAA"
55+
+ "AAAAAAAAAAAAAACEBQAAc2ltcGxlbW9kZWwvY29uc3RhbnRzLnBrbFBLAQIAAAAACAgAAAAAAADRnmdVAgAA"
56+
+ "AAIAAAATAAAAAAAAAAAAAAAAANQFAABzaW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAA"
57+
+ "AAAABQAAAAAAAAAFAAAAAAAAAGoBAAAAAAAAUgYAAAAAAABQSwYHAAAAALwHAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAABSBgAAAAA=";
58+
59+
static final long RAW_MODEL_SIZE; // size of the model before base64 encoding
60+
static {
61+
RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length;
62+
}
63+
64+
// Test a sparse embedding model deployed with the ml trained models APIs
65+
public void testSparse() throws IOException {
66+
String modelId = "custom-text-expansion-model";
67+
68+
createTextExpansionModel(modelId);
69+
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
70+
putVocabulary(
71+
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
72+
modelId
73+
);
74+
75+
var inferenceConfig = """
76+
{
77+
"service": "elasticsearch",
78+
"service_settings": {
79+
"model_id": "custom-text-expansion-model",
80+
"num_allocations": 1,
81+
"num_threads": 1
82+
}
83+
}
84+
""";
85+
86+
var inferenceId = "sparse-inf";
87+
putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
88+
var results = inferOnMockService(inferenceId, List.of("washing", "machine"));
89+
deleteModel(inferenceId);
90+
assertNotNull(results.get("sparse_embedding"));
91+
}
92+
93+
protected void createTextExpansionModel(String modelId) throws IOException {
94+
// with_special_tokens: false for this test with limited vocab
95+
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
96+
request.setJsonEntity("""
97+
{
98+
"description": "a text expansion model",
99+
"model_type": "pytorch",
100+
"inference_config": {
101+
"text_expansion": {
102+
"tokenization": {
103+
"bert": {
104+
"with_special_tokens": false
105+
}
106+
}
107+
}
108+
}
109+
}""");
110+
client().performRequest(request);
111+
}
112+
113+
protected void putVocabulary(List<String> vocabulary, String modelId) throws IOException {
114+
List<String> vocabularyWithPad = new ArrayList<>();
115+
vocabularyWithPad.add("[PAD]");
116+
vocabularyWithPad.add("[UNK]");
117+
vocabularyWithPad.addAll(vocabulary);
118+
String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));
119+
120+
Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary");
121+
request.setJsonEntity(Strings.format("""
122+
{ "vocabulary": [%s] }
123+
""", quotedWords));
124+
client().performRequest(request);
125+
}
126+
127+
protected void putModelDefinition(String modelId, String base64EncodedModel, long unencodedModelSize) throws IOException {
128+
Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
129+
String body = Strings.format("""
130+
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""", unencodedModelSize, base64EncodedModel);
131+
request.setJsonEntity(body);
132+
client().performRequest(request);
133+
}
134+
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/RerankingIT.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ private String putCohereRerankEndpoint() throws IOException {
3535
"api_key": ""
3636
}
3737
}
38-
""");// TODO remove key
38+
""");
3939
return endpointID;
4040
}
4141

@@ -61,7 +61,7 @@ private String putCohereRerankEndpointWithDocuments() throws IOException {
6161
"return_documents": true
6262
}
6363
}
64-
""");// TODO remove key
64+
""");
6565
return endpointID;
6666
}
6767

@@ -81,13 +81,13 @@ private String putCohereRerankEndpointWithTop2() throws IOException {
8181
"service": "cohere",
8282
"service_settings": {
8383
"model_id": "rerank-english-v2.0",
84-
"api_key": "8TNPBvpBO7oN97009HQHzQbBhNrxmREbcJrZCwkK"
84+
"api_key": ""
8585
},
8686
"task_settings": {
8787
"top_n": 2
8888
}
8989
}
90-
""");// TODO remove key
90+
""");
9191
return endpointID;
9292
}
9393

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,10 @@ public void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
154154
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, GetTrainedModelsAction.INSTANCE, getRequest, getModelsResponseListener);
155155
} else {
156156
listener.onFailure(
157-
new IllegalArgumentException(
158-
"Unable to determine supported model for ["
157+
new IllegalStateException(
158+
"Can not check the download status of the model used by ["
159159
+ model.getConfigurations().getInferenceEntityId()
160-
+ "] please verify the request and submit a bug report if necessary."
160+
+ "] as the model_id cannot be found."
161161
)
162162
);
163163
}

0 commit comments

Comments
 (0)