diff --git a/docs/changelog/128449.yaml b/docs/changelog/128449.yaml new file mode 100644 index 0000000000000..12798783942e6 --- /dev/null +++ b/docs/changelog/128449.yaml @@ -0,0 +1,5 @@ +pr: 128449 +summary: "[Draft] Support concurrent multipart uploads in Azure" +area: Snapshot/Restore +type: enhancement +issues: [] diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java index 08bdc2051b9e3..a040067d7b1b0 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobContainer.java @@ -15,6 +15,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Strings; import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; @@ -105,6 +106,22 @@ public void writeBlob(OperationPurpose purpose, String blobName, InputStream inp blobStore.writeBlob(purpose, buildKey(blobName), inputStream, blobSize, failIfAlreadyExists); } + @Override + public boolean supportsConcurrentMultipartUploads() { + return true; + } + + @Override + public void writeBlobAtomic( + OperationPurpose purpose, + String blobName, + long blobSize, + CheckedBiFunction provider, + boolean failIfAlreadyExists + ) throws IOException { + blobStore.writeBlobAtomic(purpose, buildKey(blobName), blobSize, provider, failIfAlreadyExists); + } + @Override public void writeBlobAtomic( OperationPurpose purpose, diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java index 906035e9ac7d9..430459f47ddfb 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureBlobStore.java @@ -19,6 +19,8 @@ import com.azure.core.http.HttpMethod; import com.azure.core.http.rest.ResponseBase; import com.azure.core.util.BinaryData; +import com.azure.core.util.FluxUtil; +import com.azure.core.util.logging.ClientLogger; import com.azure.storage.blob.BlobAsyncClient; import com.azure.storage.blob.BlobClient; import com.azure.storage.blob.BlobContainerAsyncClient; @@ -48,6 +50,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.core.util.Throwables; import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; @@ -64,6 +67,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.core.CheckedConsumer; +import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; import org.elasticsearch.repositories.RepositoriesMetrics; @@ -121,6 +125,7 @@ public class AzureBlobStore implements BlobStore { private final ByteSizeValue maxSinglePartUploadSize; private final int deletionBatchSize; private final int maxConcurrentBatchDeletes; + private final int multipartUploadMaxConcurrency; private final RequestMetricsRecorder requestMetricsRecorder; private final AzureClientProvider.RequestMetricsHandler requestMetricsHandler; @@ -142,6 +147,7 @@ public AzureBlobStore( this.maxSinglePartUploadSize = Repository.MAX_SINGLE_PART_UPLOAD_SIZE_SETTING.get(metadata.settings()); this.deletionBatchSize = Repository.DELETION_BATCH_SIZE_SETTING.get(metadata.settings()); this.maxConcurrentBatchDeletes = Repository.MAX_CONCURRENT_BATCH_DELETES_SETTING.get(metadata.settings()); + this.multipartUploadMaxConcurrency = service.getMultipartUploadMaxConcurrency(); List requestMatchers = List.of( new RequestMatcher((httpMethod, url) -> httpMethod == HttpMethod.HEAD, Operation.GET_BLOB_PROPERTIES), @@ -464,6 +470,136 @@ protected void onFailure() { } } + void writeBlobAtomic( + final OperationPurpose purpose, + final String blobName, + final long blobSize, + final CheckedBiFunction provider, + final boolean failIfAlreadyExists + ) throws IOException { + try { + final List multiParts; + if (blobSize <= getLargeBlobThresholdInBytes()) { + multiParts = null; + } else { + multiParts = computeMultiParts(blobSize, getUploadBlockSize()); + } + if (multiParts == null || multiParts.size() == 1) { + logger.debug("{}: uploading blob of size [{}] as single upload", blobName, blobSize); + try (var stream = provider.apply(0L, blobSize)) { + var flux = convertStreamToByteBuffer(stream, blobSize, DEFAULT_UPLOAD_BUFFERS_SIZE); + executeSingleUpload(purpose, blobName, flux, blobSize, failIfAlreadyExists); + } + } else { + logger.debug("{}: uploading blob of size [{}] using [{}] parts", blobName, blobSize, multiParts.size()); + assert blobSize == ((multiParts.size() - 1) * getUploadBlockSize()) + multiParts.getLast().blockSize(); + assert multiParts.size() > 1; + + final var asyncClient = asyncClient(purpose).getBlobContainerAsyncClient(container) + .getBlobAsyncClient(blobName) + .getBlockBlobAsyncClient(); + + Flux.fromIterable(multiParts) + .flatMapSequential(multipart -> stageBlock(asyncClient, blobName, multipart, provider), multipartUploadMaxConcurrency) + .collect(Collectors.toList()) + .flatMap(blockIds -> { + logger.debug("{}: all {} parts uploaded, now committing", blobName, multiParts.size()); + var response = asyncClient.commitBlockList( + multiParts.stream().map(MultiPart::blockId).toList(), + failIfAlreadyExists == false + ); + logger.debug("{}: all {} parts committed", blobName, multiParts.size()); + return response; + }) + .block(); + } + } catch (final BlobStorageException e) { + if (failIfAlreadyExists + && e.getStatusCode() == HttpURLConnection.HTTP_CONFLICT + && BlobErrorCode.BLOB_ALREADY_EXISTS.equals(e.getErrorCode())) { + throw new FileAlreadyExistsException(blobName, null, e.getMessage()); + } + throw new IOException("Unable to write blob " + blobName, e); + } catch (Exception e) { + throw new IOException("Unable to write blob " + blobName, e); + } + } + + private record MultiPart(int part, String blockId, long blockOffset, long blockSize, boolean isLast) {} + + private static List computeMultiParts(long totalSize, long partSize) { + if (partSize <= 0) { + throw new IllegalArgumentException("Part size must be greater than zero"); + } + if ((totalSize == 0L) || (totalSize <= partSize)) { + return List.of(new MultiPart(0, makeMultipartBlockId(), 0L, totalSize, true)); + } + + long lastPartSize = totalSize % partSize; + int parts = Math.toIntExact(totalSize / partSize) + (0L < lastPartSize ? 1 : 0); + + long blockOffset = 0L; + var list = new ArrayList(parts); + for (int p = 0; p < parts; p++) { + boolean isLast = (p == parts - 1); + var multipart = new MultiPart(p, makeMultipartBlockId(), blockOffset, isLast ? lastPartSize : partSize, isLast); + blockOffset += multipart.blockSize(); + list.add(multipart); + } + return List.copyOf(list); + } + + private static Mono stageBlock( + BlockBlobAsyncClient asyncClient, + String blobName, + MultiPart multiPart, + CheckedBiFunction provider + ) { + logger.debug( + "{}: staging part [{}] of size [{}] from offset [{}]", + blobName, + multiPart.part(), + multiPart.blockSize(), + multiPart.blockOffset() + ); + try { + var stream = toSynchronizedInputStream(blobName, provider.apply(multiPart.blockOffset(), multiPart.blockSize()), multiPart); + boolean success = false; + try { + var stageBlock = asyncClient.stageBlock( + multiPart.blockId(), + toFlux(stream, multiPart.blockSize(), DEFAULT_UPLOAD_BUFFERS_SIZE), + multiPart.blockSize() + ).doOnSuccess(unused -> { + logger.debug(() -> format("%s: part [%s] of size [%s] uploaded", blobName, multiPart.part(), multiPart.blockSize())); + IOUtils.closeWhileHandlingException(stream); + }).doOnCancel(() -> { + logger.warn(() -> format("%s: part [%s] of size [%s] cancelled", blobName, multiPart.part(), multiPart.blockSize())); + IOUtils.closeWhileHandlingException(stream); + }).doOnError(t -> { + logger.error(() -> format("%s: part [%s] of size [%s] failed", blobName, multiPart.part(), multiPart.blockSize()), t); + IOUtils.closeWhileHandlingException(stream); + }); + logger.debug( + "{}: part [{}] of size [{}] from offset [{}] staged", + blobName, + multiPart.part(), + multiPart.blockSize(), + multiPart.blockOffset() + ); + success = true; + return stageBlock.map(unused -> multiPart.blockId()); + } finally { + if (success != true) { + IOUtils.close(stream); + } + } + } catch (IOException e) { + logger.error(() -> format("%s: failed to stage part [%s] of size [%s]", blobName, multiPart.part(), multiPart.blockSize()), e); + return FluxUtil.monoError(new ClientLogger(AzureBlobStore.class), new UncheckedIOException(e)); + } + } + public void writeBlob(OperationPurpose purpose, String blobName, InputStream inputStream, long blobSize, boolean failIfAlreadyExists) throws IOException { assert inputStream.markSupported() @@ -625,6 +761,118 @@ public synchronized int read() throws IOException { // we read the input stream (i.e. when it's rate limited) } + private static InputStream toSynchronizedInputStream(String blobName, InputStream delegate, MultiPart multipart) { + assert delegate.markSupported() : "An InputStream with mark support was expected"; + // We need to introduce a read barrier in order to provide visibility for the underlying + // input stream state as the input stream can be read from different threads. + // TODO See if this is still needed + return new FilterInputStream(delegate) { + + private final boolean isTraceEnabled = logger.isTraceEnabled(); + + @Override + public synchronized int read(byte[] b, int off, int len) throws IOException { + var result = super.read(b, off, len); + if (isTraceEnabled) { + logger.trace("{} reads {} bytes from {} part {}", Thread.currentThread(), result, blobName, multipart.part()); + } + return result; + } + + @Override + public synchronized int read() throws IOException { + var result = super.read(); + if (isTraceEnabled) { + logger.trace("{} reads {} byte from {} part {}", Thread.currentThread(), result, blobName, multipart.part()); + } + return result; + } + + @Override + public synchronized void mark(int readlimit) { + if (isTraceEnabled) { + logger.trace("{} marks stream {} part {}", Thread.currentThread(), blobName, multipart.part()); + } + super.mark(readlimit); + } + + @Override + public synchronized void reset() throws IOException { + if (isTraceEnabled) { + logger.trace("{} resets stream {} part {}", Thread.currentThread(), blobName, multipart.part()); + } + super.reset(); + } + + @Override + public synchronized void close() throws IOException { + if (isTraceEnabled) { + logger.trace("{} closes stream {} part {}", Thread.currentThread(), blobName, multipart.part()); + } + super.close(); + } + + @Override + public String toString() { + return blobName + " part [" + multipart.part() + "] of size [" + multipart.blockSize() + ']'; + } + }; + } + + private static Flux toFlux(InputStream stream, long length, int chunkSize) { + assert stream.markSupported() : "An InputStream with mark support was expected"; + // We need to mark the InputStream as it's possible that we need to retry for the same chunk + stream.mark(Integer.MAX_VALUE); + return Flux.defer(() -> { + // TODO Code in this Flux.defer() can be concurrently executed by multiple threads? + try { + stream.reset(); + } catch (IOException e) { + throw new RuntimeException(e); + } + final var bytesRead = new AtomicLong(0L); + // This flux is subscribed by a downstream operator that finally queues the + // buffers into netty output queue. Sadly we are not able to get a signal once + // the buffer has been flushed, so we have to allocate those and let the GC to + // reclaim them (see MonoSendMany). Additionally, that very same operator requests + // 128 elements (that's hardcoded) once it's subscribed (later on, it requests + // by 64 elements), that's why we provide 64kb buffers. + + // length is at most 100MB so it's safe to cast back to an integer in this case + final int parts = (int) length / chunkSize; + final long remaining = length % chunkSize; + return Flux.range(0, remaining == 0 ? parts : parts + 1).map(i -> i * chunkSize).concatMap(pos -> Mono.fromCallable(() -> { + long count = pos + chunkSize > length ? length - pos : chunkSize; + int numOfBytesRead = 0; + int offset = 0; + int len = (int) count; + final byte[] buffer = new byte[len]; + while (numOfBytesRead != -1 && offset < count) { + numOfBytesRead = stream.read(buffer, offset, len); + offset += numOfBytesRead; + len -= numOfBytesRead; + if (numOfBytesRead != -1) { + bytesRead.addAndGet(numOfBytesRead); + } + } + if (numOfBytesRead == -1 && bytesRead.get() < length) { + throw new IllegalStateException( + format("Input stream [%s] emitted %d bytes, less than the expected %d bytes.", stream, bytesRead, length) + ); + } + return ByteBuffer.wrap(buffer); + })).doOnComplete(() -> { + if (bytesRead.get() > length) { + throw new IllegalStateException( + format("Input stream [%s] emitted %d bytes, more than the expected %d bytes.", stream, bytesRead, length) + ); + } + }); + // We need to subscribe on a different scheduler to avoid blocking the io threads when we read the input stream + }).subscribeOn(Schedulers.elastic()); + + } + /** * Returns the number parts of size of {@code partSize} needed to reach {@code totalSize}, * along with the size of the last (or unique) part. diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureClientProvider.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureClientProvider.java index eed04bcdf57bf..6f1b60a798789 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureClientProvider.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureClientProvider.java @@ -109,6 +109,7 @@ class AzureClientProvider extends AbstractLifecycleComponent { private final ConnectionProvider connectionProvider; private final ByteBufAllocator byteBufAllocator; private final LoopResources nioLoopResources; + private final int multipartUploadMaxConcurrency; private volatile boolean closed = false; AzureClientProvider( @@ -116,7 +117,8 @@ class AzureClientProvider extends AbstractLifecycleComponent { String reactorExecutorName, EventLoopGroup eventLoopGroup, ConnectionProvider connectionProvider, - ByteBufAllocator byteBufAllocator + ByteBufAllocator byteBufAllocator, + int multipartUploadMaxConcurrency ) { this.threadPool = threadPool; this.reactorExecutorName = reactorExecutorName; @@ -127,6 +129,7 @@ class AzureClientProvider extends AbstractLifecycleComponent { // hence we need to use the same instance across all the client instances // to avoid creating multiple connection pools. this.nioLoopResources = useNative -> eventLoopGroup; + this.multipartUploadMaxConcurrency = multipartUploadMaxConcurrency; } static int eventLoopThreadsFromSettings(Settings settings) { @@ -152,7 +155,14 @@ static AzureClientProvider create(ThreadPool threadPool, Settings settings) { // Just to verify that this executor exists threadPool.executor(REPOSITORY_THREAD_POOL_NAME); - return new AzureClientProvider(threadPool, REPOSITORY_THREAD_POOL_NAME, eventLoopGroup, provider, NettyAllocator.getAllocator()); + return new AzureClientProvider( + threadPool, + REPOSITORY_THREAD_POOL_NAME, + eventLoopGroup, + provider, + NettyAllocator.getAllocator(), + threadPool.info(REPOSITORY_THREAD_POOL_NAME).getMax() + ); } AzureBlobServiceClient createClient( @@ -250,6 +260,10 @@ protected void doStop() { @Override protected void doClose() {} + public int getMultipartUploadMaxConcurrency() { + return multipartUploadMaxConcurrency; + } + // visible for testing ConnectionProvider getConnectionProvider() { return connectionProvider; diff --git a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java index e26e984810935..dc8faa7d9e1e6 100644 --- a/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java +++ b/modules/repository-azure/src/main/java/org/elasticsearch/repositories/azure/AzureStorageService.java @@ -68,6 +68,7 @@ public class AzureStorageService { public static final ByteSizeValue MAX_CHUNK_SIZE = ByteSizeValue.ofBytes(MAX_BLOB_SIZE); private static final long DEFAULT_UPLOAD_BLOCK_SIZE = DEFAULT_BLOCK_SIZE.getBytes(); + private final int multipartUploadMaxConcurrency; // 'package' for testing volatile Map storageSettings = emptyMap(); @@ -81,6 +82,7 @@ public AzureStorageService(Settings settings, AzureClientProvider azureClientPro refreshSettings(clientsSettings); this.azureClientProvider = azureClientProvider; this.stateless = DiscoveryNode.isStateless(settings); + this.multipartUploadMaxConcurrency = azureClientProvider.getMultipartUploadMaxConcurrency(); } public AzureBlobServiceClient client(String clientName, LocationMode locationMode, OperationPurpose purpose) { @@ -196,4 +198,8 @@ public Set getExtraUsageFeatures(String clientName) { return Set.of(); } } + + public int getMultipartUploadMaxConcurrency() { + return multipartUploadMaxConcurrency; + } } diff --git a/server/src/main/java/org/elasticsearch/common/blobstore/BlobContainer.java b/server/src/main/java/org/elasticsearch/common/blobstore/BlobContainer.java index a8ca895480779..b33fa9cd6b117 100644 --- a/server/src/main/java/org/elasticsearch/common/blobstore/BlobContainer.java +++ b/server/src/main/java/org/elasticsearch/common/blobstore/BlobContainer.java @@ -10,6 +10,7 @@ package org.elasticsearch.common.blobstore; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.blobstore.support.BlobMetadata; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -144,6 +145,24 @@ void writeMetadataBlob( CheckedConsumer writer ) throws IOException; + /** + * Indicates if the implementation supports writing large blobs using concurrent multipart uploads. + * @return {@code true} if the implementation supports writing large blobs using concurrent multipart uploads, {@code false} otherwise + */ + default boolean supportsConcurrentMultipartUploads() { + return false; + } + + default void writeBlobAtomic( + OperationPurpose purpose, + String blobName, + long blobSize, + CheckedBiFunction provider, + boolean failIfAlreadyExists + ) throws IOException { + throw new UnsupportedOperationException(); + } + /** * Reads blob content from the input stream and writes it to the container in a new blob with the given name, * using an atomic write operation if the implementation supports it.