3333import org .elasticsearch .xpack .core .inference .results .RankedDocsResults ;
3434import org .elasticsearch .xpack .core .inference .results .SparseEmbeddingResults ;
3535import org .elasticsearch .xpack .core .ml .action .GetTrainedModelsAction ;
36+ import org .elasticsearch .xpack .core .ml .action .GetTrainedModelsStatsAction ;
3637import org .elasticsearch .xpack .core .ml .action .InferModelAction ;
3738import org .elasticsearch .xpack .core .ml .inference .assignment .AdaptiveAllocationsSettings ;
39+ import org .elasticsearch .xpack .core .ml .inference .assignment .AssignmentStats ;
3840import org .elasticsearch .xpack .core .ml .inference .results .ErrorInferenceResults ;
3941import org .elasticsearch .xpack .core .ml .inference .results .MlTextEmbeddingResults ;
4042import org .elasticsearch .xpack .core .ml .inference .results .TextExpansionResults ;
4143import org .elasticsearch .xpack .core .ml .inference .trainedmodel .EmptyConfigUpdate ;
44+ import org .elasticsearch .xpack .core .ml .inference .trainedmodel .InferenceConfig ;
45+ import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextEmbeddingConfig ;
4246import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextEmbeddingConfigUpdate ;
47+ import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextExpansionConfig ;
4348import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextExpansionConfigUpdate ;
49+ import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextSimilarityConfig ;
4450import org .elasticsearch .xpack .core .ml .inference .trainedmodel .TextSimilarityConfigUpdate ;
4551import org .elasticsearch .xpack .inference .chunking .ChunkingSettingsBuilder ;
4652import org .elasticsearch .xpack .inference .chunking .EmbeddingRequestChunker ;
5258import java .util .EnumSet ;
5359import java .util .List ;
5460import java .util .Map ;
61+ import java .util .Optional ;
5562import java .util .Set ;
5663import java .util .function .Consumer ;
5764import java .util .function .Function ;
@@ -134,7 +141,10 @@ public void parseRequestConfig(
134141 throwIfNotEmptyMap (config , name ());
135142
136143 String modelId = (String ) serviceSettingsMap .get (ElasticsearchInternalServiceSettings .MODEL_ID );
137- if (modelId == null ) {
144+ String deploymentId = (String ) serviceSettingsMap .get (ElasticsearchInternalServiceSettings .DEPLOYMENT_ID );
145+ if (deploymentId != null ) {
146+ validateAgainstDeployment (modelId , deploymentId , taskType , ) // TODO create model
147+ } else if (modelId == null ) {
138148 if (OLD_ELSER_SERVICE_NAME .equals (serviceName )) {
139149 // TODO complete deprecation of null model ID
140150 // throw new ValidationException().addValidationError("Error parsing request config, model id is missing");
@@ -215,6 +225,8 @@ private void customElandCase(
215225 + "]. You may need to load it into the cluster using eland."
216226 );
217227 } else {
228+ throwIfUnsupportedTaskType (modelId , taskType , response .getResources ().results ().get (0 ).getInferenceConfig ());
229+
218230 var model = createCustomElandModel (
219231 inferenceEntityId ,
220232 taskType ,
@@ -553,7 +565,7 @@ public void inferTextEmbedding(
553565 ActionListener <InferenceServiceResults > listener
554566 ) {
555567 var request = buildInferenceRequest (
556- model .getConfigurations (). getInferenceEntityId (),
568+ model .mlNodeDeploymentId (),
557569 TextEmbeddingConfigUpdate .EMPTY_INSTANCE ,
558570 inputs ,
559571 inputType ,
@@ -579,7 +591,7 @@ public void inferSparseEmbedding(
579591 ActionListener <InferenceServiceResults > listener
580592 ) {
581593 var request = buildInferenceRequest (
582- model .getConfigurations (). getInferenceEntityId (),
594+ model .mlNodeDeploymentId (),
583595 TextExpansionConfigUpdate .EMPTY_UPDATE ,
584596 inputs ,
585597 inputType ,
@@ -607,7 +619,7 @@ public void inferRerank(
607619 ActionListener <InferenceServiceResults > listener
608620 ) {
609621 var request = buildInferenceRequest (
610- model .getConfigurations (). getInferenceEntityId (),
622+ model .mlNodeDeploymentId (),
611623 new TextSimilarityConfigUpdate (query ),
612624 inputs ,
613625 inputType ,
@@ -681,7 +693,7 @@ public void chunkedInfer(
681693
682694 for (var batch : batchedRequests ) {
683695 var inferenceRequest = buildInferenceRequest (
684- model . getConfigurations (). getInferenceEntityId (),
696+ esModel . mlNodeDeploymentId (),
685697 EmptyConfigUpdate .INSTANCE ,
686698 batch .batch ().inputs (),
687699 inputType ,
@@ -858,4 +870,111 @@ static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSetting
858870 );
859871 };
860872 }
873+
874+
875+ private void validateAgainstDeployment (
876+ String modelId ,
877+ String deploymentId ,
878+ TaskType taskType ,
879+ ActionListener <ElasticsearchInternalServiceSettings .Builder > listener
880+ ) {
881+ getDeployment (deploymentId , listener .delegateFailureAndWrap ((l , response ) -> {
882+ if (response .isPresent ()) {
883+ if (modelId .equals (response .get ().getModelId ()) == false ) {
884+ listener .onFailure (
885+ new ElasticsearchStatusException (
886+ "Deployment [{}] uses model [{}] which does not match the model [{}] in the request." ,
887+ RestStatus .BAD_REQUEST , // TODO better message
888+ deploymentId ,
889+ response .get ().getModelId (),
890+ modelId
891+ )
892+ );
893+ return ;
894+ }
895+
896+ var updatedSettings = new ElasticsearchInternalServiceSettings .Builder ().setNumAllocations (
897+ response .get ().getNumberOfAllocations ()
898+ )
899+ .setNumThreads (response .get ().getThreadsPerAllocation ())
900+ .setAdaptiveAllocationsSettings (response .get ().getAdaptiveAllocationsSettings ())
901+ .setDeploymentId (deploymentId )
902+ .setModelId (modelId );
903+
904+ checkTaskTypeForMlNodeModel (
905+ response .get ().getModelId (),
906+ taskType ,
907+ l .delegateFailureAndWrap ((l2 , compatibleTaskType ) -> { l2 .onResponse (updatedSettings ); })
908+ );
909+ }
910+ }));
911+ }
912+
913+ private void getDeployment (String deploymentId , ActionListener <Optional <AssignmentStats >> listener ) {
914+ client .execute (
915+ GetTrainedModelsStatsAction .INSTANCE ,
916+ new GetTrainedModelsStatsAction .Request (deploymentId ),
917+ listener .delegateFailureAndWrap ((l , response ) -> {
918+ l .onResponse (
919+ response .getResources ()
920+ .results ()
921+ .stream ()
922+ .filter (s -> s .getDeploymentStats () != null && s .getDeploymentStats ().getDeploymentId ().equals (deploymentId ))
923+ .map (GetTrainedModelsStatsAction .Response .TrainedModelStats ::getDeploymentStats )
924+ .findFirst ()
925+ );
926+ })
927+ );
928+ }
929+
930+ private void checkTaskTypeForMlNodeModel (String modelId , TaskType taskType , ActionListener <Boolean > listener ) {
931+ client .execute (
932+ GetTrainedModelsAction .INSTANCE ,
933+ new GetTrainedModelsAction .Request (modelId ),
934+ listener .delegateFailureAndWrap ((l , response ) -> {
935+ if (response .getResources ().results ().isEmpty ()) {
936+ l .onFailure (new IllegalStateException ("this shouldn't happen" ));
937+ return ;
938+ }
939+
940+ var inferenceConfig = response .getResources ().results ().get (0 ).getInferenceConfig ();
941+ throwIfUnsupportedTaskType (modelId , taskType , inferenceConfig );
942+ l .onResponse (Boolean .TRUE );
943+ })
944+ );
945+ }
946+
947+ static void throwIfUnsupportedTaskType (String modelId , TaskType taskType , InferenceConfig inferenceConfig ) {
948+ var deploymentTaskType = inferenceConfigToTaskType (inferenceConfig );
949+ if (deploymentTaskType == null ) {
950+ throw new ElasticsearchStatusException (
951+ "Deployed model [{}] has type [{}] which does not map to any supported task types" ,
952+ RestStatus .BAD_REQUEST ,
953+ modelId ,
954+ inferenceConfig .getWriteableName ()
955+ );
956+ }
957+ if (deploymentTaskType != taskType ) {
958+ throw new ElasticsearchStatusException (
959+ "Deployed model [{}] with type [{}] does not match the requested task type [{}]" ,
960+ RestStatus .BAD_REQUEST ,
961+ modelId ,
962+ inferenceConfig .getWriteableName (),
963+ taskType
964+ );
965+ }
966+
967+ }
968+
969+ static TaskType inferenceConfigToTaskType (InferenceConfig config ) {
970+ if (config instanceof TextExpansionConfig ) {
971+ return TaskType .SPARSE_EMBEDDING ;
972+ } else if (config instanceof TextEmbeddingConfig ) {
973+ return TaskType .TEXT_EMBEDDING ;
974+ } else if (config instanceof TextSimilarityConfig ) {
975+ return TaskType .RERANK ;
976+ } else {
977+ return null ;
978+ }
979+ }
861980}
0 commit comments