Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,9 @@ tests:
- class: org.elasticsearch.xpack.transform.integration.TransformIT
method: testStopWaitForCheckpoint
issue: https://github.com/elastic/elasticsearch/issues/106113
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
method: testPutE5Small_withPlatformAgnosticVariant
issue: https://github.com/elastic/elasticsearch/issues/113983
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
method: testPutE5WithTrainedModelAndInference
issue: https://github.com/elastic/elasticsearch/issues/114023
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
method: testPutE5Small_withPlatformSpecificVariant
issue: https://github.com/elastic/elasticsearch/issues/113950
- class: org.elasticsearch.xpack.remotecluster.RemoteClusterSecurityWithApmTracingRestIT
method: testTracingCrossCluster
issue: https://github.com/elastic/elasticsearch/issues/112731
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
method: testInferDeploysDefaultE5
issue: https://github.com/elastic/elasticsearch/issues/115361
- class: org.elasticsearch.xpack.restart.MLModelDeploymentFullClusterRestartIT
method: testDeploymentSurvivesRestart {cluster=UPGRADED}
issue: https://github.com/elastic/elasticsearch/issues/115528
Expand Down Expand Up @@ -110,9 +98,6 @@ tests:
- class: org.elasticsearch.xpack.apmdata.APMYamlTestSuiteIT
method: test {yaml=/10_apm/Test template reinstallation}
issue: https://github.com/elastic/elasticsearch/issues/116445
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
method: testMultipleInferencesTriggeringDownloadAndDeploy
issue: https://github.com/elastic/elasticsearch/issues/117208
- class: org.elasticsearch.ingest.geoip.EnterpriseGeoIpDownloaderIT
method: testEnterpriseDownloaderTask
issue: https://github.com/elastic/elasticsearch/issues/115163
Expand All @@ -125,9 +110,6 @@ tests:
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
method: test {p0=snapshot/10_basic/Create a source only snapshot and then restore it}
issue: https://github.com/elastic/elasticsearch/issues/117295
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
method: testInferDeploysDefaultElser
issue: https://github.com/elastic/elasticsearch/issues/114913
- class: org.elasticsearch.xpack.inference.InferenceRestIT
method: test {p0=inference/40_semantic_text_query/Query a field that uses the default ELSER 2 endpoint}
issue: https://github.com/elastic/elasticsearch/issues/117027
Expand All @@ -152,9 +134,6 @@ tests:
- class: org.elasticsearch.xpack.remotecluster.CrossClusterEsqlRCS2UnavailableRemotesIT
method: testEsqlRcs2UnavailableRemoteScenarios
issue: https://github.com/elastic/elasticsearch/issues/117419
- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT
method: testInferDeploysDefaultRerank
issue: https://github.com/elastic/elasticsearch/issues/118184
- class: org.elasticsearch.xpack.esql.action.EsqlActionTaskIT
method: testCancelRequestWhenFailingFetchingPages
issue: https://github.com/elastic/elasticsearch/issues/118193
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
apply plugin: 'elasticsearch.internal-java-rest-test'

