|
59 | 59 | import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; |
60 | 60 | import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; |
61 | 61 | import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; |
62 | | -import org.elasticsearch.xpack.inference.services.ServiceUtils; |
| 62 | +import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder; |
63 | 63 |
|
64 | 64 | import java.util.ArrayList; |
65 | 65 | import java.util.Collections; |
@@ -499,49 +499,38 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M |
499 | 499 |
|
500 | 500 | @Override |
501 | 501 | public void checkModelConfig(Model model, ActionListener<Model> listener) { |
502 | | - if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) { |
503 | | - // At this point the inference endpoint configuration has not been persisted yet, if we attempt to do inference using the |
504 | | - // inference id we'll get an error because the trained model code needs to use the persisted inference endpoint to retrieve the |
505 | | - // model id. To get around this we'll have the getEmbeddingSize() method use the model id instead of inference id. So we need |
506 | | - // to create a temporary model that overrides the inference id with the model id. |
507 | | - var temporaryModelWithModelId = new CustomElandEmbeddingModel( |
508 | | - elandModel.getServiceSettings().modelId(), |
509 | | - elandModel.getTaskType(), |
510 | | - elandModel.getConfigurations().getService(), |
511 | | - elandModel.getServiceSettings(), |
512 | | - elandModel.getConfigurations().getChunkingSettings() |
| 502 | + ModelValidatorBuilder.buildModelValidator(model.getTaskType(), true).validate(this, model, listener); |
| 503 | + } |
| 504 | + |
| 505 | + @Override |
| 506 | + public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { |
| 507 | + if (model instanceof CustomElandEmbeddingModel embeddingsModel) { |
| 508 | + var serviceSettings = embeddingsModel.getServiceSettings(); |
| 509 | + |
| 510 | + var updatedServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings( |
| 511 | + serviceSettings.getNumAllocations(), |
| 512 | + serviceSettings.getNumThreads(), |
| 513 | + serviceSettings.modelId(), |
| 514 | + serviceSettings.getAdaptiveAllocationsSettings(), |
| 515 | + embeddingSize, |
| 516 | + serviceSettings.similarity(), |
| 517 | + serviceSettings.elementType() |
513 | 518 | ); |
514 | 519 |
|
515 | | - ServiceUtils.getEmbeddingSize( |
516 | | - temporaryModelWithModelId, |
517 | | - this, |
518 | | - listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(elandModel, size))) |
| 520 | + return new CustomElandEmbeddingModel( |
| 521 | + model.getInferenceEntityId(), |
| 522 | + model.getTaskType(), |
| 523 | + model.getConfigurations().getService(), |
| 524 | + updatedServiceSettings, |
| 525 | + model.getConfigurations().getChunkingSettings() |
519 | 526 | ); |
520 | 527 | } else { |
521 | | - listener.onResponse(model); |
| 528 | + // TODO: This is for the E5 case which is text embedding but we didn't previously update the dimensions. Figure out if we do |
| 529 | + // need to update the dimensions? |
| 530 | + return model; |
522 | 531 | } |
523 | 532 | } |
524 | 533 |
|
525 | | - private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomElandEmbeddingModel model, int embeddingSize) { |
526 | | - CustomElandInternalTextEmbeddingServiceSettings serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings( |
527 | | - model.getServiceSettings().getNumAllocations(), |
528 | | - model.getServiceSettings().getNumThreads(), |
529 | | - model.getServiceSettings().modelId(), |
530 | | - model.getServiceSettings().getAdaptiveAllocationsSettings(), |
531 | | - embeddingSize, |
532 | | - model.getServiceSettings().similarity(), |
533 | | - model.getServiceSettings().elementType() |
534 | | - ); |
535 | | - |
536 | | - return new CustomElandEmbeddingModel( |
537 | | - model.getInferenceEntityId(), |
538 | | - model.getTaskType(), |
539 | | - model.getConfigurations().getService(), |
540 | | - serviceSettings, |
541 | | - model.getConfigurations().getChunkingSettings() |
542 | | - ); |
543 | | - } |
544 | | - |
545 | 534 | @Override |
546 | 535 | public void infer( |
547 | 536 | Model model, |
@@ -904,7 +893,10 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) { |
904 | 893 |
|
905 | 894 | @Override |
906 | 895 | boolean isDefaultId(String inferenceId) { |
907 | | - return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId); |
| 896 | + // return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId); |
| 897 | + // TODO: This is a temporary override to ensure that we always deploy models on infer to run a validation call. |
| 898 | + // Figure out if this is what we actually want to do? |
| 899 | + return true; |
908 | 900 | } |
909 | 901 |
|
910 | 902 | static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings( |
|
0 commit comments