Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/112270.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 112270
summary: Support sparse embedding models in the elasticsearch inference service
area: Machine Learning
type: enhancement
issues: []
3 changes: 2 additions & 1 deletion docs/reference/inference/service-elasticsearch.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ include::inference-shared.asciidoc[tag=task-type]
Available task types:

* `rerank`,
* `sparse_embedding`,
* `text_embedding`.
--

Expand Down Expand Up @@ -182,4 +183,4 @@ PUT _inference/text_embedding/my-e5-model
}
}
------------------------------------------------------------
// TEST[skip:TBD]
// TEST[skip:TBD]
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.client.Request;
import org.elasticsearch.core.Strings;
import org.elasticsearch.inference.TaskType;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.stream.Collectors;

public class CustomElandModelIT extends InferenceBaseRestTest {

// The model definition is taken from org.elasticsearch.xpack.ml.integration.TextExpansionQueryIT

static final String BASE_64_ENCODED_MODEL = "UEsDBAAACAgAAAAAAAAAAAAAAAAAA"
+ "AAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwpUaW55VG"
+ "V4dEV4cGFuc2lvbgpxACmBfShYCAAAAHRyYWluaW5ncQGJWBYAAABfaXNfZnVsbF9iYWNrd2FyZF9ob29"
+ "rcQJOdWJxAy5QSwcIITmbsFgAAABYAAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAAAAAAAdAB0Ac2ltcGxl"
+ "bW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQhkAWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWoWRT4+cMAzF7"
+ "/spfASJomF3e0Ga3nrrn8vcELIyxAzRhAQlpjvbT19DWDrdquqBA/bvPT87nVUxwsm41xPd+PNtUi4a77"
+ "KvXs+W8voBAHFSQY3EFCIiHKFp1+p57vs/ShyUccZdoIaz93aBTMR+thbPqru+qKBx8P4q/e8TyxRlmwVc"
+ "tJp66H1YmCyS7WsZwD50A2L5V7pCBADGTTOj0bGGE7noQyqzv5JDfp0o9fZRCWqP37yjhE4+mqX5X3AdF"
+ "ZHGM/2TzOHDpy1IvQWR+OWo3KwsRiKdpcqg4pBFDtm+QJ7nqwIPckrlnGfFJG0uNhOl38Sjut3pCqg26Qu"
+ "Zy8BR9In7ScHHrKkKMW0TIucFrGQXCMpdaDO05O6DpOiy8e4kr0Ed/2YKOIhplW8gPr4ntygrd9ixpx3j9"
+ "UZZVRagl2c6+imWUzBjuf5m+Ch7afphuvvW+r/0dsfn+2N9MZGb9+/SFtCYdhd83CMYp+mGy0LiKNs8y/e"
+ "UuEA8B/d2z4dfUEsHCFSE3IaCAQAAIAMAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJwApAHNpbXBsZ"
+ "W1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCJQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlp"
+ "aWlpaWlpaWlpaWlpahZHLbtNAFIZtp03rSVIuLRKXjdk5ojitKJsiFq24lem0KKSqpRIZt55gE9/GM+lNL"
+ "Fgx4i1Ys2aHhIBXgAVICNggHgNm6rqJN2BZGv36/v/MOWeea/Z5RVHurLfRUsfZXOnccx522itrd53O0vL"
+ "qbaKYtsAKUe1pcege7hm9JNtzM8+kOOzNApIX0A3xBXE6YE7g0UWjg2OaZAJXbKvALOnj2GEHKc496ykLkt"
+ "gNt3Jz17hprCUxFqExe7YIpQkNpO1/kfHhPUdtUAdH2/gfmeYiIFW7IkM6IBP2wrDNbMe3Mjf2ksiK3Hjg"
+ "hg7F2DN9l/omZZl5Mmez2QRk0q4WUUB0+1oh9nDwxGdUXJdXPMRZQs352eGaRPV9s2lcMeZFGWBfKJJiw0Y"
+ "gbCMLBaRmXyy4flx6a667Fch55q05QOq2Jg2ANOyZwplhNsjiohVApo7aa21QnNGW5+4GXv8gxK1beBeHSR"
+ "rhmLXWVh+0aBhErZ7bx1ejxMOhlR6QU4ycNqGyk8/yNGCWkwY7/RCD7UEQek4QszCgDJAzZtfErA0VqHBy9"
+ "ugQP9pUfUmgCjVYgWNwHFbhBJyEOgSwBuuwARWZmoI6J9PwLfzEocpRpPrT8DP8wqHG0b4UX+E3DiscvRgl"
+ "XIoi81KKPwioHI5x9EooNKWiy0KOc/T6WF4SssrRuzJ9L2VNRXUhJzj6UKYfS4W/q/5wuh/l4M9R9qsU+y2"
+ "dpoo2hJzkaEET8r6KRONicnRdK9EbUi6raFVIwNGjsrlbpk6ZPi7TbS3fv3LyNjPiEKzG0aG0tvNb6xw90/"
+ "whe6ONjnJcUxobHDUqQ8bIOW79BVBLBwhfSmPKdAIAAE4EAABQSwMEAAAICAAAAAAAAAAAAAAAAAAAAAAAA"
+ "BkABQBzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsRkIBAFqAAikuUEsHCG0vCVcEAAAABAAAAFBLAwQAAAgI"
+ "AAAAAAAAAAAAAAAAAAAAAAAAEwA7AHNpbXBsZW1vZGVsL3ZlcnNpb25GQjcAWlpaWlpaWlpaWlpaWlpaWlp"
+ "aWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWjMKUEsHCNGeZ1UCAAAAAgAAAFBLAQIAAA"
+ "AACAgAAAAAAAAhOZuwWAAAAFgAAAAUAAAAAAAAAAAAAAAAAAAAAABzaW1wbGVtb2RlbC9kYXRhLnBrbFBLA"
+ "QIAABQACAgIAAAAAABUhNyGggEAACADAAAdAAAAAAAAAAAAAAAAAKgAAABzaW1wbGVtb2RlbC9jb2RlL19f"
+ "dG9yY2hfXy5weVBLAQIAABQACAgIAAAAAABfSmPKdAIAAE4EAAAnAAAAAAAAAAAAAAAAAJICAABzaW1wbGVt"
+ "b2RlbC9jb2RlL19fdG9yY2hfXy5weS5kZWJ1Z19wa2xQSwECAAAAAAgIAAAAAAAAbS8JVwQAAAAEAAAAGQAA"
+ "AAAAAAAAAAAAAACEBQAAc2ltcGxlbW9kZWwvY29uc3RhbnRzLnBrbFBLAQIAAAAACAgAAAAAAADRnmdVAgAA"
+ "AAIAAAATAAAAAAAAAAAAAAAAANQFAABzaW1wbGVtb2RlbC92ZXJzaW9uUEsGBiwAAAAAAAAAHgMtAAAAAAAA"
+ "AAAABQAAAAAAAAAFAAAAAAAAAGoBAAAAAAAAUgYAAAAAAABQSwYHAAAAALwHAAAAAAAAAQAAAFBLBQYAAAAABQAFAGoBAABSBgAAAAA=";

static final long RAW_MODEL_SIZE; // size of the model before base64 encoding
static {
RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length;
}

// Test a sparse embedding model deployed with the ml trained models APIs
public void testSparse() throws IOException {
String modelId = "custom-text-expansion-model";

createTextExpansionModel(modelId);
putModelDefinition(modelId, BASE_64_ENCODED_MODEL, RAW_MODEL_SIZE);
putVocabulary(
List.of("these", "are", "my", "words", "the", "washing", "machine", "is", "leaking", "octopus", "comforter", "smells"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunate

modelId
);

var inferenceConfig = """
{
"service": "elasticsearch",
"service_settings": {
"model_id": "custom-text-expansion-model",
"num_allocations": 1,
"num_threads": 1
}
}
""";

var inferenceId = "sparse-inf";
putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
var results = inferOnMockService(inferenceId, List.of("washing", "machine"));
deleteModel(inferenceId);
assertNotNull(results.get("sparse_embedding"));
}

protected void createTextExpansionModel(String modelId) throws IOException {
// with_special_tokens: false for this test with limited vocab
Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
request.setJsonEntity("""
{
"description": "a text expansion model",
"model_type": "pytorch",
"inference_config": {
"text_expansion": {
"tokenization": {
"bert": {
"with_special_tokens": false
}
}
}
}
}""");
client().performRequest(request);
}

protected void putVocabulary(List<String> vocabulary, String modelId) throws IOException {
List<String> vocabularyWithPad = new ArrayList<>();
vocabularyWithPad.add("[PAD]");
vocabularyWithPad.add("[UNK]");
vocabularyWithPad.addAll(vocabulary);
String quotedWords = vocabularyWithPad.stream().map(s -> "\"" + s + "\"").collect(Collectors.joining(","));

Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/vocabulary");
request.setJsonEntity(Strings.format("""
{ "vocabulary": [%s] }
""", quotedWords));
client().performRequest(request);
}

protected void putModelDefinition(String modelId, String base64EncodedModel, long unencodedModelSize) throws IOException {
Request request = new Request("PUT", "_ml/trained_models/" + modelId + "/definition/0");
String body = Strings.format("""
{"total_definition_length":%s,"definition": "%s","total_parts": 1}""", unencodedModelSize, base64EncodedModel);
request.setJsonEntity(body);
client().performRequest(request);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ private String putCohereRerankEndpoint() throws IOException {
"api_key": ""
}
}
""");// TODO remove key
""");
return endpointID;
}

Expand All @@ -61,7 +61,7 @@ private String putCohereRerankEndpointWithDocuments() throws IOException {
"return_documents": true
}
}
""");// TODO remove key
""");
return endpointID;
}

Expand All @@ -81,13 +81,13 @@ private String putCohereRerankEndpointWithTop2() throws IOException {
"service": "cohere",
"service_settings": {
"model_id": "rerank-english-v2.0",
"api_key": "8TNPBvpBO7oN97009HQHzQbBhNrxmREbcJrZCwkK"
"api_key": ""
},
"task_settings": {
"top_n": 2
}
}
""");// TODO remove key
""");
return endpointID;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,10 @@ public void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
executeAsyncWithOrigin(client, INFERENCE_ORIGIN, GetTrainedModelsAction.INSTANCE, getRequest, getModelsResponseListener);
} else {
listener.onFailure(
new IllegalArgumentException(
"Unable to determine supported model for ["
new IllegalStateException(
"Can not check the download status of the model used by ["
+ model.getConfigurations().getInferenceEntityId()
+ "] please verify the request and submit a bug report if necessary."
+ "] as the model_id cannot be found."
)
);
}
Expand Down
Loading