Skip to content

Commit bf53f97

Browse files
[ML] Migrate model_version to model_id when parsing persistent elser inference endpoints (#124769)
* Handling model_version for prexisting endpoints * Update docs/changelog/124769.yaml
1 parent 1bee2cc commit bf53f97

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

docs/changelog/124769.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
pr: 124769
2+
summary: Migrate `model_version` to `model_id` when parsing persistent elser inference
3+
endpoints
4+
area: Machine Learning
5+
type: bug
6+
issues:
7+
- 124675

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.elasticsearch.common.settings.Settings;
1919
import org.elasticsearch.common.util.LazyInitializable;
2020
import org.elasticsearch.core.Nullable;
21+
import org.elasticsearch.core.Strings;
2122
import org.elasticsearch.core.TimeValue;
2223
import org.elasticsearch.inference.ChunkedInference;
2324
import org.elasticsearch.inference.ChunkingSettings;
@@ -111,6 +112,13 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
111112
private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
112113
private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
113114

115+
/**
116+
* Fix for https://github.com/elastic/elasticsearch/issues/124675
117+
* In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
118+
* service_settings.model_version.
119+
*/
120+
private static final String OLD_MODEL_ID_FIELD_NAME = "model_version";
121+
114122
private final Settings settings;
115123

116124
public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
@@ -489,14 +497,18 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
489497
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
490498
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
491499

500+
migrateModelVersionToModelId(serviceSettingsMap);
501+
492502
ChunkingSettings chunkingSettings = null;
493503
if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.SPARSE_EMBEDDING.equals(taskType)) {
494504
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
495505
}
496506

497507
String modelId = (String) serviceSettingsMap.get(MODEL_ID);
498508
if (modelId == null) {
499-
throw new IllegalArgumentException("Error parsing request config, model id is missing");
509+
throw new IllegalArgumentException(
510+
Strings.format("Error parsing request config, model id is missing for inference id: %s", inferenceEntityId)
511+
);
500512
}
501513

502514
if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) {
@@ -536,6 +548,18 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
536548
}
537549
}
538550

551+
/**
552+
* Fix for https://github.com/elastic/elasticsearch/issues/124675
553+
* In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use
554+
* service_settings.model_version. We need to look for that key and migrate it to model_id.
555+
*/
556+
private void migrateModelVersionToModelId(Map<String, Object> serviceSettingsMap) {
557+
if (serviceSettingsMap.containsKey(OLD_MODEL_ID_FIELD_NAME)) {
558+
String modelId = ServiceUtils.removeAsType(serviceSettingsMap, OLD_MODEL_ID_FIELD_NAME, String.class);
559+
serviceSettingsMap.put(ElserInternalServiceSettings.MODEL_ID, modelId);
560+
}
561+
}
562+
539563
@Override
540564
public void checkModelConfig(Model model, ActionListener<Model> listener) {
541565
if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86;
101101
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.NAME;
102102
import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.OLD_ELSER_SERVICE_NAME;
103+
import static org.hamcrest.Matchers.containsString;
103104
import static org.hamcrest.Matchers.hasSize;
104105
import static org.hamcrest.Matchers.instanceOf;
105106
import static org.hamcrest.Matchers.is;
@@ -709,6 +710,30 @@ private ActionListener<Model> getElserModelVerificationActionListener(
709710

710711
public void testParsePersistedConfig() {
711712

713+
// Parsing a persistent configuration using model_version succeeds
714+
{
715+
var service = createService(mock(Client.class));
716+
var settings = new HashMap<String, Object>();
717+
settings.put(
718+
ModelConfigurations.SERVICE_SETTINGS,
719+
new HashMap<>(
720+
Map.of(
721+
ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS,
722+
1,
723+
ElasticsearchInternalServiceSettings.NUM_THREADS,
724+
4,
725+
"model_version",
726+
".elser_model_2"
727+
)
728+
)
729+
);
730+
731+
var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings);
732+
assertThat(model, instanceOf(ElserInternalModel.class));
733+
ElserInternalModel elserInternalModel = (ElserInternalModel) model;
734+
assertThat(elserInternalModel.getServiceSettings().modelId(), is(".elser_model_2"));
735+
}
736+
712737
// Null model variant
713738
{
714739
var service = createService(mock(Client.class));
@@ -727,11 +752,12 @@ public void testParsePersistedConfig() {
727752
)
728753
);
729754

730-
expectThrows(
755+
var exception = expectThrows(
731756
IllegalArgumentException.class,
732757
() -> service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings)
733758
);
734759

760+
assertThat(exception.getMessage(), containsString(randomInferenceEntityId));
735761
}
736762

737763
// Invalid model variant

0 commit comments

Comments
 (0)