Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/changelog/124769.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pr: 124769
summary: Migrate `model_version` to `model_id` when parsing persistent elser inference
endpoints
area: Machine Learning
type: bug
issues:
- 124675
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.LazyInitializable;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkedInference;
import org.elasticsearch.inference.ChunkingSettings;
Expand Down Expand Up @@ -111,6 +112,13 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);

/**
* Fix for https://github.com/elastic/elasticsearch/issues/124675
* In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
* service_settings.model_version.
*/
private static final String OLD_MODEL_ID_FIELD_NAME = "model_version";

private final Settings settings;

public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
Expand Down Expand Up @@ -489,14 +497,18 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);

migrateModelVersionToModelId(serviceSettingsMap);

ChunkingSettings chunkingSettings = null;
if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.SPARSE_EMBEDDING.equals(taskType)) {
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
}

String modelId = (String) serviceSettingsMap.get(MODEL_ID);
if (modelId == null) {
throw new IllegalArgumentException("Error parsing request config, model id is missing");
throw new IllegalArgumentException(
Strings.format("Error parsing request config, model id is missing for inference id: %s", inferenceEntityId)
);
}

if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) {
Expand Down Expand Up @@ -536,6 +548,18 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
}
}

/**
* Fix for https://github.com/elastic/elasticsearch/issues/124675
* In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
* service_settings.model_version. We need to look for that key and migrate it to model_id.
*/
private void migrateModelVersionToModelId(Map<String, Object> serviceSettingsMap) {
if (serviceSettingsMap.containsKey(OLD_MODEL_ID_FIELD_NAME)) {
String modelId = ServiceUtils.removeAsType(serviceSettingsMap, OLD_MODEL_ID_FIELD_NAME, String.class);
serviceSettingsMap.put(ElserInternalServiceSettings.MODEL_ID, modelId);
}
}

@Override
public void checkModelConfig(Model model, ActionListener<Model> listener) {
if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME;
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
Expand Down Expand Up @@ -709,6 +710,30 @@ private ActionListener<Model> getElserModelVerificationActionListener(

public void testParsePersistedConfig() {

// Parsing a persistent configuration using model_version succeeds
{
var service = createService(mock(Client.class));
var settings = new HashMap<String, Object>();
settings.put(
ModelConfigurations.SERVICE_SETTINGS,
new HashMap<>(
Map.of(
ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS,
1,
ElasticsearchInternalServiceSettings.NUM_THREADS,
4,
"model_version",
".elser_model_2"
)
)
);

var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings);
assertThat(model, instanceOf(ElserInternalModel.class));
ElserInternalModel elserInternalModel = (ElserInternalModel) model;
assertThat(elserInternalModel.getServiceSettings().modelId(), is(".elser_model_2"));
}

// Null model variant
{
var service = createService(mock(Client.class));
Expand All @@ -727,11 +752,12 @@ public void testParsePersistedConfig() {
)
);

expectThrows(
var exception = expectThrows(
IllegalArgumentException.class,
() -> service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings)
);

assertThat(exception.getMessage(), containsString(randomInferenceEntityId));
}

// Invalid model variant
Expand Down