1919import com .azure .core .http .HttpMethod ;
2020import com .azure .core .http .rest .ResponseBase ;
2121import com .azure .core .util .BinaryData ;
22+ import com .azure .core .util .FluxUtil ;
23+ import com .azure .core .util .logging .ClientLogger ;
2224import com .azure .storage .blob .BlobAsyncClient ;
2325import com .azure .storage .blob .BlobClient ;
2426import com .azure .storage .blob .BlobContainerAsyncClient ;
4850import org .apache .logging .log4j .Logger ;
4951import org .apache .logging .log4j .core .util .Throwables ;
5052import org .elasticsearch .cluster .metadata .RepositoryMetadata ;
53+ import org .elasticsearch .common .CheckedBiFunction ;
5154import org .elasticsearch .common .UUIDs ;
5255import org .elasticsearch .common .blobstore .BlobContainer ;
5356import org .elasticsearch .common .blobstore .BlobPath ;
6467import org .elasticsearch .common .unit .ByteSizeValue ;
6568import org .elasticsearch .common .util .BigArrays ;
6669import org .elasticsearch .core .CheckedConsumer ;
70+ import org .elasticsearch .core .IOUtils ;
6771import org .elasticsearch .core .Nullable ;
6872import org .elasticsearch .core .Tuple ;
6973import org .elasticsearch .repositories .RepositoriesMetrics ;
@@ -121,6 +125,7 @@ public class AzureBlobStore implements BlobStore {
121125 private final ByteSizeValue maxSinglePartUploadSize ;
122126 private final int deletionBatchSize ;
123127 private final int maxConcurrentBatchDeletes ;
128+ private final int multipartUploadMaxConcurrency ;
124129
125130 private final RequestMetricsRecorder requestMetricsRecorder ;
126131 private final AzureClientProvider .RequestMetricsHandler requestMetricsHandler ;
@@ -142,6 +147,7 @@ public AzureBlobStore(
142147 this .maxSinglePartUploadSize = Repository .MAX_SINGLE_PART_UPLOAD_SIZE_SETTING .get (metadata .settings ());
143148 this .deletionBatchSize = Repository .DELETION_BATCH_SIZE_SETTING .get (metadata .settings ());
144149 this .maxConcurrentBatchDeletes = Repository .MAX_CONCURRENT_BATCH_DELETES_SETTING .get (metadata .settings ());
150+ this .multipartUploadMaxConcurrency = service .getMultipartUploadMaxConcurrency ();
145151
146152 List <RequestMatcher > requestMatchers = List .of (
147153 new RequestMatcher ((httpMethod , url ) -> httpMethod == HttpMethod .HEAD , Operation .GET_BLOB_PROPERTIES ),
@@ -464,6 +470,136 @@ protected void onFailure() {
464470 }
465471 }
466472
473+ void writeBlobAtomic (
474+ final OperationPurpose purpose ,
475+ final String blobName ,
476+ final long blobSize ,
477+ final CheckedBiFunction <Long , Long , InputStream , IOException > provider ,
478+ final boolean failIfAlreadyExists
479+ ) throws IOException {
480+ try {
481+ final List <MultiPart > multiParts ;
482+ if (blobSize <= getLargeBlobThresholdInBytes ()) {
483+ multiParts = null ;
484+ } else {
485+ multiParts = computeMultiParts (blobSize , getUploadBlockSize ());
486+ }
487+ if (multiParts == null || multiParts .size () == 1 ) {
488+ logger .debug ("{}: uploading blob of size [{}] as single upload" , blobName , blobSize );
489+ try (var stream = provider .apply (0L , blobSize )) {
490+ var flux = convertStreamToByteBuffer (stream , blobSize , DEFAULT_UPLOAD_BUFFERS_SIZE );
491+ executeSingleUpload (purpose , blobName , flux , blobSize , failIfAlreadyExists );
492+ }
493+ } else {
494+ logger .debug ("{}: uploading blob of size [{}] using [{}] parts" , blobName , blobSize , multiParts .size ());
495+ assert blobSize == ((multiParts .size () - 1 ) * getUploadBlockSize ()) + multiParts .getLast ().blockSize ();
496+ assert multiParts .size () > 1 ;
497+
498+ final var asyncClient = asyncClient (purpose ).getBlobContainerAsyncClient (container )
499+ .getBlobAsyncClient (blobName )
500+ .getBlockBlobAsyncClient ();
501+
502+ Flux .fromIterable (multiParts )
503+ .flatMapSequential (multipart -> stageBlock (asyncClient , blobName , multipart , provider ), multipartUploadMaxConcurrency )
504+ .collect (Collectors .toList ())
505+ .flatMap (blockIds -> {
506+ logger .debug ("{}: all {} parts uploaded, now committing" , blobName , multiParts .size ());
507+ var response = asyncClient .commitBlockList (
508+ multiParts .stream ().map (MultiPart ::blockId ).toList (),
509+ failIfAlreadyExists == false
510+ );
511+ logger .debug ("{}: all {} parts committed" , blobName , multiParts .size ());
512+ return response ;
513+ })
514+ .block ();
515+ }
516+ } catch (final BlobStorageException e ) {
517+ if (failIfAlreadyExists
518+ && e .getStatusCode () == HttpURLConnection .HTTP_CONFLICT
519+ && BlobErrorCode .BLOB_ALREADY_EXISTS .equals (e .getErrorCode ())) {
520+ throw new FileAlreadyExistsException (blobName , null , e .getMessage ());
521+ }
522+ throw new IOException ("Unable to write blob " + blobName , e );
523+ } catch (Exception e ) {
524+ throw new IOException ("Unable to write blob " + blobName , e );
525+ }
526+ }
527+
528+ private record MultiPart (int part , String blockId , long blockOffset , long blockSize , boolean isLast ) {}
529+
530+ private static List <MultiPart > computeMultiParts (long totalSize , long partSize ) {
531+ if (partSize <= 0 ) {
532+ throw new IllegalArgumentException ("Part size must be greater than zero" );
533+ }
534+ if ((totalSize == 0L ) || (totalSize <= partSize )) {
535+ return List .of (new MultiPart (0 , makeMultipartBlockId (), 0L , totalSize , true ));
536+ }
537+
538+ long lastPartSize = totalSize % partSize ;
539+ int parts = Math .toIntExact (totalSize / partSize ) + (0L < lastPartSize ? 1 : 0 );
540+
541+ long blockOffset = 0L ;
542+ var list = new ArrayList <MultiPart >(parts );
543+ for (int p = 0 ; p < parts ; p ++) {
544+ boolean isLast = (p == parts - 1 );
545+ var multipart = new MultiPart (p , makeMultipartBlockId (), blockOffset , isLast ? lastPartSize : partSize , isLast );
546+ blockOffset += multipart .blockSize ();
547+ list .add (multipart );
548+ }
549+ return List .copyOf (list );
550+ }
551+
552+ private static Mono <String > stageBlock (
553+ BlockBlobAsyncClient asyncClient ,
554+ String blobName ,
555+ MultiPart multiPart ,
556+ CheckedBiFunction <Long , Long , InputStream , IOException > provider
557+ ) {
558+ logger .debug (
559+ "{}: staging part [{}] of size [{}] from offset [{}]" ,
560+ blobName ,
561+ multiPart .part (),
562+ multiPart .blockSize (),
563+ multiPart .blockOffset ()
564+ );
565+ try {
566+ var stream = toSynchronizedInputStream (blobName , provider .apply (multiPart .blockOffset (), multiPart .blockSize ()), multiPart );
567+ boolean success = false ;
568+ try {
569+ var stageBlock = asyncClient .stageBlock (
570+ multiPart .blockId (),
571+ toFlux (stream , multiPart .blockSize (), DEFAULT_UPLOAD_BUFFERS_SIZE ),
572+ multiPart .blockSize ()
573+ ).doOnSuccess (unused -> {
574+ logger .debug (() -> format ("%s: part [%s] of size [%s] uploaded" , blobName , multiPart .part (), multiPart .blockSize ()));
575+ IOUtils .closeWhileHandlingException (stream );
576+ }).doOnCancel (() -> {
577+ logger .warn (() -> format ("%s: part [%s] of size [%s] cancelled" , blobName , multiPart .part (), multiPart .blockSize ()));
578+ IOUtils .closeWhileHandlingException (stream );
579+ }).doOnError (t -> {
580+ logger .error (() -> format ("%s: part [%s] of size [%s] failed" , blobName , multiPart .part (), multiPart .blockSize ()), t );
581+ IOUtils .closeWhileHandlingException (stream );
582+ });
583+ logger .debug (
584+ "{}: part [{}] of size [{}] from offset [{}] staged" ,
585+ blobName ,
586+ multiPart .part (),
587+ multiPart .blockSize (),
588+ multiPart .blockOffset ()
589+ );
590+ success = true ;
591+ return stageBlock .map (unused -> multiPart .blockId ());
592+ } finally {
593+ if (success != true ) {
594+ IOUtils .close (stream );
595+ }
596+ }
597+ } catch (IOException e ) {
598+ logger .error (() -> format ("%s: failed to stage part [%s] of size [%s]" , blobName , multiPart .part (), multiPart .blockSize ()), e );
599+ return FluxUtil .monoError (new ClientLogger (AzureBlobStore .class ), new UncheckedIOException (e ));
600+ }
601+ }
602+
467603 public void writeBlob (OperationPurpose purpose , String blobName , InputStream inputStream , long blobSize , boolean failIfAlreadyExists )
468604 throws IOException {
469605 assert inputStream .markSupported ()
@@ -625,6 +761,118 @@ public synchronized int read() throws IOException {
625761 // we read the input stream (i.e. when it's rate limited)
626762 }
627763
764+ private static InputStream toSynchronizedInputStream (String blobName , InputStream delegate , MultiPart multipart ) {
765+ assert delegate .markSupported () : "An InputStream with mark support was expected" ;
766+ // We need to introduce a read barrier in order to provide visibility for the underlying
767+ // input stream state as the input stream can be read from different threads.
768+ // TODO See if this is still needed
769+ return new FilterInputStream (delegate ) {
770+
771+ private final boolean isTraceEnabled = logger .isTraceEnabled ();
772+
773+ @ Override
774+ public synchronized int read (byte [] b , int off , int len ) throws IOException {
775+ var result = super .read (b , off , len );
776+ if (isTraceEnabled ) {
777+ logger .trace ("{} reads {} bytes from {} part {}" , Thread .currentThread (), result , blobName , multipart .part ());
778+ }
779+ return result ;
780+ }
781+
782+ @ Override
783+ public synchronized int read () throws IOException {
784+ var result = super .read ();
785+ if (isTraceEnabled ) {
786+ logger .trace ("{} reads {} byte from {} part {}" , Thread .currentThread (), result , blobName , multipart .part ());
787+ }
788+ return result ;
789+ }
790+
791+ @ Override
792+ public synchronized void mark (int readlimit ) {
793+ if (isTraceEnabled ) {
794+ logger .trace ("{} marks stream {} part {}" , Thread .currentThread (), blobName , multipart .part ());
795+ }
796+ super .mark (readlimit );
797+ }
798+
799+ @ Override
800+ public synchronized void reset () throws IOException {
801+ if (isTraceEnabled ) {
802+ logger .trace ("{} resets stream {} part {}" , Thread .currentThread (), blobName , multipart .part ());
803+ }
804+ super .reset ();
805+ }
806+
807+ @ Override
808+ public synchronized void close () throws IOException {
809+ if (isTraceEnabled ) {
810+ logger .trace ("{} closes stream {} part {}" , Thread .currentThread (), blobName , multipart .part ());
811+ }
812+ super .close ();
813+ }
814+
815+ @ Override
816+ public String toString () {
817+ return blobName + " part [" + multipart .part () + "] of size [" + multipart .blockSize () + ']' ;
818+ }
819+ };
820+ }
821+
822+ private static Flux <ByteBuffer > toFlux (InputStream stream , long length , int chunkSize ) {
823+ assert stream .markSupported () : "An InputStream with mark support was expected" ;
824+ // We need to mark the InputStream as it's possible that we need to retry for the same chunk
825+ stream .mark (Integer .MAX_VALUE );
826+ return Flux .defer (() -> {
827+ // TODO Code in this Flux.defer() can be concurrently executed by multiple threads?
828+ try {
829+ stream .reset ();
830+ } catch (IOException e ) {
831+ throw new RuntimeException (e );
832+ }
833+ final var bytesRead = new AtomicLong (0L );
834+ // This flux is subscribed by a downstream operator that finally queues the
835+ // buffers into netty output queue. Sadly we are not able to get a signal once
836+ // the buffer has been flushed, so we have to allocate those and let the GC to
837+ // reclaim them (see MonoSendMany). Additionally, that very same operator requests
838+ // 128 elements (that's hardcoded) once it's subscribed (later on, it requests
839+ // by 64 elements), that's why we provide 64kb buffers.
840+
841+ // length is at most 100MB so it's safe to cast back to an integer in this case
842+ final int parts = (int ) length / chunkSize ;
843+ final long remaining = length % chunkSize ;
844+ return Flux .range (0 , remaining == 0 ? parts : parts + 1 ).map (i -> i * chunkSize ).concatMap (pos -> Mono .fromCallable (() -> {
845+ long count = pos + chunkSize > length ? length - pos : chunkSize ;
846+ int numOfBytesRead = 0 ;
847+ int offset = 0 ;
848+ int len = (int ) count ;
849+ final byte [] buffer = new byte [len ];
850+ while (numOfBytesRead != -1 && offset < count ) {
851+ numOfBytesRead = stream .read (buffer , offset , len );
852+ offset += numOfBytesRead ;
853+ len -= numOfBytesRead ;
854+ if (numOfBytesRead != -1 ) {
855+ bytesRead .addAndGet (numOfBytesRead );
856+ }
857+ }
858+ if (numOfBytesRead == -1 && bytesRead .get () < length ) {
859+ throw new IllegalStateException (
860+ format ("Input stream [%s] emitted %d bytes, less than the expected %d bytes." , stream , bytesRead , length )
861+ );
862+ }
863+ return ByteBuffer .wrap (buffer );
864+ })).doOnComplete (() -> {
865+ if (bytesRead .get () > length ) {
866+ throw new IllegalStateException (
867+ format ("Input stream [%s] emitted %d bytes, more than the expected %d bytes." , stream , bytesRead , length )
868+ );
869+ }
870+ });
871+ // We need to subscribe on a different scheduler to avoid blocking the io threads when we read the input stream
872+ }).subscribeOn (Schedulers .elastic ());
873+
874+ }
875+
628876 /**
629877 * Returns the number parts of size of {@code partSize} needed to reach {@code totalSize},
630878 * along with the size of the last (or unique) part.
0 commit comments