dependencies {
javaRestTestImplementation project(path: xpackModule('core'))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to get XPackSettings.ML_NATIVE_CODE_PLATFORMS into the model server

javaRestTestImplementation project(path: xpackModule('inference'))
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.junit.Before;
import org.junit.ClassRule;

import java.io.IOException;
Expand All @@ -41,6 +42,7 @@
import static org.hamcrest.Matchers.hasSize;

public class InferenceBaseRestTest extends ESRestTestCase {

@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
Expand All @@ -51,6 +53,22 @@ public class InferenceBaseRestTest extends ESRestTestCase {
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
.build();

@ClassRule
public static MlModelServer mlModelServer = new MlModelServer();

@Before
public void setMlModelRepository() throws IOException {
logger.info("setting ML model repository to: {}", mlModelServer.getUrl());
var request = new Request("PUT", "/_cluster/settings");
request.setJsonEntity(Strings.format("""
{
"persistent": {
"xpack.ml.model_repository": "%s"
}
}""", mlModelServer.getUrl()));
assertOK(client().performRequest(request));
}

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer;

import org.apache.http.HttpHeaders;
import org.apache.http.HttpStatus;
import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.test.fixture.HttpHeaderParser;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
* Simple model server to serve ML models.
* The URL path corresponds to a file name in this class's resources.
* If the file is found, its content is returned, otherwise 404.
* Respects a range header to serve partial content.
*/
public class MlModelServer implements TestRule {

private static final String HOST = "localhost";
private static final Logger logger = LogManager.getLogger(MlModelServer.class);

private int port;

public String getUrl() {
return new URIBuilder().setScheme("http").setHost(HOST).setPort(port).toString();
}

private void handle(HttpExchange exchange) throws IOException {
String rangeHeader = exchange.getRequestHeaders().getFirst(HttpHeaders.RANGE);
HttpHeaderParser.Range range = rangeHeader != null ? HttpHeaderParser.parseRangeHeader(rangeHeader) : null;
logger.info("request: {} range={}", exchange.getRequestURI().getPath(), range);

try (InputStream is = getInputStream(exchange)) {
int httpStatus;
long numBytes;
if (is == null) {
httpStatus = HttpStatus.SC_NOT_FOUND;
numBytes = 0;
} else if (range == null) {
httpStatus = HttpStatus.SC_OK;
numBytes = is.available();
} else {
httpStatus = HttpStatus.SC_PARTIAL_CONTENT;
is.skipNBytes(range.start());
numBytes = range.end() - range.start() + 1;
}
logger.info("response: {} {}", exchange.getRequestURI().getPath(), httpStatus);
exchange.sendResponseHeaders(httpStatus, numBytes);
try (OutputStream os = exchange.getResponseBody()) {
while (numBytes > 0) {
byte[] bytes = is.readNBytes((int) Math.min(1 << 20, numBytes));
os.write(bytes);
numBytes -= bytes.length;
}
}
}
}

private InputStream getInputStream(HttpExchange exchange) throws IOException {
String path = exchange.getRequestURI().getPath().substring(1); // Strip leading slash
String modelId = path.substring(0, path.indexOf('.'));
String extension = path.substring(path.indexOf('.') + 1);

// If a model specifically optimized for some platform is requested,
// serve the default non-optimized model instead, which is compatible.
String defaultModelId = modelId;
for (String platform : XPackSettings.ML_NATIVE_CODE_PLATFORMS) {
defaultModelId = defaultModelId.replace("_" + platform, "");
}

ClassLoader classloader = Thread.currentThread().getContextClassLoader();
InputStream is = classloader.getResourceAsStream(defaultModelId + "." + extension);
if (is != null && modelId.equals(defaultModelId) == false && extension.equals("metadata.json")) {
// When an optimized version is requested, fix the default metadata,
// so that it contains the correct model ID.
try (XContentParser parser = XContentType.JSON.xContent().createParser(XContentParserConfiguration.EMPTY, is.readAllBytes())) {
is.close();
ModelPackageConfig packageConfig = ModelPackageConfig.fromXContentLenient(parser);
packageConfig = new ModelPackageConfig.Builder(packageConfig).setPackedModelId(modelId).build();
is = new ByteArrayInputStream(packageConfig.toString().getBytes(StandardCharsets.UTF_8));
}
}
return is;
}

@Override
public Statement apply(Statement statement, Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
logger.info("Starting ML model server");
HttpServer server = HttpServer.create();
while (true) {
port = new Random().nextInt(10000, 65536);
try {
server.bind(new InetSocketAddress(HOST, port), 1);
} catch (Exception e) {
continue;
}
break;
}
logger.info("Bound ML model server to port {}", port);

ExecutorService executor = Executors.newCachedThreadPool();
server.setExecutor(executor);
server.createContext("/", MlModelServer.this::handle);
server.start();

try {
statement.evaluate();
} finally {
logger.info("Stopping ML model server on port {}", port);
server.stop(1);
executor.shutdown();
}
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

import static org.hamcrest.Matchers.containsString;

// This test was previously disabled in CI due to the models being too large
// See "https://github.com/elastic/elasticsearch/issues/105198".
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {

public void testPutE5Small_withNoModelVariant() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"packaged_model_id": "elser_model_2",
"minimum_version": "11.0.0",
"size": 1859242,
"sha256": "602dbccfb2746e5700bf65d8019b06fb2ec1e3c5bfb980eb2005fc17c1bfe0c0",
"description": "Elastic Learned Sparse EncodeR v2",
"model_type": "pytorch",
"tags": [
"elastic"
],
"inference_config": {
"text_expansion": {
"tokenization": {
"bert": {
"do_lower_case": true,
"with_special_tokens": true,
"max_sequence_length": 512,
"truncate": "first",
"span": -1
}
}
}
},
"vocabulary_file": "elser_model_2.vocab.json"
}
Binary file not shown.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"packaged_model_id": "multilingual-e5-small",
"minimum_version": "12.0.0",
"size": 5531160,
"sha256": "92e24566eff554d3a6808cc62731dbecf32db63e01801f3f62210aa9131c7a8b",
"description": "E5 small multilingual",
"model_type": "pytorch",
"tags": [],
"inference_config": {
"text_embedding": {
"tokenization": {
"xlm_roberta": {
"do_lower_case": false,
"with_special_tokens": true,
"max_sequence_length": 512,
"truncate": "first",
"span": -1
}
},
"embedding_size": 384
}
},
"prefix_strings": {
"search": "query: ",
"ingest": "passage: "
},
"metadata": {
"per_allocation_memory_bytes": 557785256,
"per_deployment_memory_bytes": 470031872
},
"vocabulary_file": "multilingual-e5-small.vocab.json"
}
Binary file not shown.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"packaged_model_id": "rerank-v1",
"minimum_version": "9.0.0",
"size": 12419194,
"sha256": "8d37d7240175b59a1a82f409e572c4d0136acff875da980ec5e5e1783263a042",
"description": "Elastic Rerank v1",
"model_type": "pytorch",
"tags": [
"curated"
],
"inference_config": {
"text_similarity": {"tokenization": {"deberta_v2": {"truncate": "balanced"}}}
},
"vocabulary_file": "rerank-v1.vocab.json"
}
Binary file not shown.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,17 @@ static InputStream getFileInputStream(URI uri) {
* in size. The separate range for the final chunk is because when streaming and
* uploading a large model definition, writing the last part has to handled
* as a special case.
* Less ranges may be returned in case the stream size is too small.
* @param sizeInBytes The total size of the stream
* @param numberOfStreams Divide the bulk of the size into this many streams.
* @param chunkSizeBytes The size of each chunk
* @return List of {@code numberOfStreams} + 1 ranges.
*/
static List<RequestRange> split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) {
int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes);

if (numberOfStreams > numberOfChunks) {
numberOfStreams = numberOfChunks;
}
var ranges = new ArrayList<RequestRange>();

int baseChunksPerStream = numberOfChunks / numberOfStreams;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.rest.RestStatus;
Expand All @@ -45,11 +44,9 @@
public class TransportGetTrainedModelPackageConfigAction extends TransportMasterNodeAction<Request, Response> {

private static final Logger logger = LogManager.getLogger(TransportGetTrainedModelPackageConfigAction.class);
private final Settings settings;

@Inject
public TransportGetTrainedModelPackageConfigAction(
Settings settings,
TransportService transportService,
ClusterService clusterService,
ThreadPool threadPool,
Expand All @@ -67,12 +64,11 @@ public TransportGetTrainedModelPackageConfigAction(
GetTrainedModelPackageConfigAction.Response::new,
EsExecutors.DIRECT_EXECUTOR_SERVICE
);
this.settings = settings;
}

@Override
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<Response> listener) throws Exception {
String repository = MachineLearningPackageLoader.MODEL_REPOSITORY.get(settings);
String repository = clusterService.getClusterSettings().get(MachineLearningPackageLoader.MODEL_REPOSITORY);

String packagedModelId = request.getPackagedModelId();
logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository));
Expand Down
Loading
Loading