diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java index 3fd8c3cc0165..22ac529aafa9 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncRequestBody.java @@ -33,6 +33,7 @@ import software.amazon.awssdk.annotations.SdkPublicApi; import software.amazon.awssdk.core.FileRequestBodyConfiguration; import software.amazon.awssdk.core.internal.async.ByteBuffersAsyncRequestBody; +import software.amazon.awssdk.core.internal.async.ClosableAsyncRequestBodyAdaptor; import software.amazon.awssdk.core.internal.async.FileAsyncRequestBody; import software.amazon.awssdk.core.internal.async.InputStreamWithExecutorAsyncRequestBody; import software.amazon.awssdk.core.internal.async.SplittingPublisher; @@ -507,17 +508,33 @@ static AsyncRequestBody empty() { * is 2MB and the default buffer size is 8MB. * *

- * By default, if content length of this {@link AsyncRequestBody} is present, each divided {@link AsyncRequestBody} is - * delivered to the subscriber right after it's initialized. On the other hand, if content length is null, it is sent after - * the entire content for that chunk is buffered. In this case, the configured {@code maxMemoryUsageInBytes} must be larger - * than or equal to {@code chunkSizeInBytes}. Note that this behavior may be different if a specific implementation of this - * interface overrides this method. + * Each divided {@link AsyncRequestBody} is sent after the entire content for that chunk is buffered. * * @see AsyncRequestBodySplitConfiguration + * @deprecated Use {@link #splitV2(AsyncRequestBodySplitConfiguration)} instead. */ + @Deprecated default SdkPublisher split(AsyncRequestBodySplitConfiguration splitConfiguration) { Validate.notNull(splitConfiguration, "splitConfiguration"); + return splitV2(splitConfiguration).map(body -> new ClosableAsyncRequestBodyAdaptor(body)); + } + /** + * Converts this {@link AsyncRequestBody} to a publisher of {@link ClosableAsyncRequestBody}s, each of which publishes + * specific portion of the original data, based on the provided {@link AsyncRequestBodySplitConfiguration}. The default chunk + * size is 2MB and the default buffer size is 8MB. + * + *

+ * Each divided {@link ClosableAsyncRequestBody} is sent after the entire content for that chunk is buffered. This behavior + * may be different if a specific implementation of this interface overrides this method. + * + *

+ * Each {@link ClosableAsyncRequestBody} MUST be closed by the user when it is ready to be disposed. + * + * @see AsyncRequestBodySplitConfiguration + */ + default SdkPublisher splitV2(AsyncRequestBodySplitConfiguration splitConfiguration) { + Validate.notNull(splitConfiguration, "splitConfiguration"); return new SplittingPublisher(this, splitConfiguration); } @@ -526,12 +543,26 @@ default SdkPublisher split(AsyncRequestBodySplitConfiguration * avoiding the need to create one manually via {@link AsyncRequestBodySplitConfiguration#builder()}. * * @see #split(AsyncRequestBodySplitConfiguration) + * @deprecated Use {@link #splitV2(Consumer)} instead. */ + @Deprecated default SdkPublisher split(Consumer splitConfiguration) { Validate.notNull(splitConfiguration, "splitConfiguration"); return split(AsyncRequestBodySplitConfiguration.builder().applyMutation(splitConfiguration).build()); } + /** + * This is a convenience method that passes an instance of the {@link AsyncRequestBodySplitConfiguration} builder, + * avoiding the need to create one manually via {@link AsyncRequestBodySplitConfiguration#builder()}. + * + * @see #splitV2(Consumer) + */ + default SdkPublisher splitV2( + Consumer splitConfiguration) { + Validate.notNull(splitConfiguration, "splitConfiguration"); + return splitV2(AsyncRequestBodySplitConfiguration.builder().applyMutation(splitConfiguration).build()); + } + @SdkProtectedApi enum BodyType { FILE("File", "f"), diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/ClosableAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/ClosableAsyncRequestBody.java new file mode 100644 index 000000000000..7f495883c477 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/ClosableAsyncRequestBody.java @@ -0,0 +1,26 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.async; + +import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.utils.SdkAutoCloseable; + +/** + * An extension of {@link AsyncRequestBody} that is closable. + */ +@SdkPublicApi +public interface ClosableAsyncRequestBody extends AsyncRequestBody, SdkAutoCloseable { +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncRequestBodyListener.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncRequestBodyListener.java index a37b226d4bc3..e0228b896d26 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncRequestBodyListener.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncRequestBodyListener.java @@ -23,6 +23,7 @@ import software.amazon.awssdk.annotations.SdkProtectedApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Validate; @@ -76,6 +77,17 @@ public SdkPublisher split(Consumer splitV2(AsyncRequestBodySplitConfiguration splitConfiguration) { + return delegate.splitV2(splitConfiguration); + } + + @Override + public SdkPublisher splitV2( + Consumer splitConfiguration) { + return delegate.splitV2(splitConfiguration); + } + @Override public void subscribe(Subscriber s) { invoke(() -> listener.publisherSubscribe(s), "publisherSubscribe"); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java index a304d75ccf94..1ae49d0dfdde 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java @@ -76,7 +76,9 @@ public final class ByteBuffersAsyncRequestBody implements AsyncRequestBody, SdkA private final Object lock = new Object(); private boolean closed; - private ByteBuffersAsyncRequestBody(String mimetype, Long length, List buffers) { + private ByteBuffersAsyncRequestBody(String mimetype, + Long length, + List buffers) { this.mimetype = mimetype; this.buffers = buffers; this.length = length; diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ClosableAsyncRequestBodyAdaptor.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ClosableAsyncRequestBodyAdaptor.java new file mode 100644 index 000000000000..1c3d126981dd --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ClosableAsyncRequestBodyAdaptor.java @@ -0,0 +1,67 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.nio.ByteBuffer; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; +import software.amazon.awssdk.core.exception.NonRetryableException; +import software.amazon.awssdk.core.internal.util.NoopSubscription; + +/** + * Adaptor to convert a {@link ClosableAsyncRequestBody} to an {@link AsyncRequestBody} + * + *

+ * This is needed to maintain backwards compatibility for the deprecated + * {@link AsyncRequestBody#split(AsyncRequestBodySplitConfiguration)} + */ +@SdkInternalApi +public final class ClosableAsyncRequestBodyAdaptor implements AsyncRequestBody { + + private final AtomicBoolean subscribeCalled; + private final ClosableAsyncRequestBody delegate; + + public ClosableAsyncRequestBodyAdaptor(ClosableAsyncRequestBody delegate) { + this.delegate = delegate; + subscribeCalled = new AtomicBoolean(false); + } + + @Override + public Optional contentLength() { + return delegate.contentLength(); + } + + @Override + public void subscribe(Subscriber subscriber) { + if (subscribeCalled.compareAndSet(false, true)) { + delegate.doAfterOnComplete(() -> delegate.close()) + .doAfterOnCancel(() -> delegate.close()) + .doAfterOnError(t -> delegate.close()) + .subscribe(subscriber); + } else { + subscriber.onSubscribe(new NoopSubscription(subscriber)); + subscriber.onError(NonRetryableException.create( + "A retry was attempted, but AsyncRequestBody.split does not " + + "support retries.")); + } + } + +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBody.java index f5dcc164f61c..656cd1a38e56 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBody.java @@ -34,6 +34,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.internal.util.Mimetype; import software.amazon.awssdk.core.internal.util.NoopSubscription; @@ -86,6 +87,11 @@ public SdkPublisher split(AsyncRequestBodySplitConfiguration s return new FileAsyncRequestBodySplitHelper(this, splitConfiguration).split(); } + @Override + public SdkPublisher splitV2(AsyncRequestBodySplitConfiguration splitConfiguration) { + return split(splitConfiguration).map(body -> new ClosableAsyncRequestBodyWrapper(body)); + } + public Path path() { return path; } @@ -436,4 +442,26 @@ private void signalOnError(Throwable t) { private static AsynchronousFileChannel openInputChannel(Path path) throws IOException { return AsynchronousFileChannel.open(path, StandardOpenOption.READ); } + + private static class ClosableAsyncRequestBodyWrapper implements ClosableAsyncRequestBody { + private final AsyncRequestBody body; + + ClosableAsyncRequestBodyWrapper(AsyncRequestBody body) { + this.body = body; + } + + @Override + public Optional contentLength() { + return body.contentLength(); + } + + @Override + public void subscribe(Subscriber s) { + body.subscribe(s); + } + + @Override + public void close() { + } + } } \ No newline at end of file diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java index 12278cf84dca..c44c728746a3 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/SplittingPublisher.java @@ -16,6 +16,8 @@ package software.amazon.awssdk.core.internal.async; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -25,9 +27,8 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; -import software.amazon.awssdk.core.exception.NonRetryableException; -import software.amazon.awssdk.core.internal.util.NoopSubscription; import software.amazon.awssdk.utils.Logger; import software.amazon.awssdk.utils.Validate; import software.amazon.awssdk.utils.async.SimplePublisher; @@ -36,17 +37,17 @@ * Splits an {@link AsyncRequestBody} to multiple smaller {@link AsyncRequestBody}s, each of which publishes a specific portion of * the original data. * - *

If content length is known, each {@link AsyncRequestBody} is sent to the subscriber right after it's initialized. - * Otherwise, it is sent after the entire content for that chunk is buffered. This is required to get content length. + *

Each {@link AsyncRequestBody} is sent after the entire content for that chunk is buffered. */ @SdkInternalApi -public class SplittingPublisher implements SdkPublisher { +public class SplittingPublisher implements SdkPublisher { private static final Logger log = Logger.loggerFor(SplittingPublisher.class); private final AsyncRequestBody upstreamPublisher; private final SplittingSubscriber splittingSubscriber; - private final SimplePublisher downstreamPublisher = new SimplePublisher<>(); + private final SimplePublisher downstreamPublisher = new SimplePublisher<>(); private final long chunkSizeInBytes; private final long bufferSizeInBytes; + private final AtomicBoolean currentBodySent = new AtomicBoolean(false); public SplittingPublisher(AsyncRequestBody asyncRequestBody, AsyncRequestBodySplitConfiguration splitConfiguration) { @@ -62,15 +63,13 @@ public SplittingPublisher(AsyncRequestBody asyncRequestBody, this.splittingSubscriber = new SplittingSubscriber(upstreamPublisher.contentLength().orElse(null)); - if (!upstreamPublisher.contentLength().isPresent()) { - Validate.isTrue(bufferSizeInBytes >= chunkSizeInBytes, - "bufferSizeInBytes must be larger than or equal to " + - "chunkSizeInBytes if the content length is unknown"); - } + Validate.isTrue(bufferSizeInBytes >= chunkSizeInBytes, + "bufferSizeInBytes must be larger than or equal to " + + "chunkSizeInBytes"); } @Override - public void subscribe(Subscriber downstreamSubscriber) { + public void subscribe(Subscriber downstreamSubscriber) { downstreamPublisher.subscribe(downstreamSubscriber); upstreamPublisher.subscribe(splittingSubscriber); } @@ -78,7 +77,10 @@ public void subscribe(Subscriber downstreamSubscriber) private class SplittingSubscriber implements Subscriber { private Subscription upstreamSubscription; private final Long upstreamSize; - private final AtomicInteger chunkNumber = new AtomicInteger(0); + /** + * 1 based index number for each part/chunk + */ + private final AtomicInteger partNumber = new AtomicInteger(1); private volatile DownstreamBody currentBody; private final AtomicBoolean hasOpenUpstreamDemand = new AtomicBoolean(false); private final AtomicLong dataBuffered = new AtomicLong(0); @@ -98,17 +100,15 @@ public void onSubscribe(Subscription s) { this.upstreamSubscription = s; this.currentBody = initializeNextDownstreamBody(upstreamSize != null, calculateChunkSize(upstreamSize), - chunkNumber.get()); + partNumber.get()); // We need to request subscription *after* we set currentBody because onNext could be invoked right away. upstreamSubscription.request(1); } - private DownstreamBody initializeNextDownstreamBody(boolean contentLengthKnown, long chunkSize, int chunkNumber) { - DownstreamBody body = new DownstreamBody(contentLengthKnown, chunkSize, chunkNumber); - if (contentLengthKnown) { - sendCurrentBody(body); - } - return body; + private DownstreamBody initializeNextDownstreamBody(boolean contentLengthKnown, long chunkSize, int partNumber) { + currentBodySent.set(false); + log.debug(() -> "initializing next downstream body " + partNumber); + return new DownstreamBody(contentLengthKnown, chunkSize, partNumber); } @Override @@ -157,8 +157,8 @@ public void onNext(ByteBuffer byteBuffer) { } private void completeCurrentBodyAndCreateNewIfNeeded(ByteBuffer byteBuffer) { - completeCurrentBody(); - int currentChunk = chunkNumber.incrementAndGet(); + completeCurrentBodyAndDeliver(); + int nextChunk = partNumber.incrementAndGet(); boolean shouldCreateNewDownstreamRequestBody; Long dataRemaining = totalDataRemaining(); @@ -170,27 +170,47 @@ private void completeCurrentBodyAndCreateNewIfNeeded(ByteBuffer byteBuffer) { if (shouldCreateNewDownstreamRequestBody) { long chunkSize = calculateChunkSize(dataRemaining); - currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, currentChunk); + currentBody = initializeNextDownstreamBody(upstreamSize != null, chunkSize, nextChunk); } } private int amountRemainingInChunk() { - return Math.toIntExact(currentBody.maxLength - currentBody.transferredLength); + return Math.toIntExact(currentBody.maxLength - currentBody.bufferedLength); } - private void completeCurrentBody() { - log.debug(() -> "completeCurrentBody for chunk " + chunkNumber.get()); - currentBody.complete(); - if (upstreamSize == null) { - sendCurrentBody(currentBody); + /** + * Current body could be completed in either onNext or onComplete, so we need to guard against sending the last body + * twice. + */ + private void completeCurrentBodyAndDeliver() { + if (currentBodySent.compareAndSet(false, true)) { + log.debug(() -> "completeCurrentBody for chunk " + currentBody.partNumber); + // For unknown content length, we always create a new DownstreamBody because we don't know if there is data + // left or not, so we need to only send the body if there is actually data + long bufferedLength = currentBody.bufferedLength; + Long totalLength = currentBody.totalLength; + if (bufferedLength > 0) { + if (totalLength != null && totalLength != bufferedLength) { + upstreamSubscription.cancel(); + downstreamPublisher.error(new IllegalStateException( + String.format("Content length of buffered data mismatches " + + "with the expected content length, buffered data content length: %d, " + + "expected length: %d", totalLength, + bufferedLength))); + return; + } + + currentBody.complete(); + sendCurrentBody(currentBody); + } } } @Override public void onComplete() { upstreamComplete = true; - log.trace(() -> "Received onComplete()"); - completeCurrentBody(); + log.debug(() -> "Received onComplete() from upstream AsyncRequestBody"); + completeCurrentBodyAndDeliver(); downstreamPublisher.complete(); } @@ -200,7 +220,8 @@ public void onError(Throwable t) { downstreamPublisher.error(t); } - private void sendCurrentBody(AsyncRequestBody body) { + private void sendCurrentBody(DownstreamBody body) { + log.debug(() -> "sendCurrentBody for chunk " + body.partNumber); downstreamPublisher.send(body).exceptionally(t -> { downstreamPublisher.error(t); upstreamSubscription.cancel(); @@ -227,17 +248,21 @@ private void maybeRequestMoreUpstreamData() { } private boolean shouldRequestMoreData(long buffered) { - return buffered == 0 || buffered + byteBufferSizeHint <= bufferSizeInBytes; + return buffered <= 0 || buffered + byteBufferSizeHint <= bufferSizeInBytes; } private Long totalDataRemaining() { if (upstreamSize == null) { return null; } - return upstreamSize - (chunkNumber.get() * chunkSizeInBytes); + return upstreamSize - ((partNumber.get() - 1) * chunkSizeInBytes); } - private final class DownstreamBody implements AsyncRequestBody { + /** + * AsyncRequestBody for individual part. The entire data is buffered in memory and can be subscribed multiple times + * for retry attempts. The buffered data is cleared upon close + */ + private final class DownstreamBody implements ClosableAsyncRequestBody { /** * The maximum length of the content this AsyncRequestBody can hold. If the upstream content length is known, this is @@ -245,66 +270,54 @@ private final class DownstreamBody implements AsyncRequestBody { */ private final long maxLength; private final Long totalLength; - private final SimplePublisher delegate = new SimplePublisher<>(); - private final int chunkNumber; - private final AtomicBoolean subscribeCalled = new AtomicBoolean(false); - private volatile long transferredLength = 0; + private final int partNumber; + private volatile long bufferedLength = 0; + private volatile ByteBuffersAsyncRequestBody delegate; + private final List buffers = new ArrayList<>(); - private DownstreamBody(boolean contentLengthKnown, long maxLength, int chunkNumber) { + private DownstreamBody(boolean contentLengthKnown, long maxLength, int partNumber) { this.totalLength = contentLengthKnown ? maxLength : null; this.maxLength = maxLength; - this.chunkNumber = chunkNumber; + this.partNumber = partNumber; } @Override public Optional contentLength() { - return totalLength != null ? Optional.of(totalLength) : Optional.of(transferredLength); + return totalLength != null ? Optional.of(totalLength) : Optional.of(bufferedLength); } public void send(ByteBuffer data) { - log.trace(() -> String.format("Sending bytebuffer %s to chunk %d", data, chunkNumber)); + log.debug(() -> String.format("Sending bytebuffer %s to chunk %d", data, partNumber)); int length = data.remaining(); - transferredLength += length; + bufferedLength += length; addDataBuffered(length); - delegate.send(data).whenComplete((r, t) -> { - addDataBuffered(-length); - if (t != null) { - error(t); - } - }); + buffers.add(data); } public void complete() { - log.debug(() -> "Received complete() for chunk number: " + chunkNumber + " length " + transferredLength); - delegate.complete().whenComplete((r, t) -> { - if (t != null) { - error(t); - } - }); - } - - public void error(Throwable error) { - delegate.error(error); + log.debug(() -> "Received complete() for chunk number: " + partNumber + " length " + bufferedLength); + this.delegate = ByteBuffersAsyncRequestBody.of(buffers); } @Override public void subscribe(Subscriber s) { - if (subscribeCalled.compareAndSet(false, true)) { - delegate.subscribe(s); - } else { - s.onSubscribe(new NoopSubscription(s)); - s.onError(NonRetryableException.create( - "A retry was attempted, but AsyncRequestBody.split does not " - + "support retries.")); - } + log.debug(() -> "Subscribe for chunk number: " + partNumber + " length " + bufferedLength); + delegate.subscribe(s); } - private void addDataBuffered(int length) { + private void addDataBuffered(long length) { dataBuffered.addAndGet(length); if (length < 0) { maybeRequestMoreUpstreamData(); } } + + @Override + public void close() { + log.debug(() -> "Closing current body " + partNumber); + delegate.close(); + addDataBuffered(-bufferedLength); + } } } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfigurationTest.java similarity index 97% rename from core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java rename to core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfigurationTest.java index 8b8f78f2b5e9..e932da3bfa1c 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyConfigurationTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodySplitConfigurationTest.java @@ -23,7 +23,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -public class AsyncRequestBodyConfigurationTest { +public class AsyncRequestBodySplitConfigurationTest { @Test void equalsHashCode() { diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java index cdd87822e3d4..f0c72a624119 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/async/AsyncRequestBodyTest.java @@ -19,10 +19,12 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static software.amazon.awssdk.core.internal.async.SplittingPublisherTestUtils.verifyIndividualAsyncRequestBody; import com.google.common.jimfs.Configuration; import com.google.common.jimfs.Jimfs; import io.reactivex.Flowable; +import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.Charset; @@ -31,13 +33,20 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.List; +import java.util.Optional; +import java.util.UUID; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.commons.lang3.RandomStringUtils; import org.assertj.core.util.Lists; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import org.reactivestreams.Publisher; @@ -49,16 +58,33 @@ public class AsyncRequestBodyTest { private static final String testString = "Hello!"; - private static final Path path; - - static { - FileSystem fs = Jimfs.newFileSystem(Configuration.unix()); + private static Path path; + private static final int CONTENT_SIZE = 1024; + private static final byte[] CONTENT = + RandomStringUtils.randomAscii(CONTENT_SIZE).getBytes(Charset.defaultCharset()); + private static File fileForSplit; + private static FileSystem fs; + + @BeforeAll + public static void setup() throws IOException { + fs = Jimfs.newFileSystem(Configuration.unix()); path = fs.getPath("./test"); - try { - Files.write(path, testString.getBytes()); - } catch (IOException e) { - e.printStackTrace(); - } + Files.write(path, testString.getBytes()); + + fileForSplit = File.createTempFile("SplittingPublisherTest", UUID.randomUUID().toString()); + Files.write(fileForSplit.toPath(), CONTENT); + } + + @AfterAll + public static void teardown() throws IOException { + fileForSplit.delete(); + fs.close(); + } + + public static Stream asyncRequestBodies() { + return Stream.of(Arguments.of(AsyncRequestBody.fromBytes(CONTENT)), + Arguments.of(AsyncRequestBody.fromFile(b -> b.chunkSizeInBytes(50) + .path(fileForSplit.toPath())))); } @ParameterizedTest @@ -300,6 +326,34 @@ void rewindingByteBufferBuildersReadTheInputBufferFromTheBeginning( assertEquals(bb, publishedBuffer.get()); } + @ParameterizedTest + @MethodSource("asyncRequestBodies") + void legacySplit_shouldWork(AsyncRequestBody delegate) throws Exception { + long chunkSize = 20l; + AsyncRequestBody asyncRequestBody = new AsyncRequestBody() { + @Override + public Optional contentLength() { + return delegate.contentLength(); + } + + @Override + public void subscribe(Subscriber s) { + delegate.subscribe(s); + } + }; + + AsyncRequestBodySplitConfiguration configuration = AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes(chunkSize) + .bufferSizeInBytes(chunkSize) + .build(); + + SdkPublisher split = asyncRequestBody.split(configuration); + verifyIndividualAsyncRequestBody(split, fileForSplit.toPath(), (int) chunkSize); + } + + + + private static Function[] rewindingByteBufferBodyBuilders() { Function fromByteBuffer = AsyncRequestBody::fromByteBuffer; Function fromByteBufferUnsafe = AsyncRequestBody::fromByteBufferUnsafe; @@ -356,4 +410,13 @@ void publisherConstructorHasCorrectContentType() { AsyncRequestBody requestBody = AsyncRequestBody.fromPublisher(bodyPublisher); assertEquals(Mimetype.MIMETYPE_OCTET_STREAM, requestBody.contentType()); } + + @Test + void splitV2_nullConfig_shouldThrowException() { + AsyncRequestBody requestBody = AsyncRequestBody.fromString("hello world"); + AsyncRequestBodySplitConfiguration config = null; + assertThatThrownBy(() -> requestBody.splitV2(config)) + .isInstanceOf(NullPointerException.class) + .hasMessageContaining("splitConfig"); + } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ClosableAsyncRequestBodyAdaptorTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ClosableAsyncRequestBodyAdaptorTest.java new file mode 100644 index 000000000000..1a9a3ffc833b --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ClosableAsyncRequestBodyAdaptorTest.java @@ -0,0 +1,170 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.reactivex.Flowable; +import io.reactivex.FlowableSubscriber; +import io.reactivex.internal.observers.BiConsumerSingleObserver; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Observable; +import java.util.Observer; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; +import software.amazon.awssdk.core.exception.NonRetryableException; + +public class ClosableAsyncRequestBodyAdaptorTest { + private ClosableAsyncRequestBody closableAsyncRequestBody; + + @BeforeEach + public void setup() { + closableAsyncRequestBody =Mockito.mock(ClosableAsyncRequestBody.class); + Mockito.when(closableAsyncRequestBody.doAfterOnComplete(any(Runnable.class))).thenReturn(closableAsyncRequestBody); + Mockito.when(closableAsyncRequestBody.doAfterOnCancel(any(Runnable.class))).thenReturn(closableAsyncRequestBody); + Mockito.when(closableAsyncRequestBody.doAfterOnError(any(Consumer.class))).thenReturn(closableAsyncRequestBody); + } + + @Test + void resubscribe_shouldThrowException() { + ClosableAsyncRequestBodyAdaptor adaptor = new ClosableAsyncRequestBodyAdaptor(closableAsyncRequestBody); + Subscriber subscriber = Mockito.mock(Subscriber.class); + adaptor.subscribe(subscriber); + + Subscriber anotherSubscriber = Mockito.mock(Subscriber.class); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); + doNothing().when(anotherSubscriber).onError(exceptionCaptor.capture()); + + adaptor.subscribe(anotherSubscriber); + + assertThat(exceptionCaptor.getValue()) + .isInstanceOf(NonRetryableException.class) + .hasMessageContaining("A retry was attempted"); + } + + @Test + void onComplete_shouldCloseAsyncRequestBody() { + TestClosableAsyncRequestBody asyncRequestBody = new TestClosableAsyncRequestBody(); + ClosableAsyncRequestBodyAdaptor adaptor = new ClosableAsyncRequestBodyAdaptor(asyncRequestBody); + CompletableFuture future = new CompletableFuture<>(); + Subscriber subscriber = new ByteArrayAsyncResponseTransformer.BaosSubscriber(future); + adaptor.subscribe(subscriber); + assertThat(asyncRequestBody.closeInvoked).isTrue(); + } + + @Test + void cancel_shouldCloseAsyncRequestBody() { + TestClosableAsyncRequestBody asyncRequestBody = new TestClosableAsyncRequestBody(); + ClosableAsyncRequestBodyAdaptor adaptor = new ClosableAsyncRequestBodyAdaptor(asyncRequestBody); + Subscriber subscriber = new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.cancel(); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + } + + @Override + public void onError(Throwable t) { + } + + @Override + public void onComplete() { + } + }; + adaptor.subscribe(subscriber); + assertThat(asyncRequestBody.closeInvoked).isTrue(); + } + + @Test + void onError_shouldCloseAsyncRequestBody() { + OnErrorClosableAsyncRequestBody asyncRequestBody = new OnErrorClosableAsyncRequestBody(); + ClosableAsyncRequestBodyAdaptor adaptor = new ClosableAsyncRequestBodyAdaptor(asyncRequestBody); + CompletableFuture future = new CompletableFuture<>(); + Subscriber subscriber = new ByteArrayAsyncResponseTransformer.BaosSubscriber(future); + adaptor.subscribe(subscriber); + assertThat(asyncRequestBody.closeInvoked).isTrue(); + } + + + private static class TestClosableAsyncRequestBody implements ClosableAsyncRequestBody { + private boolean closeInvoked; + + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber s) { + Flowable.just(ByteBuffer.wrap("foo bar".getBytes(StandardCharsets.UTF_8))) + .subscribe(s); + } + + @Override + public void close() { + closeInvoked = true; + } + } + + private static class OnErrorClosableAsyncRequestBody implements ClosableAsyncRequestBody { + private boolean closeInvoked; + + @Override + public Optional contentLength() { + return Optional.empty(); + } + + @Override + public void subscribe(Subscriber s) { + s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + s.onError(new IllegalStateException("foobar")); + } + + @Override + public void cancel() { + + } + }); + } + + @Override + public void close() { + closeInvoked = true; + } + } +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBodySplitHelperTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBodySplitHelperTest.java index 4c5d0748d16d..1edea1a58b1b 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBodySplitHelperTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncRequestBodySplitHelperTest.java @@ -77,7 +77,9 @@ public void split_differentChunkSize_shouldSplitCorrectly(int chunkSize) throws ScheduledFuture scheduledFuture = executor.scheduleWithFixedDelay(verifyConcurrentRequests(helper, maxConcurrency), 1, 50, TimeUnit.MICROSECONDS); - verifyIndividualAsyncRequestBody(helper.split(), testFile, chunkSize); + verifyIndividualAsyncRequestBody(helper.split(), + testFile, + chunkSize); scheduledFuture.cancel(true); int expectedMaxConcurrency = (int) (bufferSize / chunkSizeInBytes); assertThat(maxConcurrency.get()).isLessThanOrEqualTo(expectedMaxConcurrency); diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java index 6f116ca2667c..a1f22d2bf7af 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTest.java @@ -16,7 +16,6 @@ package software.amazon.awssdk.core.internal.async; import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.assertj.core.api.Assertions.fail; import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import static software.amazon.awssdk.core.internal.async.SplittingPublisherTestUtils.verifyIndividualAsyncRequestBody; import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely; @@ -24,7 +23,6 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.Charset; @@ -42,7 +40,6 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; -import org.reactivestreams.Publisher; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; @@ -72,13 +69,18 @@ public static void afterAll() throws Exception { } @Test - public void split_contentUnknownMaxMemorySmallerThanChunkSize_shouldThrowException() { + public void split_MaxMemorySmallerThanChunkSize_shouldThrowException() { AsyncRequestBody body = AsyncRequestBody.fromPublisher(s -> { }); - assertThatThrownBy(() -> new SplittingPublisher(body, AsyncRequestBodySplitConfiguration.builder() - .chunkSizeInBytes(10L) - .bufferSizeInBytes(5L) - .build())) + AsyncRequestBodySplitConfiguration configuration = AsyncRequestBodySplitConfiguration.builder() + .chunkSizeInBytes(10L) + .bufferSizeInBytes(5L) + .build(); + assertThatThrownBy(() -> new SplittingPublisher(body, configuration)) + .hasMessageContaining("must be larger than or equal"); + + assertThatThrownBy(() -> new SplittingPublisher(AsyncRequestBody.fromString("test"), + configuration)) .hasMessageContaining("must be larger than or equal"); } @@ -169,7 +171,7 @@ private static void verifySplitContent(AsyncRequestBody asyncRequestBody, int ch .bufferSizeInBytes((long) chunkSize * 4) .build()); - verifyIndividualAsyncRequestBody(splittingPublisher, testFile.toPath(), chunkSize); + verifyIndividualAsyncRequestBody(splittingPublisher.map(m -> m), testFile.toPath(), chunkSize); } private static class TestAsyncRequestBody implements AsyncRequestBody { @@ -204,30 +206,6 @@ public void cancel() { } } - private static final class OnlyRequestOnceSubscriber implements Subscriber { - private List asyncRequestBodies = new ArrayList<>(); - - @Override - public void onSubscribe(Subscription s) { - s.request(1); - } - - @Override - public void onNext(AsyncRequestBody requestBody) { - asyncRequestBodies.add(requestBody); - } - - @Override - public void onError(Throwable t) { - - } - - @Override - public void onComplete() { - - } - } - private static final class BaosSubscriber implements Subscriber { private final CompletableFuture resultFuture; diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTestUtils.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTestUtils.java index 04da97adbf42..145a1cecc0ef 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTestUtils.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/SplittingPublisherTestUtils.java @@ -15,23 +15,16 @@ package software.amazon.awssdk.core.internal.async; -import static org.assertj.core.api.AssertionsForClassTypes.assertThat; - -import java.io.File; import java.io.FileInputStream; import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import org.assertj.core.api.Assertions; -import org.reactivestreams.Publisher; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; -import software.amazon.awssdk.core.internal.async.ByteArrayAsyncResponseTransformer; -import software.amazon.awssdk.core.internal.async.SplittingPublisherTest; public final class SplittingPublisherTestUtils { @@ -45,6 +38,11 @@ public static void verifyIndividualAsyncRequestBody(SdkPublisher { + if (requestBody instanceof ClosableAsyncRequestBody) { + ((ClosableAsyncRequestBody) requestBody).close(); + } + }); futures.add(baosFuture); }).get(5, TimeUnit.SECONDS); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java index 93bc0dfeb6f8..527119415ea6 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriber.java @@ -34,6 +34,7 @@ import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; import software.amazon.awssdk.core.async.listener.PublisherListener; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; @@ -47,7 +48,7 @@ import software.amazon.awssdk.utils.Pair; @SdkInternalApi -public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { +public class KnownContentLengthAsyncRequestBodySubscriber implements Subscriber { private static final Logger log = Logger.loggerFor(KnownContentLengthAsyncRequestBodySubscriber.class); @@ -144,16 +145,21 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(AsyncRequestBody asyncRequestBody) { + public void onNext(ClosableAsyncRequestBody asyncRequestBody) { if (isPaused || isDone) { return; } int currentPartNum = partNumber.getAndIncrement(); + + log.debug(() -> String.format("Received asyncRequestBody for part number %d with length %s", currentPartNum, + asyncRequestBody.contentLength())); + if (existingParts.containsKey(currentPartNum)) { asyncRequestBody.subscribe(new CancelledSubscriber<>()); - subscription.request(1); asyncRequestBody.contentLength().ifPresent(progressListener::subscriberOnNext); + asyncRequestBody.close(); + subscription.request(1); return; } @@ -178,10 +184,12 @@ public void onNext(AsyncRequestBody asyncRequestBody) { multipartUploadHelper.sendIndividualUploadPartRequest(uploadId, completedPartConsumer, futures, Pair.of(uploadRequest, asyncRequestBody), progressListener) .whenComplete((r, t) -> { + asyncRequestBody.close(); if (t != null) { if (shouldFailRequest()) { multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); + subscription.cancel(); } } else { completeMultipartUploadIfFinished(asyncRequestBodyInFlight.decrementAndGet()); @@ -206,7 +214,7 @@ private Optional validatePart(AsyncRequestBody asyncRequestB } if (currentPartSize != partSize) { - return Optional.of(contentLengthMismatchForPart(partSize, currentPartSize)); + return Optional.of(contentLengthMismatchForPart(partSize, currentPartSize, currentPartNum)); } return Optional.empty(); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java index d25d5b6fa7fa..d7c988c16e55 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartUploadHelper.java @@ -162,11 +162,12 @@ static SdkClientException contentLengthMissingForPart(int currentPartNum) { return SdkClientException.create("Content length is missing on the AsyncRequestBody for part number " + currentPartNum); } - static SdkClientException contentLengthMismatchForPart(long expected, long actual) { + static SdkClientException contentLengthMismatchForPart(long expected, long actual, int partNum) { return SdkClientException.create(String.format("Content length must not be greater than " - + "part size. Expected: %d, Actual: %d", + + "part size. Expected: %d, Actual: %d, partNum: %d", expected, - actual)); + actual, + partNum)); } static SdkClientException partNumMismatch(int expectedNumParts, int actualNumParts) { diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java index 04690677c92b..15f15767db54 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithKnownContentLengthHelper.java @@ -186,7 +186,7 @@ private void splitAndSubscribe(MpuRequestContext mpuRequestContext, CompletableF attachSubscriberToObservable(subscriber, mpuRequestContext.request().left()); mpuRequestContext.request().right() - .split(b -> b.chunkSizeInBytes(mpuRequestContext.partSize()) + .splitV2(b -> b.chunkSizeInBytes(mpuRequestContext.partSize()) .bufferSizeInBytes(maxMemoryUsageInBytes)) .subscribe(subscriber); } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java index 520625ad90b0..73a0c0d39cfc 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelper.java @@ -33,6 +33,7 @@ import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.async.listener.PublisherListener; import software.amazon.awssdk.core.exception.SdkClientException; @@ -53,12 +54,10 @@ public final class UploadWithUnknownContentLengthHelper { private static final Logger log = Logger.loggerFor(UploadWithUnknownContentLengthHelper.class); - private final S3AsyncClient s3AsyncClient; private final long partSizeInBytes; private final GenericMultipartHelper genericMultipartHelper; private final long maxMemoryUsageInBytes; - private final long multipartUploadThresholdInBytes; private final MultipartUploadHelper multipartUploadHelper; @@ -66,13 +65,11 @@ public UploadWithUnknownContentLengthHelper(S3AsyncClient s3AsyncClient, long partSizeInBytes, long multipartUploadThresholdInBytes, long maxMemoryUsageInBytes) { - this.s3AsyncClient = s3AsyncClient; this.partSizeInBytes = partSizeInBytes; this.genericMultipartHelper = new GenericMultipartHelper<>(s3AsyncClient, SdkPojoConversionUtils::toAbortMultipartUploadRequest, SdkPojoConversionUtils::toPutObjectResponse); this.maxMemoryUsageInBytes = maxMemoryUsageInBytes; - this.multipartUploadThresholdInBytes = multipartUploadThresholdInBytes; this.multipartUploadHelper = new MultipartUploadHelper(s3AsyncClient, multipartUploadThresholdInBytes, maxMemoryUsageInBytes); } @@ -81,8 +78,8 @@ public CompletableFuture uploadObject(PutObjectRequest putObj AsyncRequestBody asyncRequestBody) { CompletableFuture returnFuture = new CompletableFuture<>(); - SdkPublisher splitAsyncRequestBodyResponse = - asyncRequestBody.split(b -> b.chunkSizeInBytes(partSizeInBytes) + SdkPublisher splitAsyncRequestBodyResponse = + asyncRequestBody.splitV2(b -> b.chunkSizeInBytes(partSizeInBytes) .bufferSizeInBytes(maxMemoryUsageInBytes)); splitAsyncRequestBodyResponse.subscribe(new UnknownContentLengthAsyncRequestBodySubscriber(partSizeInBytes, @@ -91,11 +88,7 @@ public CompletableFuture uploadObject(PutObjectRequest putObj return returnFuture; } - private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscriber { - /** - * Indicates whether this is the first async request body or not. - */ - private final AtomicBoolean isFirstAsyncRequestBody = new AtomicBoolean(true); + private class UnknownContentLengthAsyncRequestBodySubscriber implements Subscriber { /** * Indicates whether CreateMultipartUpload has been initiated or not @@ -161,12 +154,13 @@ public void onSubscribe(Subscription s) { } @Override - public void onNext(AsyncRequestBody asyncRequestBody) { + public void onNext(ClosableAsyncRequestBody asyncRequestBody) { if (isDone) { return; } int currentPartNum = partNumber.incrementAndGet(); - log.trace(() -> "Received asyncRequestBody " + asyncRequestBody.contentLength()); + log.debug(() -> String.format("Received asyncRequestBody for part number %d with length %s", currentPartNum, + asyncRequestBody.contentLength())); asyncRequestBodyInFlight.incrementAndGet(); Optional sdkClientException = validatePart(asyncRequestBody, currentPartNum); @@ -177,14 +171,6 @@ public void onNext(AsyncRequestBody asyncRequestBody) { return; } - if (isFirstAsyncRequestBody.compareAndSet(true, false)) { - log.trace(() -> "Received first async request body"); - // If this is the first AsyncRequestBody received, request another one because we don't know if there is more - firstRequestBody = asyncRequestBody; - subscription.request(1); - return; - } - // If there are more than 1 AsyncRequestBodies, then we know we need to upload this // object using MPU if (createMultipartUploadInitiated.compareAndSet(false, true)) { @@ -201,8 +187,7 @@ public void onNext(AsyncRequestBody asyncRequestBody) { uploadId = createMultipartUploadResponse.uploadId(); log.debug(() -> "Initiated a new multipart upload, uploadId: " + uploadId); - sendUploadPartRequest(uploadId, firstRequestBody, 1); - sendUploadPartRequest(uploadId, asyncRequestBody, 2); + sendUploadPartRequest(uploadId, asyncRequestBody, currentPartNum); // We need to complete the uploadIdFuture *after* the first two requests have been sent uploadIdFuture.complete(uploadId); @@ -224,14 +209,14 @@ private Optional validatePart(AsyncRequestBody asyncRequestB Long contentLengthCurrentPart = contentLength.get(); if (contentLengthCurrentPart > partSizeInBytes) { - return Optional.of(contentLengthMismatchForPart(partSizeInBytes, contentLengthCurrentPart)); + return Optional.of(contentLengthMismatchForPart(partSizeInBytes, contentLengthCurrentPart, currentPartNum)); } return Optional.empty(); } private void sendUploadPartRequest(String uploadId, - AsyncRequestBody asyncRequestBody, + ClosableAsyncRequestBody asyncRequestBody, int currentPartNum) { Long contentLengthCurrentPart = asyncRequestBody.contentLength().get(); this.contentLength.getAndAdd(contentLengthCurrentPart); @@ -240,15 +225,17 @@ private void sendUploadPartRequest(String uploadId, .sendIndividualUploadPartRequest(uploadId, completedParts::add, futures, uploadPart(asyncRequestBody, currentPartNum), progressListener) .whenComplete((r, t) -> { + asyncRequestBody.close(); if (t != null) { if (failureActionInitiated.compareAndSet(false, true)) { + subscription.cancel(); multipartUploadHelper.failRequestsElegantly(futures, t, uploadId, returnFuture, putObjectRequest); } } else { completeMultipartUploadIfFinish(asyncRequestBodyInFlight.decrementAndGet()); } }); - synchronized (this) { + synchronized (subscription) { subscription.request(1); }; } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java index 4faf9d4a04b0..d2bd9d55dfb7 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/KnownContentLengthAsyncRequestBodySubscriberTest.java @@ -17,7 +17,6 @@ import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -25,11 +24,9 @@ import java.io.IOException; import java.util.Collection; -import java.util.HashMap; import java.util.Map; import java.util.Optional; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.function.Function; import java.util.stream.Collectors; @@ -39,9 +36,9 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.ArgumentCaptor; -import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; @@ -60,7 +57,7 @@ public class KnownContentLengthAsyncRequestBodySubscriberTest { private static final int TOTAL_NUM_PARTS = 4; private static final String UPLOAD_ID = "1234"; private static RandomTempFile testFile; - + private AsyncRequestBody asyncRequestBody; private PutObjectRequest putObjectRequest; private S3AsyncClient s3AsyncClient; @@ -114,7 +111,7 @@ void validatePart_withPartSizeExceedingLimit_shouldFailRequest() { void validateLastPartSize_withIncorrectSize_shouldFailRequest() { long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE; long incorrectLastPartSize = expectedLastPartSize + 1; - + KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext()); lastPartSubscriber.onSubscribe(subscription); @@ -130,12 +127,12 @@ void validateLastPartSize_withIncorrectSize_shouldFailRequest() { @Test void validateTotalPartNum_receivedMoreParts_shouldFail() { long expectedLastPartSize = MPU_CONTENT_SIZE % PART_SIZE; - + KnownContentLengthAsyncRequestBodySubscriber lastPartSubscriber = createSubscriber(createDefaultMpuRequestContext()); lastPartSubscriber.onSubscribe(subscription); for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { - AsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); + ClosableAsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(null)); lastPartSubscriber.onNext(regularPart); @@ -157,7 +154,7 @@ void validateLastPartSize_withCorrectSize_shouldNotFail() { subscriber.onSubscribe(subscription); for (int i = 0; i < TOTAL_NUM_PARTS - 1; i++) { - AsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); + ClosableAsyncRequestBody regularPart = createMockAsyncRequestBody(PART_SIZE); when(multipartUploadHelper.sendIndividualUploadPartRequest(eq(UPLOAD_ID), any(), any(), any(), any())) .thenReturn(CompletableFuture.completedFuture(null)); subscriber.onNext(regularPart); @@ -175,7 +172,7 @@ void validateLastPartSize_withCorrectSize_shouldNotFail() { void pause_withOngoingCompleteMpuFuture_shouldReturnTokenAndCancelFuture() { CompletableFuture completeMpuFuture = new CompletableFuture<>(); int numExistingParts = 2; - + S3ResumeToken resumeToken = testPauseScenario(numExistingParts, completeMpuFuture); verifyResumeToken(resumeToken, numExistingParts); @@ -187,7 +184,7 @@ void pause_withCompletedCompleteMpuFuture_shouldReturnNullToken() { CompletableFuture completeMpuFuture = CompletableFuture.completedFuture(CompleteMultipartUploadResponse.builder().build()); int numExistingParts = 2; - + S3ResumeToken resumeToken = testPauseScenario(numExistingParts, completeMpuFuture); assertThat(resumeToken).isNull(); @@ -196,15 +193,15 @@ void pause_withCompletedCompleteMpuFuture_shouldReturnNullToken() { @Test void pause_withUninitiatedCompleteMpuFuture_shouldReturnToken() { int numExistingParts = 2; - + S3ResumeToken resumeToken = testPauseScenario(numExistingParts, null); verifyResumeToken(resumeToken, numExistingParts); } - - private S3ResumeToken testPauseScenario(int numExistingParts, + + private S3ResumeToken testPauseScenario(int numExistingParts, CompletableFuture completeMpuFuture) { - KnownContentLengthAsyncRequestBodySubscriber subscriber = + KnownContentLengthAsyncRequestBodySubscriber subscriber = createSubscriber(createMpuRequestContextWithExistingParts(numExistingParts)); when(multipartUploadHelper.completeMultipartUpload(any(CompletableFuture.class), any(String.class), @@ -246,14 +243,14 @@ private KnownContentLengthAsyncRequestBodySubscriber createSubscriber(MpuRequest return new KnownContentLengthAsyncRequestBodySubscriber(mpuRequestContext, returnFuture, multipartUploadHelper); } - private AsyncRequestBody createMockAsyncRequestBody(long contentLength) { - AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + private ClosableAsyncRequestBody createMockAsyncRequestBody(long contentLength) { + ClosableAsyncRequestBody mockBody = mock(ClosableAsyncRequestBody.class); when(mockBody.contentLength()).thenReturn(Optional.of(contentLength)); return mockBody; } - private AsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { - AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + private ClosableAsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { + ClosableAsyncRequestBody mockBody = mock(ClosableAsyncRequestBody.class); when(mockBody.contentLength()).thenReturn(Optional.empty()); return mockBody; } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java index 90e14dcff2dd..ee3bf01c3d54 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientPutObjectWiremockTest.java @@ -18,14 +18,20 @@ import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; import static com.github.tomakehurst.wiremock.client.WireMock.delete; +import static com.github.tomakehurst.wiremock.client.WireMock.matching; import static com.github.tomakehurst.wiremock.client.WireMock.post; import static com.github.tomakehurst.wiremock.client.WireMock.put; +import static com.github.tomakehurst.wiremock.client.WireMock.putRequestedFor; import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import com.github.tomakehurst.wiremock.client.ResponseDefinitionBuilder; +import com.github.tomakehurst.wiremock.http.Fault; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.github.tomakehurst.wiremock.stubbing.Scenario; import io.reactivex.rxjava3.core.Flowable; import java.io.InputStream; import java.net.URI; @@ -46,20 +52,17 @@ import org.reactivestreams.Subscriber; import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.async.AsyncRequestBody; -import software.amazon.awssdk.core.async.AsyncRequestBodySplitConfiguration; import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody; -import software.amazon.awssdk.core.async.SdkPublisher; -import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient; import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.S3Exception; -import software.amazon.awssdk.utils.async.SimplePublisher; @WireMockTest -@Timeout(100) +@Timeout(120) public class S3MultipartClientPutObjectWiremockTest { private static final String BUCKET = "Example-Bucket"; @@ -71,25 +74,21 @@ public class S3MultipartClientPutObjectWiremockTest { + ""; private S3AsyncClient s3AsyncClient; - public static Stream invalidAsyncRequestBodies() { + public static Stream retryableErrorTestCase() { return Stream.of( - Arguments.of("knownContentLength_nullPartSize", new TestPublisherWithIncorrectSplitImpl(20L, null), - "Content length is missing on the AsyncRequestBody for part number"), - Arguments.of("unknownContentLength_nullPartSize", new TestPublisherWithIncorrectSplitImpl(null, null), - "Content length is missing on the AsyncRequestBody for part number"), - Arguments.of("knownContentLength_partSizeIncorrect", new TestPublisherWithIncorrectSplitImpl(20L, 11L), - "Content length must not be greater than part size"), - Arguments.of("unknownContentLength_partSizeIncorrect", new TestPublisherWithIncorrectSplitImpl(null, 11L), - "Content length must not be greater than part size"), - Arguments.of("knownContentLength_sendMoreParts", new TestPublisherWithIncorrectSplitImpl(20L, 10L, 3), - "The number of parts divided is not equal to the expected number of parts"), - Arguments.of("knownContentLength_sendFewerParts", new TestPublisherWithIncorrectSplitImpl(20L, 10L, 1), - "The number of parts divided is not equal to the expected number of parts")); + Arguments.of("unknownContentLength_failOfConnectionReset_shouldRetry", null, + aResponse().withFault(Fault.CONNECTION_RESET_BY_PEER)), + Arguments.of("unknownContentLength_failOf500_shouldRetry", null, + aResponse().withStatus(500)), + Arguments.of("knownContentLength_failOfConnectionReset_shouldRetry", 20L, + aResponse().withFault(Fault.CONNECTION_RESET_BY_PEER)), + Arguments.of("knownContentLength_failOf500_shouldRetry", 20L, + aResponse().withStatus(500)) + ); } @BeforeEach public void setup(WireMockRuntimeInfo wiremock) { - stubFailedPutObjectCalls(); s3AsyncClient = S3AsyncClient.builder() .region(Region.US_EAST_1) .endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort())) @@ -101,22 +100,17 @@ public void setup(WireMockRuntimeInfo wiremock) { .build(); } - private void stubFailedPutObjectCalls() { + private void stubPutObject404Calls() { stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD))); stubFor(put(anyUrl()).willReturn(aResponse().withStatus(404))); stubFor(put(urlEqualTo("/Example-Bucket/Example-Object?partNumber=1&uploadId=string")).willReturn(aResponse().withStatus(200))); stubFor(delete(anyUrl()).willReturn(aResponse().withStatus(200))); } - private void stubSuccessfulPutObjectCalls() { - stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD))); - stubFor(put(anyUrl()).willReturn(aResponse().withStatus(200))); - } - - // https://github.com/aws/aws-sdk-java-v2/issues/4801 @Test void uploadWithUnknownContentLength_onePartFails_shouldCancelUpstream() { + stubPutObject404Calls(); BlockingInputStreamAsyncRequestBody blockingInputStreamAsyncRequestBody = AsyncRequestBody.forBlockingInputStream(null); CompletableFuture putObjectResponse = s3AsyncClient.putObject( r -> r.bucket(BUCKET).key(KEY), blockingInputStreamAsyncRequestBody); @@ -132,6 +126,7 @@ void uploadWithUnknownContentLength_onePartFails_shouldCancelUpstream() { @Test void uploadWithKnownContentLength_onePartFails_shouldCancelUpstream() { + stubPutObject404Calls(); BlockingInputStreamAsyncRequestBody blockingInputStreamAsyncRequestBody = AsyncRequestBody.forBlockingInputStream(1024L * 20); // must be larger than the buffer used in // InputStreamConsumingPublisher to trigger the error @@ -147,19 +142,56 @@ void uploadWithKnownContentLength_onePartFails_shouldCancelUpstream() { assertThatThrownBy(() -> putObjectResponse.join()).hasRootCauseInstanceOf(S3Exception.class); } - @ParameterizedTest(name = "{index} {0}") - @MethodSource("invalidAsyncRequestBodies") - void uploadWithIncorrectAsyncRequestBodySplit_contentLengthMismatch_shouldThrowException(String description, - TestPublisherWithIncorrectSplitImpl asyncRequestBody, - String errorMsg) { - stubSuccessfulPutObjectCalls(); - CompletableFuture putObjectResponse = s3AsyncClient.putObject( - r -> r.bucket(BUCKET).key(KEY), asyncRequestBody); + @ParameterizedTest + @MethodSource("retryableErrorTestCase") + void mpu_partsFailOfRetryableError_shouldRetry(String description, + Long contentLength, + ResponseDefinitionBuilder responseDefinitionBuilder) { + stubUploadPartFailsInitialAttemptCalls(responseDefinitionBuilder); + List buffers = new ArrayList<>(); + buffers.add(SdkBytes.fromUtf8String(RandomStringUtils.randomAscii(10)).asByteBuffer()); + buffers.add(SdkBytes.fromUtf8String(RandomStringUtils.randomAscii(10)).asByteBuffer()); + AsyncRequestBody asyncRequestBody = new AsyncRequestBody() { + @Override + public Optional contentLength() { + return Optional.ofNullable(contentLength); + } + + @Override + public void subscribe(Subscriber s) { + Flowable.fromIterable(buffers).subscribe(s); + } + }; + + s3AsyncClient.putObject(b -> b.bucket(BUCKET).key(KEY), asyncRequestBody).join(); + + verify(2, putRequestedFor(anyUrl()).withQueryParam("partNumber", matching(String.valueOf(1)))); + verify(2, putRequestedFor(anyUrl()).withQueryParam("partNumber", matching(String.valueOf(2)))); + } + - assertThatThrownBy(() -> putObjectResponse.join()).hasMessageContaining(errorMsg) - .hasRootCauseInstanceOf(SdkClientException.class); + private void stubUploadPartFailsInitialAttemptCalls(ResponseDefinitionBuilder responseDefinitionBuilder) { + stubFor(post(anyUrl()).willReturn(aResponse().withStatus(200).withBody(CREATE_MULTIPART_PAYLOAD))); + stubUploadFailsInitialAttemptCalls(1, responseDefinitionBuilder); + stubUploadFailsInitialAttemptCalls(2, responseDefinitionBuilder); } + private void stubUploadFailsInitialAttemptCalls(int partNumber, ResponseDefinitionBuilder responseDefinitionBuilder) { + stubFor(put(anyUrl()) + .withQueryParam("partNumber", matching(String.valueOf(partNumber))) + .inScenario(String.valueOf(partNumber)) + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(responseDefinitionBuilder) + .willSetStateTo("SecondAttempt" + partNumber)); + + stubFor(put(anyUrl()) + .withQueryParam("partNumber", matching(String.valueOf(partNumber))) + .inScenario(String.valueOf(partNumber)) + .whenScenarioStateIs("SecondAttempt" + partNumber) + .willReturn(aResponse().withStatus(200))); + } + + private InputStream createUnlimitedInputStream() { return new InputStream() { @Override @@ -168,65 +200,5 @@ public int read() { } }; } - - private static class TestPublisherWithIncorrectSplitImpl implements AsyncRequestBody { - private SimplePublisher simplePublisher = new SimplePublisher<>(); - private Long totalSize; - private Long partSize; - private Integer numParts; - - private TestPublisherWithIncorrectSplitImpl(Long totalSize, Long partSize) { - this.totalSize = totalSize; - this.partSize = partSize; - } - - private TestPublisherWithIncorrectSplitImpl(Long totalSize, long partSize, int numParts) { - this.totalSize = totalSize; - this.partSize = partSize; - this.numParts = numParts; - } - - @Override - public Optional contentLength() { - return Optional.ofNullable(totalSize); - } - - @Override - public void subscribe(Subscriber s) { - simplePublisher.subscribe(s); - } - - @Override - public SdkPublisher split(AsyncRequestBodySplitConfiguration splitConfiguration) { - List requestBodies = new ArrayList<>(); - int numAsyncRequestBodies = numParts == null ? 1 : numParts; - for (int i = 0; i < numAsyncRequestBodies; i++) { - requestBodies.add(new TestAsyncRequestBody(partSize)); - } - - return SdkPublisher.adapt(Flowable.fromArray(requestBodies.toArray(new AsyncRequestBody[requestBodies.size()]))); - } - } - - private static class TestAsyncRequestBody implements AsyncRequestBody { - private Long partSize; - private SimplePublisher simplePublisher = new SimplePublisher<>(); - - public TestAsyncRequestBody(Long partSize) { - this.partSize = partSize; - } - - @Override - public Optional contentLength() { - return Optional.ofNullable(partSize); - } - - @Override - public void subscribe(Subscriber s) { - simplePublisher.subscribe(s); - simplePublisher.send(ByteBuffer.wrap( - RandomStringUtils.randomAscii(Math.toIntExact(partSize)).getBytes())); - simplePublisher.complete(); - } - } } + diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java index 972f0b86241a..e5c96bc6abed 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/UploadWithUnknownContentLengthHelperTest.java @@ -31,6 +31,7 @@ import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; +import java.nio.ByteBuffer; import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; @@ -46,6 +47,7 @@ import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.ClosableAsyncRequestBody; import software.amazon.awssdk.core.async.BlockingInputStreamAsyncRequestBody; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.exception.SdkClientException; @@ -114,14 +116,14 @@ void upload_blockingInputStream_shouldInOrder() throws FileNotFoundException { @Test void uploadObject_withMissingContentLength_shouldFailRequest() { - AsyncRequestBody asyncRequestBody = createMockAsyncRequestBodyWithEmptyContentLength(); + ClosableAsyncRequestBody asyncRequestBody = createMockAsyncRequestBodyWithEmptyContentLength(); CompletableFuture future = setupAndTriggerUploadFailure(asyncRequestBody); verifyFailureWithMessage(future, "Content length is missing on the AsyncRequestBody for part number"); } @Test void uploadObject_withPartSizeExceedingLimit_shouldFailRequest() { - AsyncRequestBody asyncRequestBody = createMockAsyncRequestBody(PART_SIZE + 1); + ClosableAsyncRequestBody asyncRequestBody = createMockAsyncRequestBody(PART_SIZE + 1); CompletableFuture future = setupAndTriggerUploadFailure(asyncRequestBody); verifyFailureWithMessage(future, "Content length must not be greater than part size"); } @@ -139,27 +141,27 @@ private List createCompletedParts(int totalNumParts) { .collect(Collectors.toList()); } - private AsyncRequestBody createMockAsyncRequestBody(long contentLength) { - AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + private ClosableAsyncRequestBody createMockAsyncRequestBody(long contentLength) { + ClosableAsyncRequestBody mockBody = mock(ClosableAsyncRequestBody.class); when(mockBody.contentLength()).thenReturn(Optional.of(contentLength)); return mockBody; } - private AsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { - AsyncRequestBody mockBody = mock(AsyncRequestBody.class); + private ClosableAsyncRequestBody createMockAsyncRequestBodyWithEmptyContentLength() { + ClosableAsyncRequestBody mockBody = mock(ClosableAsyncRequestBody.class); when(mockBody.contentLength()).thenReturn(Optional.empty()); return mockBody; } - private CompletableFuture setupAndTriggerUploadFailure(AsyncRequestBody asyncRequestBody) { - SdkPublisher mockPublisher = mock(SdkPublisher.class); - when(asyncRequestBody.split(any(Consumer.class))).thenReturn(mockPublisher); + private CompletableFuture setupAndTriggerUploadFailure(ClosableAsyncRequestBody asyncRequestBody) { + SdkPublisher mockPublisher = mock(SdkPublisher.class); + when(asyncRequestBody.splitV2(any(Consumer.class))).thenReturn(mockPublisher); - ArgumentCaptor> subscriberCaptor = ArgumentCaptor.forClass(Subscriber.class); + ArgumentCaptor> subscriberCaptor = ArgumentCaptor.forClass(Subscriber.class); CompletableFuture future = helper.uploadObject(createPutObjectRequest(), asyncRequestBody); verify(mockPublisher).subscribe(subscriberCaptor.capture()); - Subscriber subscriber = subscriberCaptor.getValue(); + Subscriber subscriber = subscriberCaptor.getValue(); Subscription subscription = mock(Subscription.class); subscriber.onSubscribe(subscription); diff --git a/services/s3/src/test/resources/log4j2.properties b/services/s3/src/test/resources/log4j2.properties index ad5cb8e79a64..8f3afbf09abe 100644 --- a/services/s3/src/test/resources/log4j2.properties +++ b/services/s3/src/test/resources/log4j2.properties @@ -35,4 +35,4 @@ rootLogger.appenderRef.stdout.ref = ConsoleAppender #logger.apache.level = debug # #logger.netty.name = io.netty.handler.logging -#logger.netty.level = debug +#logger.netty.level = debug \ No newline at end of file