2121import org .elasticsearch .action .bulk .BulkItemResponse ;
2222import org .elasticsearch .action .bulk .BulkResponse ;
2323import org .elasticsearch .action .index .IndexRequest ;
24+ import org .elasticsearch .action .index .IndexRequestBuilder ;
2425import org .elasticsearch .action .search .SearchRequest ;
2526import org .elasticsearch .action .search .SearchResponse ;
2627import org .elasticsearch .action .support .GroupedActionListener ;
@@ -531,11 +532,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
531532
532533 SubscribableListener .<BulkResponse >newForked ((subListener ) -> {
533534 // in this block, we try to update the stored model configurations
534- IndexRequest configRequest = createIndexRequest (
535- Model . documentId ( inferenceEntityId ) ,
535+ var configRequestBuilder = createIndexRequestBuilder (
536+ inferenceEntityId ,
536537 InferenceIndex .INDEX_NAME ,
537538 newModel .getConfigurations (),
538- true
539+ true ,
540+ client
539541 );
540542
541543 ActionListener <BulkResponse > storeConfigListener = subListener .delegateResponse ((l , e ) -> {
@@ -544,7 +546,10 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
544546 l .onFailure (e );
545547 });
546548
547- client .prepareBulk ().add (configRequest ).setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE ).execute (storeConfigListener );
549+ client .prepareBulk ()
550+ .add (configRequestBuilder )
551+ .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
552+ .execute (storeConfigListener );
548553
549554 }).<BulkResponse >andThen ((subListener , configResponse ) -> {
550555 // in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets
@@ -569,11 +574,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
569574 );
570575 } else {
571576 // Since the model configurations were successfully updated, we can now try to store the new secrets
572- IndexRequest secretsRequest = createIndexRequest (
573- Model . documentId ( newModel . getConfigurations (). getInferenceEntityId ()) ,
577+ var secretsRequestBuilder = createIndexRequestBuilder (
578+ inferenceEntityId ,
574579 InferenceSecretsIndex .INDEX_NAME ,
575580 newModel .getSecrets (),
576- true
581+ true ,
582+ client
577583 );
578584
579585 ActionListener <BulkResponse > storeSecretsListener = subListener .delegateResponse ((l , e ) -> {
@@ -583,20 +589,22 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
583589 });
584590
585591 client .prepareBulk ()
586- .add (secretsRequest )
592+ .add (secretsRequestBuilder )
587593 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
588594 .execute (storeSecretsListener );
589595 }
590596 }).<BulkResponse >andThen ((subListener , secretsResponse ) -> {
591597 // in this block, we respond to the success or failure of updating the model secrets
592598 if (secretsResponse .hasFailures ()) {
593599 // since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations
594- IndexRequest configRequest = createIndexRequest (
595- Model . documentId ( inferenceEntityId ) ,
600+ var configRequestBuilder = createIndexRequestBuilder (
601+ inferenceEntityId ,
596602 InferenceIndex .INDEX_NAME ,
597603 existingModel .getConfigurations (),
598- true
604+ true ,
605+ client
599606 );
607+
600608 logger .error (
601609 "Failed to update inference endpoint secrets [{}], attempting rolling back to previous state" ,
602610 inferenceEntityId
@@ -608,7 +616,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
608616 l .onFailure (e );
609617 });
610618 client .prepareBulk ()
611- .add (configRequest )
619+ .add (configRequestBuilder )
612620 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
613621 .execute (rollbackConfigListener );
614622 } else {
@@ -655,24 +663,25 @@ public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue
655663
656664 private void storeModel (Model model , boolean updateClusterState , ActionListener <Boolean > listener , TimeValue timeout ) {
657665 ActionListener <BulkResponse > bulkResponseActionListener = getStoreIndexListener (model , updateClusterState , listener , timeout );
658-
659- IndexRequest configRequest = createIndexRequest (
660- Model . documentId ( model . getConfigurations (). getInferenceEntityId ()) ,
666+ String inferenceEntityId = model . getConfigurations (). getInferenceEntityId ();
667+ var configRequestBuilder = createIndexRequestBuilder (
668+ inferenceEntityId ,
661669 InferenceIndex .INDEX_NAME ,
662670 model .getConfigurations (),
663- false
671+ false ,
672+ client
664673 );
665-
666- IndexRequest secretsRequest = createIndexRequest (
667- Model .documentId (model .getConfigurations ().getInferenceEntityId ()),
674+ var secretsRequestBuilder = createIndexRequestBuilder (
675+ inferenceEntityId ,
668676 InferenceSecretsIndex .INDEX_NAME ,
669677 model .getSecrets (),
670- false
678+ false ,
679+ client
671680 );
672681
673682 client .prepareBulk ()
674- .add (configRequest )
675- .add (secretsRequest )
683+ .add (configRequestBuilder )
684+ .add (secretsRequestBuilder )
676685 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
677686 .execute (bulkResponseActionListener );
678687 }
@@ -683,15 +692,24 @@ private ActionListener<BulkResponse> getStoreIndexListener(
683692 ActionListener <Boolean > listener ,
684693 TimeValue timeout
685694 ) {
695+ // If there was a partial failure in writing to the indices, we need to clean up
696+ AtomicBoolean partialFailure = new AtomicBoolean (false );
697+ var cleanupListener = listener .delegateResponse ((delegate , ex ) -> {
698+ if (partialFailure .get ()) {
699+ deleteModel (model .getInferenceEntityId (), ActionListener .running (() -> delegate .onFailure (ex )));
700+ } else {
701+ delegate .onFailure (ex );
702+ }
703+ });
686704 return ActionListener .wrap (bulkItemResponses -> {
687- var inferenceEntityId = model .getConfigurations (). getInferenceEntityId ();
705+ var inferenceEntityId = model .getInferenceEntityId ();
688706
689707 if (bulkItemResponses .getItems ().length == 0 ) {
690708 logger .warn (
691709 format ("Storing inference endpoint [%s] failed, no items were received from the bulk response" , inferenceEntityId )
692710 );
693711
694- listener .onFailure (
712+ cleanupListener .onFailure (
695713 new ElasticsearchStatusException (
696714 format (
697715 "Failed to store inference endpoint [%s], invalid bulk response received. Try reinitializing the service" ,
@@ -707,7 +725,7 @@ private ActionListener<BulkResponse> getStoreIndexListener(
707725
708726 if (failure == null ) {
709727 if (updateClusterState ) {
710- var storeListener = getStoreMetadataListener (inferenceEntityId , listener );
728+ var storeListener = getStoreMetadataListener (inferenceEntityId , cleanupListener );
711729 try {
712730 metadataTaskQueue .submitTask (
713731 "add model [" + inferenceEntityId + "]" ,
@@ -723,29 +741,32 @@ private ActionListener<BulkResponse> getStoreIndexListener(
723741 storeListener .onFailure (exc );
724742 }
725743 } else {
726- listener .onResponse (Boolean .TRUE );
744+ cleanupListener .onResponse (Boolean .TRUE );
727745 }
728746 return ;
729747 }
730748
731- logBulkFailures (model .getConfigurations ().getInferenceEntityId (), bulkItemResponses );
749+ for (BulkItemResponse aResponse : bulkItemResponses .getItems ()) {
750+ logBulkFailure (inferenceEntityId , aResponse );
751+ partialFailure .compareAndSet (false , aResponse .isFailed () == false );
752+ }
732753
733754 if (ExceptionsHelper .unwrapCause (failure .getCause ()) instanceof VersionConflictEngineException ) {
734- listener .onFailure (new ResourceAlreadyExistsException ("Inference endpoint [{}] already exists" , inferenceEntityId ));
755+ cleanupListener .onFailure (new ResourceAlreadyExistsException ("Inference endpoint [{}] already exists" , inferenceEntityId ));
735756 return ;
736757 }
737758
738- listener .onFailure (
759+ cleanupListener .onFailure (
739760 new ElasticsearchStatusException (
740761 format ("Failed to store inference endpoint [%s]" , inferenceEntityId ),
741762 RestStatus .INTERNAL_SERVER_ERROR ,
742763 failure .getCause ()
743764 )
744765 );
745766 }, e -> {
746- String errorMessage = format ("Failed to store inference endpoint [%s]" , model .getConfigurations (). getInferenceEntityId ());
767+ String errorMessage = format ("Failed to store inference endpoint [%s]" , model .getInferenceEntityId ());
747768 logger .warn (errorMessage , e );
748- listener .onFailure (new ElasticsearchStatusException (errorMessage , RestStatus .INTERNAL_SERVER_ERROR , e ));
769+ cleanupListener .onFailure (new ElasticsearchStatusException (errorMessage , RestStatus .INTERNAL_SERVER_ERROR , e ));
749770 });
750771 }
751772
@@ -779,18 +800,12 @@ public void onFailure(Exception exc) {
779800 };
780801 }
781802
782- private static void logBulkFailures (String inferenceEntityId , BulkResponse bulkResponse ) {
783- for (BulkItemResponse item : bulkResponse .getItems ()) {
784- if (item .isFailed ()) {
785- logger .warn (
786- format (
787- "Failed to store inference endpoint [%s] index: [%s] bulk failure message [%s]" ,
788- inferenceEntityId ,
789- item .getIndex (),
790- item .getFailureMessage ()
791- )
792- );
793- }
803+ private static void logBulkFailure (String inferenceEntityId , BulkItemResponse item ) {
804+ if (item .isFailed ()) {
805+ logger .warn (
806+ format ("Failed to store inference endpoint [%s] index: [%s]" , inferenceEntityId , item .getIndex ()),
807+ item .getFailure ().getCause ()
808+ );
794809 }
795810 }
796811
@@ -937,6 +952,33 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T
937952 }
938953 }
939954
955+ static IndexRequestBuilder createIndexRequestBuilder (
956+ String inferenceId ,
957+ String indexName ,
958+ ToXContentObject body ,
959+ boolean allowOverwriting ,
960+ Client client
961+ ) {
962+ try (XContentBuilder xContentBuilder = XContentFactory .jsonBuilder ()) {
963+ XContentBuilder source = body .toXContent (
964+ xContentBuilder ,
965+ new ToXContent .MapParams (Map .of (ModelConfigurations .USE_ID_FOR_INDEX , Boolean .TRUE .toString ()))
966+ );
967+
968+ return new IndexRequestBuilder (client ).setIndex (indexName )
969+ .setCreate (allowOverwriting == false )
970+ .setId (Model .documentId (inferenceId ))
971+ .setSource (source );
972+ } catch (IOException ex ) {
973+ throw new ElasticsearchException (
974+ "Unexpected serialization exception for index [{}] inference ID [{}]" ,
975+ ex ,
976+ indexName ,
977+ inferenceId
978+ );
979+ }
980+ }
981+
940982 private static UnparsedModel modelToUnparsedModel (Model model ) {
941983 try (XContentBuilder builder = XContentFactory .jsonBuilder ()) {
942984 model .getConfigurations ()
0 commit comments