2121import org .elasticsearch .action .bulk .BulkItemResponse ;
2222import org .elasticsearch .action .bulk .BulkResponse ;
2323import org .elasticsearch .action .index .IndexRequest ;
24+ import org .elasticsearch .action .index .IndexRequestBuilder ;
2425import org .elasticsearch .action .search .SearchRequest ;
2526import org .elasticsearch .action .search .SearchResponse ;
2627import org .elasticsearch .action .support .GroupedActionListener ;
@@ -510,11 +511,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
510511
511512 SubscribableListener .<BulkResponse >newForked ((subListener ) -> {
512513 // in this block, we try to update the stored model configurations
513- IndexRequest configRequest = createIndexRequest (
514- Model . documentId ( inferenceEntityId ) ,
514+ var configRequestBuilder = createIndexRequestBuilder (
515+ inferenceEntityId ,
515516 InferenceIndex .INDEX_NAME ,
516517 newModel .getConfigurations (),
517- true
518+ true ,
519+ client
518520 );
519521
520522 ActionListener <BulkResponse > storeConfigListener = subListener .delegateResponse ((l , e ) -> {
@@ -523,7 +525,10 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
523525 l .onFailure (e );
524526 });
525527
526- client .prepareBulk ().add (configRequest ).setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE ).execute (storeConfigListener );
528+ client .prepareBulk ()
529+ .add (configRequestBuilder )
530+ .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
531+ .execute (storeConfigListener );
527532
528533 }).<BulkResponse >andThen ((subListener , configResponse ) -> {
529534 // in this block, we respond to the success or failure of updating the model configurations, then try to store the new secrets
@@ -548,11 +553,12 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
548553 );
549554 } else {
550555 // Since the model configurations were successfully updated, we can now try to store the new secrets
551- IndexRequest secretsRequest = createIndexRequest (
552- Model . documentId ( newModel . getConfigurations (). getInferenceEntityId ()) ,
556+ var secretsRequestBuilder = createIndexRequestBuilder (
557+ inferenceEntityId ,
553558 InferenceSecretsIndex .INDEX_NAME ,
554559 newModel .getSecrets (),
555- true
560+ true ,
561+ client
556562 );
557563
558564 ActionListener <BulkResponse > storeSecretsListener = subListener .delegateResponse ((l , e ) -> {
@@ -562,20 +568,22 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
562568 });
563569
564570 client .prepareBulk ()
565- .add (secretsRequest )
571+ .add (secretsRequestBuilder )
566572 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
567573 .execute (storeSecretsListener );
568574 }
569575 }).<BulkResponse >andThen ((subListener , secretsResponse ) -> {
570576 // in this block, we respond to the success or failure of updating the model secrets
571577 if (secretsResponse .hasFailures ()) {
572578 // since storing the secrets failed, we will try to restore / roll-back-to the previous model configurations
573- IndexRequest configRequest = createIndexRequest (
574- Model . documentId ( inferenceEntityId ) ,
579+ var configRequestBuilder = createIndexRequestBuilder (
580+ inferenceEntityId ,
575581 InferenceIndex .INDEX_NAME ,
576582 existingModel .getConfigurations (),
577- true
583+ true ,
584+ client
578585 );
586+
579587 logger .error (
580588 "Failed to update inference endpoint secrets [{}], attempting rolling back to previous state" ,
581589 inferenceEntityId
@@ -587,7 +595,7 @@ public void updateModelTransaction(Model newModel, Model existingModel, ActionLi
587595 l .onFailure (e );
588596 });
589597 client .prepareBulk ()
590- .add (configRequest )
598+ .add (configRequestBuilder )
591599 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
592600 .execute (rollbackConfigListener );
593601 } else {
@@ -633,24 +641,25 @@ public void storeModel(Model model, ActionListener<Boolean> listener, TimeValue
633641
634642 private void storeModel (Model model , boolean updateClusterState , ActionListener <Boolean > listener , TimeValue timeout ) {
635643 ActionListener <BulkResponse > bulkResponseActionListener = getStoreIndexListener (model , updateClusterState , listener , timeout );
636-
637- IndexRequest configRequest = createIndexRequest (
638- Model . documentId ( model . getConfigurations (). getInferenceEntityId ()) ,
644+ String inferenceEntityId = model . getConfigurations (). getInferenceEntityId ();
645+ var configRequestBuilder = createIndexRequestBuilder (
646+ inferenceEntityId ,
639647 InferenceIndex .INDEX_NAME ,
640648 model .getConfigurations (),
641- false
649+ false ,
650+ client
642651 );
643-
644- IndexRequest secretsRequest = createIndexRequest (
645- Model .documentId (model .getConfigurations ().getInferenceEntityId ()),
652+ var secretsRequestBuilder = createIndexRequestBuilder (
653+ inferenceEntityId ,
646654 InferenceSecretsIndex .INDEX_NAME ,
647655 model .getSecrets (),
648- false
656+ false ,
657+ client
649658 );
650659
651660 client .prepareBulk ()
652- .add (configRequest )
653- .add (secretsRequest )
661+ .add (configRequestBuilder )
662+ .add (secretsRequestBuilder )
654663 .setRefreshPolicy (WriteRequest .RefreshPolicy .IMMEDIATE )
655664 .execute (bulkResponseActionListener );
656665 }
@@ -661,15 +670,24 @@ private ActionListener<BulkResponse> getStoreIndexListener(
661670 ActionListener <Boolean > listener ,
662671 TimeValue timeout
663672 ) {
673+ // If there was a partial failure in writing to the indices, we need to clean up
674+ AtomicBoolean partialFailure = new AtomicBoolean (false );
675+ var cleanupListener = listener .delegateResponse ((delegate , ex ) -> {
676+ if (partialFailure .get ()) {
677+ deleteModel (model .getInferenceEntityId (), ActionListener .running (() -> delegate .onFailure (ex )));
678+ } else {
679+ delegate .onFailure (ex );
680+ }
681+ });
664682 return ActionListener .wrap (bulkItemResponses -> {
665- var inferenceEntityId = model .getConfigurations (). getInferenceEntityId ();
683+ var inferenceEntityId = model .getInferenceEntityId ();
666684
667685 if (bulkItemResponses .getItems ().length == 0 ) {
668686 logger .warn (
669687 format ("Storing inference endpoint [%s] failed, no items were received from the bulk response" , inferenceEntityId )
670688 );
671689
672- listener .onFailure (
690+ cleanupListener .onFailure (
673691 new ElasticsearchStatusException (
674692 format (
675693 "Failed to store inference endpoint [%s], invalid bulk response received. Try reinitializing the service" ,
@@ -685,7 +703,7 @@ private ActionListener<BulkResponse> getStoreIndexListener(
685703
686704 if (failure == null ) {
687705 if (updateClusterState ) {
688- var storeListener = getStoreMetadataListener (inferenceEntityId , listener );
706+ var storeListener = getStoreMetadataListener (inferenceEntityId , cleanupListener );
689707 try {
690708 metadataTaskQueue .submitTask (
691709 "add model [" + inferenceEntityId + "]" ,
@@ -696,29 +714,32 @@ private ActionListener<BulkResponse> getStoreIndexListener(
696714 storeListener .onFailure (exc );
697715 }
698716 } else {
699- listener .onResponse (Boolean .TRUE );
717+ cleanupListener .onResponse (Boolean .TRUE );
700718 }
701719 return ;
702720 }
703721
704- logBulkFailures (model .getConfigurations ().getInferenceEntityId (), bulkItemResponses );
722+ for (BulkItemResponse aResponse : bulkItemResponses .getItems ()) {
723+ logBulkFailure (inferenceEntityId , aResponse );
724+ partialFailure .compareAndSet (false , aResponse .isFailed () == false );
725+ }
705726
706727 if (ExceptionsHelper .unwrapCause (failure .getCause ()) instanceof VersionConflictEngineException ) {
707- listener .onFailure (new ResourceAlreadyExistsException ("Inference endpoint [{}] already exists" , inferenceEntityId ));
728+ cleanupListener .onFailure (new ResourceAlreadyExistsException ("Inference endpoint [{}] already exists" , inferenceEntityId ));
708729 return ;
709730 }
710731
711- listener .onFailure (
732+ cleanupListener .onFailure (
712733 new ElasticsearchStatusException (
713734 format ("Failed to store inference endpoint [%s]" , inferenceEntityId ),
714735 RestStatus .INTERNAL_SERVER_ERROR ,
715736 failure .getCause ()
716737 )
717738 );
718739 }, e -> {
719- String errorMessage = format ("Failed to store inference endpoint [%s]" , model .getConfigurations (). getInferenceEntityId ());
740+ String errorMessage = format ("Failed to store inference endpoint [%s]" , model .getInferenceEntityId ());
720741 logger .warn (errorMessage , e );
721- listener .onFailure (new ElasticsearchStatusException (errorMessage , RestStatus .INTERNAL_SERVER_ERROR , e ));
742+ cleanupListener .onFailure (new ElasticsearchStatusException (errorMessage , RestStatus .INTERNAL_SERVER_ERROR , e ));
722743 });
723744 }
724745
@@ -752,18 +773,12 @@ public void onFailure(Exception exc) {
752773 };
753774 }
754775
755- private static void logBulkFailures (String inferenceEntityId , BulkResponse bulkResponse ) {
756- for (BulkItemResponse item : bulkResponse .getItems ()) {
757- if (item .isFailed ()) {
758- logger .warn (
759- format (
760- "Failed to store inference endpoint [%s] index: [%s] bulk failure message [%s]" ,
761- inferenceEntityId ,
762- item .getIndex (),
763- item .getFailureMessage ()
764- )
765- );
766- }
776+ private static void logBulkFailure (String inferenceEntityId , BulkItemResponse item ) {
777+ if (item .isFailed ()) {
778+ logger .warn (
779+ format ("Failed to store inference endpoint [%s] index: [%s]" , inferenceEntityId , item .getIndex ()),
780+ item .getFailure ().getCause ()
781+ );
767782 }
768783 }
769784
@@ -896,6 +911,33 @@ private static IndexRequest createIndexRequest(String docId, String indexName, T
896911 }
897912 }
898913
914+ static IndexRequestBuilder createIndexRequestBuilder (
915+ String inferenceId ,
916+ String indexName ,
917+ ToXContentObject body ,
918+ boolean allowOverwriting ,
919+ Client client
920+ ) {
921+ try (XContentBuilder xContentBuilder = XContentFactory .jsonBuilder ()) {
922+ XContentBuilder source = body .toXContent (
923+ xContentBuilder ,
924+ new ToXContent .MapParams (Map .of (ModelConfigurations .USE_ID_FOR_INDEX , Boolean .TRUE .toString ()))
925+ );
926+
927+ return new IndexRequestBuilder (client ).setIndex (indexName )
928+ .setCreate (allowOverwriting == false )
929+ .setId (Model .documentId (inferenceId ))
930+ .setSource (source );
931+ } catch (IOException ex ) {
932+ throw new ElasticsearchException (
933+ "Unexpected serialization exception for index [{}] inference ID [{}]" ,
934+ ex ,
935+ indexName ,
936+ inferenceId
937+ );
938+ }
939+ }
940+
899941 private static UnparsedModel modelToUnparsedModel (Model model ) {
900942 try (XContentBuilder builder = XContentFactory .jsonBuilder ()) {
901943 model .getConfigurations ()
0 commit comments