diff --git a/docs/changelog/124769.yaml b/docs/changelog/124769.yaml new file mode 100644 index 0000000000000..11d4cbebb0c9a --- /dev/null +++ b/docs/changelog/124769.yaml @@ -0,0 +1,7 @@ +pr: 124769 +summary: Migrate `model_version` to `model_id` when parsing persistent elser inference + endpoints +area: Machine Learning +type: bug +issues: + - 124675 diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index ddc5e3e1aa36c..f4d361ab319e8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; @@ -111,6 +112,13 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class); private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class); + /** + * Fix for https://github.com/elastic/elasticsearch/issues/124675 + * In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use + * service_settings.model_version. + */ + private static final String OLD_MODEL_ID_FIELD_NAME = "model_version"; + private final Settings settings; public ElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) { @@ -489,6 +497,8 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + migrateModelVersionToModelId(serviceSettingsMap); + ChunkingSettings chunkingSettings = null; if (TaskType.TEXT_EMBEDDING.equals(taskType) || TaskType.SPARSE_EMBEDDING.equals(taskType)) { chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); @@ -496,7 +506,9 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M String modelId = (String) serviceSettingsMap.get(MODEL_ID); if (modelId == null) { - throw new IllegalArgumentException("Error parsing request config, model id is missing"); + throw new IllegalArgumentException( + Strings.format("Error parsing request config, model id is missing for inference id: %s", inferenceEntityId) + ); } if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) { @@ -536,6 +548,18 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M } } + /** + * Fix for https://github.com/elastic/elasticsearch/issues/124675 + * In 8.13.0 we transitioned from model_version to model_id. Any elser inference endpoints created prior to 8.13.0 will still use + * service_settings.model_version. We need to look for that key and migrate it to model_id. + */ + private void migrateModelVersionToModelId(Map serviceSettingsMap) { + if (serviceSettingsMap.containsKey(OLD_MODEL_ID_FIELD_NAME)) { + String modelId = ServiceUtils.removeAsType(serviceSettingsMap, OLD_MODEL_ID_FIELD_NAME, String.class); + serviceSettingsMap.put(ElserInternalServiceSettings.MODEL_ID, modelId); + } + } + @Override public void checkModelConfig(Model model, ActionListener listener) { if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index d1ce79b863c61..d8886e1eea471 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -710,6 +710,30 @@ private ActionListener getElserModelVerificationActionListener( public void testParsePersistedConfig() { + // Parsing a persistent configuration using model_version succeeds + { + var service = createService(mock(Client.class)); + var settings = new HashMap(); + settings.put( + ModelConfigurations.SERVICE_SETTINGS, + new HashMap<>( + Map.of( + ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS, + 1, + ElasticsearchInternalServiceSettings.NUM_THREADS, + 4, + "model_version", + ".elser_model_2" + ) + ) + ); + + var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings); + assertThat(model, instanceOf(ElserInternalModel.class)); + ElserInternalModel elserInternalModel = (ElserInternalModel) model; + assertThat(elserInternalModel.getServiceSettings().modelId(), is(".elser_model_2")); + } + // Null model variant { var service = createService(mock(Client.class)); @@ -728,11 +752,12 @@ public void testParsePersistedConfig() { ) ); - expectThrows( + var exception = expectThrows( IllegalArgumentException.class, () -> service.parsePersistedConfig(randomInferenceEntityId, TaskType.TEXT_EMBEDDING, settings) ); + assertThat(exception.getMessage(), containsString(randomInferenceEntityId)); } // Invalid model variant