2121import org .elasticsearch .action .bulk .BulkItemResponse ;
2222import org .elasticsearch .action .bulk .BulkResponse ;
2323import org .elasticsearch .action .index .IndexRequest ;
24+ import org .elasticsearch .action .index .IndexRequestBuilder ;
2425import org .elasticsearch .action .search .SearchRequest ;
2526import org .elasticsearch .action .search .SearchResponse ;
2627import org .elasticsearch .action .support .GroupedActionListener ;
@@ -514,11 +515,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
514515
515516 SubscribableListener .<BulkResponse >newForked ((subListener ) -> {
516517 // in this block, we try to update the stored model configurations
517- IndexRequest configRequest = createIndexRequest (
518- Model . documentId ( inferenceEntityId ) ,
518+ var configRequestBuilder = createIndexRequestBuilder (
519+ inferenceEntityId ,
519520 InferenceIndex .INDEX_NAME ,
520521 newModel .getConfigurations (),
521- true
522+ true ,
523+ client
522524 );
523525
524526 ActionListener <BulkResponse > storeConfigListener = subListener .delegateResponse ((l , e ) -> {
@@ -527,7 +529,10 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
527529 l .onFailure (e );
528530 });
529531
530- client .prepareBulk ().add (configRequest ).setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE ).execute (storeConfigListener );
532+ client .prepareBulk ()
533+ .add (configRequestBuilder )
534+ .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
535+ .execute (storeConfigListener );
531536
532537 }).<BulkResponse >andThen ((subListener , configResponse ) -> {
533538 // in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets
@@ -552,11 +557,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
552557 );
553558 } else {
554559 // Since the model configurations were successfully updated, we can now try to store the new secrets
555- IndexRequest secretsRequest = createIndexRequest (
556- Model . documentId ( newModel . getConfigurations (). getInferenceEntityId ()) ,
560+ var secretsRequestBuilder = createIndexRequestBuilder (
561+ inferenceEntityId ,
557562 InferenceSecretsIndex .INDEX_NAME ,
558563 newModel .getSecrets (),
559- true
564+ true ,
565+ client
560566 );
561567
562568 ActionListener <BulkResponse > storeSecretsListener = subListener .delegateResponse ((l , e ) -> {
@@ -566,20 +572,22 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
566572 });
567573
568574 client .prepareBulk ()
569- .add (secretsRequest )
575+ .add (secretsRequestBuilder )
570576 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
571577 .execute (storeSecretsListener );
572578 }
573579 }).<BulkResponse >andThen ((subListener , secretsResponse ) -> {
574580 // in this block, we respond to the success or failure of updating the model secrets
575581 if (secretsResponse .hasFailures ()) {
576582 // since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations
577- IndexRequest configRequest = createIndexRequest (
578- Model . documentId ( inferenceEntityId ) ,
583+ var configRequestBuilder = createIndexRequestBuilder (
584+ inferenceEntityId ,
579585 InferenceIndex .INDEX_NAME ,
580586 existingModel .getConfigurations (),
581- true
587+ true ,
588+ client
582589 );
590+
583591 logger .error (
584592 "Failed to update inference endpoint secrets [{}], attempting rolling back to previous state" ,
585593 inferenceEntityId
@@ -591,7 +599,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
591599 l .onFailure (e );
592600 });
593601 client .prepareBulk ()
594- .add (configRequest )
602+ .add (configRequestBuilder )
595603 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
596604 .execute (rollbackConfigListener );
597605 } else {
@@ -637,24 +645,25 @@ public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue
637645
638646 private void storeModel (Model model , boolean updateClusterState , ActionListener <Boolean > listener , TimeValue timeout ) {
639647 ActionListener <BulkResponse > bulkResponseActionListener = getStoreIndexListener (model , updateClusterState , listener , timeout );
640-
641- IndexRequest configRequest = createIndexRequest (
642- Model . documentId ( model . getConfigurations (). getInferenceEntityId ()) ,
648+ String inferenceEntityId = model . getConfigurations (). getInferenceEntityId ();
649+ var configRequestBuilder = createIndexRequestBuilder (
650+ inferenceEntityId ,
643651 InferenceIndex .INDEX_NAME ,
644652 model .getConfigurations (),
645- false
653+ false ,
654+ client
646655 );
647-
648- IndexRequest secretsRequest = createIndexRequest (
649- Model .documentId (model .getConfigurations ().getInferenceEntityId ()),
656+ var secretsRequestBuilder = createIndexRequestBuilder (
657+ inferenceEntityId ,
650658 InferenceSecretsIndex .INDEX_NAME ,
651659 model .getSecrets (),
652- false
660+ false ,
661+ client
653662 );
654663
655664 client .prepareBulk ()
656- .add (configRequest )
657- .add (secretsRequest )
665+ .add (configRequestBuilder )
666+ .add (secretsRequestBuilder )
658667 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
659668 .execute (bulkResponseActionListener );
660669 }
@@ -665,15 +674,24 @@ private ActionListener<BulkResponse> getStoreIndexListener(
665674 ActionListener <Boolean > listener ,
666675 TimeValue timeout
667676 ) {
677+ // If there was a partial failure in writing to the indices, we need to clean up
678+ AtomicBoolean partialFailure = new AtomicBoolean (false );
679+ var cleanupListener = listener .delegateResponse ((delegate , ex ) -> {
680+ if (partialFailure .get ()) {
681+ deleteModel (model .getInferenceEntityId (), ActionListener .running (() -> delegate .onFailure (ex )));
682+ } else {
683+ delegate .onFailure (ex );
684+ }
685+ });
668686 return ActionListener .wrap (bulkItemResponses -> {
669- var inferenceEntityId = model .getConfigurations (). getInferenceEntityId ();
687+ var inferenceEntityId = model .getInferenceEntityId ();
670688
671689 if (bulkItemResponses .getItems ().length == 0 ) {
672690 logger .warn (
673691 format ("Storing inference endpoint [%s] failed, no items were received from the bulk response" , inferenceEntityId )
674692 );
675693
676- listener .onFailure (
694+ cleanupListener .onFailure (
677695 new ElasticsearchStatusException (
678696 format (
679697 "Failed to store inference endpoint [%s], invalid bulk response received. Try reinitializing the service" ,
@@ -689,7 +707,7 @@ private ActionListener<BulkResponse> getStoreIndexListener(
689707
690708 if (failure == null ) {
691709 if (updateClusterState ) {
692- var storeListener = getStoreMetadataListener (inferenceEntityId , listener );
710+ var storeListener = getStoreMetadataListener (inferenceEntityId , cleanupListener );
693711 try {
694712 metadataTaskQueue .submitTask (
695713 "add model [" + inferenceEntityId + "]" ,
@@ -705,29 +723,32 @@ private ActionListener<BulkResponse> getStoreIndexListener(
705723 storeListener .onFailure (exc );
706724 }
707725 } else {
708- listener .onResponse (Boolean .TRUE );
726+ cleanupListener .onResponse (Boolean .TRUE );
709727 }
710728 return ;
711729 }
712730
713- logBulkFailures (model .getConfigurations ().getInferenceEntityId (), bulkItemResponses );
731+ for (BulkItemResponse aResponse : bulkItemResponses .getItems ()) {
732+ logBulkFailure (inferenceEntityId , aResponse );
733+ partialFailure .compareAndSet (false , aResponse .isFailed () == false );
734+ }
714735
715736 if (ExceptionsHelper .unwrapCause (failure .getCause ()) instanceof VersionConflictEngineException ) {
716- listener .onFailure (new ResourceAlreadyExistsException ("Inference endpoint [{}] already exists" , inferenceEntityId ));
737+ cleanupListener .onFailure (new ResourceAlreadyExistsException ("Inference endpoint [{}] already exists" , inferenceEntityId ));
717738 return ;
718739 }
719740
720- listener .onFailure (
741+ cleanupListener .onFailure (
721742 new ElasticsearchStatusException (
722743 format ("Failed to store inference endpoint [%s]" , inferenceEntityId ),
723744 RestStatus .INTERNAL_SERVER_ERROR ,
724745 failure .getCause ()
725746 )
726747 );
727748 }, e -> {
728- String errorMessage = format ("Failed to store inference endpoint [%s]" , model .getConfigurations (). getInferenceEntityId ());
749+ String errorMessage = format ("Failed to store inference endpoint [%s]" , model .getInferenceEntityId ());
729750 logger .warn (errorMessage , e );
730- listener .onFailure (new ElasticsearchStatusException (errorMessage , RestStatus .INTERNAL_SERVER_ERROR , e ));
751+ cleanupListener .onFailure (new ElasticsearchStatusException (errorMessage , RestStatus .INTERNAL_SERVER_ERROR , e ));
731752 });
732753 }
733754
@@ -761,18 +782,12 @@ public void onFailure(Exception exc) {
761782 };
762783 }
763784
764- private static void logBulkFailures (String inferenceEntityId , BulkResponse bulkResponse ) {
765- for (BulkItemResponse item : bulkResponse .getItems ()) {
766- if (item .isFailed ()) {
767- logger .warn (
768- format (
769- "Failed to store inference endpoint [%s] index: [%s] bulk failure message [%s]" ,
770- inferenceEntityId ,
771- item .getIndex (),
772- item .getFailureMessage ()
773- )
774- );
775- }
785+ private static void logBulkFailure (String inferenceEntityId , BulkItemResponse item ) {
786+ if (item .isFailed ()) {
787+ logger .warn (
788+ format ("Failed to store inference endpoint [%s] index: [%s]" , inferenceEntityId , item .getIndex ()),
789+ item .getFailure ().getCause ()
790+ );
776791 }
777792 }
778793
@@ -905,6 +920,33 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T
905920 }
906921 }
907922
923+ static IndexRequestBuilder createIndexRequestBuilder (
924+ String inferenceId ,
925+ String indexName ,
926+ ToXContentObject body ,
927+ boolean allowOverwriting ,
928+ Client client
929+ ) {
930+ try (XContentBuilder xContentBuilder = XContentFactory .jsonBuilder ()) {
931+ XContentBuilder source = body .toXContent (
932+ xContentBuilder ,
933+ new ToXContent .MapParams (Map .of (ModelConfigurations .USE_ID_FOR_INDEX , Boolean .TRUE .toString ()))
934+ );
935+
936+ return new IndexRequestBuilder (client ).setIndex (indexName )
937+ .setCreate (allowOverwriting == false )
938+ .setId (Model .documentId (inferenceId ))
939+ .setSource (source );
940+ } catch (IOException ex ) {
941+ throw new ElasticsearchException (
942+ "Unexpected serialization exception for index [{}] inference ID [{}]" ,
943+ ex ,
944+ indexName ,
945+ inferenceId
946+ );
947+ }
948+ }
949+
908950 private static UnparsedModel modelToUnparsedModel (Model model ) {
909951 try (XContentBuilder builder = XContentFactory .jsonBuilder ()) {
910952 model .getConfigurations ()
0 commit comments