Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/111684.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 111684
summary: Write downloaded model parts async
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,23 @@
import java.util.Locale;

public enum Level {
INFO,
WARNING,
ERROR;
INFO {
public org.apache.logging.log4j.Level log4jLevel() {
return org.apache.logging.log4j.Level.INFO;
}
},
WARNING {
public org.apache.logging.log4j.Level log4jLevel() {
return org.apache.logging.log4j.Level.WARN;
}
},
ERROR {
public org.apache.logging.log4j.Level log4jLevel() {
return org.apache.logging.log4j.Level.ERROR;
}
};

public abstract org.apache.logging.log4j.Level log4jLevel();

/**
* Case-insensitive from string method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

package org.elasticsearch.xpack.inference;

import org.apache.lucene.tests.util.LuceneTestCase;
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
Expand All @@ -19,11 +18,11 @@

import static org.hamcrest.Matchers.containsString;

// Tests disabled in CI due to the models being too large to download. Can be enabled (commented out) for local testing
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198")
// This test was previously disabled in CI due to the models being too large
// See "https://github.com/elastic/elasticsearch/issues/105198".
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {

public void testPutE5Small_withNoModelVariant() throws IOException {
public void testPutE5Small_withNoModelVariant() {
{
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
expectThrows(
Expand Down Expand Up @@ -51,6 +50,7 @@ public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
deleteTextEmbeddingModel(inferenceEntityId);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105198")
public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) {
Expand Down Expand Up @@ -124,7 +124,7 @@ private String noModelIdVariantJsonEntity() {
private String platformAgnosticModelVariantJsonEntity() {
return """
{
"service": "text_embedding",
"service": "elasticsearch",
"service_settings": {
"num_allocations": 1,
"num_threads": 1,
Expand All @@ -137,7 +137,7 @@ private String platformAgnosticModelVariantJsonEntity() {
private String platformSpecificModelVariantJsonEntity() {
return """
{
"service": "text_embedding",
"service": "elasticsearch",
"service_settings": {
"num_allocations": 1,
"num_threads": 1,
Expand All @@ -150,7 +150,7 @@ private String platformSpecificModelVariantJsonEntity() {
private String fakeModelVariantJsonEntity() {
return """
{
"service": "text_embedding",
"service": "elasticsearch",
"service_settings": {
"num_allocations": 1,
"num_threads": 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.threadpool.ExecutorBuilder;
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction;
import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask;
import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter;
import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction;
import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage;

Expand All @@ -44,16 +49,15 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin
Setting.Property.Dynamic
);

// re-using thread pool setup by the ml plugin
public static final String UTILITY_THREAD_POOL_NAME = "ml_utility";

// This link will be invalid for serverless, but serverless will never be
// air-gapped, so this message should never be needed.
private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format(
"https://www.elastic.co/guide/en/machine-learning/%s/ml-nlp-elser.html#air-gapped-install",
Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1")
);

public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download";

public MachineLearningPackageLoader() {}

@Override
Expand Down Expand Up @@ -81,6 +85,24 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
return List.of(modelDownloadExecutor(settings));
}

public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) {
// Threadpool with a fixed number of threads for
// downloading the model definition files
return new FixedExecutorBuilder(
settings,
MODEL_DOWNLOAD_THREADPOOL_NAME,
ModelImporter.NUMBER_OF_STREAMS,
-1, // unbounded queue size
"xpack.ml.model_download_thread_pool",
EsExecutors.TaskTrackingConfig.DO_NOT_TRACK
);
}

@Override
public List<BootstrapCheck> getBootstrapChecks() {
return List.of(new BootstrapCheck() {
Expand Down
Loading