diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java
index b4805a78dca2..d0196b48f5a3 100644
--- a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java
+++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java
@@ -26,14 +26,16 @@
import org.reactivestreams.Subscriber;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.Pair;
+import software.amazon.awssdk.utils.Validate;
import software.amazon.awssdk.utils.async.AddingTrailingDataSubscriber;
+import software.amazon.awssdk.utils.async.ContentLengthAwareSubscriber;
import software.amazon.awssdk.utils.async.DelegatingSubscriber;
import software.amazon.awssdk.utils.async.FlatteningSubscriber;
import software.amazon.awssdk.utils.internal.MappingSubscriber;
/**
* An implementation of chunk-transfer encoding, but by wrapping a {@link Publisher} of {@link ByteBuffer}. This implementation
- * supports chunk-headers, chunk-extensions.
+ * supports chunk-headers, chunk-extensions, and trailer-part.
*
* Per RFC-7230, a chunk-transfer encoded message is
* defined as:
@@ -66,6 +68,7 @@ public class ChunkedEncodedPublisher implements Publisher {
private static final byte COMMA = ',';
private final Publisher wrapped;
+ private final long contentLength;
private final List extensions = new ArrayList<>();
private final List trailers = new ArrayList<>();
private final int chunkSize;
@@ -74,6 +77,7 @@ public class ChunkedEncodedPublisher implements Publisher {
public ChunkedEncodedPublisher(Builder b) {
this.wrapped = b.publisher;
+ this.contentLength = Validate.notNull(b.contentLength, "contentLength must not be null");
this.chunkSize = b.chunkSize;
this.extensions.addAll(b.extensions);
this.trailers.addAll(b.trailers);
@@ -82,7 +86,8 @@ public ChunkedEncodedPublisher(Builder b) {
@Override
public void subscribe(Subscriber super ByteBuffer> subscriber) {
- Publisher> chunked = chunk(wrapped);
+ Publisher lengthEnforced = limitLength(wrapped, contentLength);
+ Publisher> chunked = chunk(lengthEnforced);
Publisher> trailingAdded = addTrailingChunks(chunked);
Publisher flattened = flatten(trailingAdded);
Publisher encoded = map(flattened, this::encodeChunk);
@@ -111,6 +116,10 @@ private Iterable> getTrailingChunks() {
return Collections.singletonList(trailing);
}
+ private Publisher limitLength(Publisher publisher, long length) {
+ return subscriber -> publisher.subscribe(new ContentLengthAwareSubscriber(subscriber, length));
+ }
+
private Publisher> chunk(Publisher upstream) {
return subscriber -> {
upstream.subscribe(new ChunkingSubscriber(subscriber));
@@ -153,8 +162,7 @@ private ByteBuffer encodeChunk(ByteBuffer byteBuffer) {
}
int trailerLen = trailerData.stream()
- // + 2 for each CRLF that ends the header-field
- .mapToInt(t -> t.remaining() + 2)
+ .mapToInt(t -> t.remaining() + CRLF.length)
.sum();
int encodedLen = chunkSizeHex.length + extensionsLength + CRLF.length + contentLen + trailerLen + CRLF.length;
@@ -188,11 +196,11 @@ private ByteBuffer encodeChunk(ByteBuffer byteBuffer) {
encoded.put(t);
encoded.put(CRLF);
});
+ // empty line ends the request body
encoded.put(CRLF);
}
encoded.flip();
-
return encoded;
}
@@ -294,6 +302,7 @@ public void onNext(ByteBuffer byteBuffer) {
public static class Builder {
private Publisher publisher;
+ private Long contentLength;
private int chunkSize;
private boolean addEmptyTrailingChunk;
private final List extensions = new ArrayList<>();
@@ -304,6 +313,15 @@ public Builder publisher(Publisher publisher) {
return this;
}
+ public Publisher publisher() {
+ return publisher;
+ }
+
+ public Builder contentLength(long contentLength) {
+ this.contentLength = contentLength;
+ return this;
+ }
+
public Builder chunkSize(int chunkSize) {
this.chunkSize = chunkSize;
return this;
@@ -324,6 +342,10 @@ public Builder addTrailer(TrailerProvider trailerProvider) {
return this;
}
+ public List trailers() {
+ return trailers;
+ }
+
public ChunkedEncodedPublisher build() {
return new ChunkedEncodedPublisher(this);
}
diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java
index 909ae0289616..7f62802ecd1e 100644
--- a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java
+++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java
@@ -18,6 +18,7 @@
import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
import io.reactivex.Flowable;
+import io.reactivex.subscribers.TestSubscriber;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
@@ -26,6 +27,7 @@
import java.util.List;
import java.util.PrimitiveIterator;
import java.util.Random;
+import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
@@ -56,6 +58,7 @@ public void subscribe_publisherEmpty_onlyProducesTrailer() {
.addTrailer(() -> Pair.of("foo", Collections.singletonList("1")))
.addTrailer(() -> Pair.of("bar", Collections.singletonList("2")))
.addEmptyTrailingChunk(true)
+ .contentLength(0)
.build();
List chunks = getAllElements(build);
@@ -73,12 +76,14 @@ public void subscribe_publisherEmpty_onlyProducesTrailer() {
@Test
void subscribe_trailerProviderPresent_trailerPartAdded() {
+ int contentLength = 8;
TestPublisher upstream = randomPublisherOfLength(8);
TrailerProvider trailerProvider = new StaticTrailerProvider("foo", "bar");
ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder()
.publisher(upstream)
+ .contentLength(contentLength)
.chunkSize(CHUNK_SIZE)
.addEmptyTrailingChunk(true)
.addTrailer(trailerProvider)
@@ -93,12 +98,14 @@ void subscribe_trailerProviderPresent_trailerPartAdded() {
@Test
void subscribe_trailerProviderPresent_multipleValues_trailerPartAdded() {
- TestPublisher upstream = randomPublisherOfLength(8);
+ int contentLength = 8;
+ TestPublisher upstream = randomPublisherOfLength(contentLength);
TrailerProvider trailerProvider = new StaticTrailerProvider("foo", Arrays.asList("bar1", "bar2", "bar3"));
ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder()
.publisher(upstream)
+ .contentLength(contentLength)
.chunkSize(CHUNK_SIZE)
.addEmptyTrailingChunk(true)
.addTrailer(trailerProvider)
@@ -113,7 +120,8 @@ void subscribe_trailerProviderPresent_multipleValues_trailerPartAdded() {
@Test
void subscribe_trailerProviderPresent_onlyInvokedOnce() {
- TestPublisher upstream = randomPublisherOfLength(8);
+ int contentLength = 8;
+ TestPublisher upstream = randomPublisherOfLength(contentLength);
TrailerProvider trailerProvider = Mockito.spy(new StaticTrailerProvider("foo", "bar"));
@@ -121,6 +129,7 @@ void subscribe_trailerProviderPresent_onlyInvokedOnce() {
.publisher(upstream)
.addEmptyTrailingChunk(true)
.chunkSize(CHUNK_SIZE)
+ .contentLength(contentLength)
.addTrailer(trailerProvider).build();
getAllElements(chunkedPublisher);
@@ -130,13 +139,15 @@ void subscribe_trailerProviderPresent_onlyInvokedOnce() {
@Test
void subscribe_trailerPresent_trailerFormattedCorrectly() {
- TestPublisher testPublisher = randomPublisherOfLength(32);
+ int contentLength = 32;
+ TestPublisher testPublisher = randomPublisherOfLength(contentLength);
TrailerProvider trailerProvider = new StaticTrailerProvider("foo", "bar");
ChunkedEncodedPublisher chunkedPublisher = newChunkedBuilder(testPublisher)
.addTrailer(trailerProvider)
.addEmptyTrailingChunk(true)
+ .contentLength(contentLength)
.build();
List chunks = getAllElements(chunkedPublisher);
@@ -152,11 +163,13 @@ void subscribe_trailerPresent_trailerFormattedCorrectly() {
@Test
void subscribe_wrappedDoesNotFillBuffer_allDataInSingleChunk() {
- ByteBuffer element = ByteBuffer.wrap("hello world".getBytes(StandardCharsets.UTF_8));
+ byte[] content = "hello world".getBytes(StandardCharsets.UTF_8);
+ ByteBuffer element = ByteBuffer.wrap(content);
Flowable upstream = Flowable.just(element.duplicate());
ChunkedEncodedPublisher publisher = ChunkedEncodedPublisher.builder()
.chunkSize(CHUNK_SIZE)
+ .contentLength(content.length)
.publisher(upstream)
.build();
@@ -169,7 +182,8 @@ void subscribe_wrappedDoesNotFillBuffer_allDataInSingleChunk() {
@Test
void subscribe_extensionHasNoValue_formattedCorrectly() {
- TestPublisher testPublisher = randomPublisherOfLength(8);
+ int contentLength = 8;
+ TestPublisher testPublisher = randomPublisherOfLength(contentLength);
ChunkExtensionProvider extensionProvider = new StaticExtensionProvider("foo", "");
@@ -178,6 +192,7 @@ void subscribe_extensionHasNoValue_formattedCorrectly() {
.publisher(testPublisher)
.addExtension(extensionProvider)
.chunkSize(CHUNK_SIZE)
+ .contentLength(contentLength)
.build();
List chunks = getAllElements(chunkPublisher);
@@ -187,11 +202,13 @@ void subscribe_extensionHasNoValue_formattedCorrectly() {
@Test
void subscribe_multipleExtensions_formattedCorrectly() {
- TestPublisher testPublisher = randomPublisherOfLength(8);
+ int contentLength = 8;
+ TestPublisher testPublisher = randomPublisherOfLength(contentLength);
ChunkedEncodedPublisher.Builder chunkPublisher =
ChunkedEncodedPublisher.builder()
.publisher(testPublisher)
+ .contentLength(contentLength)
.chunkSize(CHUNK_SIZE);
Stream.of("1", "2", "3")
@@ -207,10 +224,12 @@ void subscribe_multipleExtensions_formattedCorrectly() {
void subscribe_randomElementSizes_dataChunkedCorrectly() {
for (int i = 0; i < 512; ++i) {
int nChunks = 24;
- TestPublisher byteBufferPublisher = randomPublisherOfLength(CHUNK_SIZE * 24);
+ int contentLength = nChunks * CHUNK_SIZE;
+ TestPublisher byteBufferPublisher = randomPublisherOfLength(contentLength);
ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder()
.publisher(byteBufferPublisher)
+ .contentLength(contentLength)
.chunkSize(CHUNK_SIZE)
.build();
@@ -232,7 +251,8 @@ void subscribe_randomElementSizes_dataChunkedCorrectly() {
void subscribe_randomElementSizes_chunksHaveExtensions_dataChunkedCorrectly() {
for (int i = 0; i < 512; ++i) {
int nChunks = 24;
- TestPublisher byteBufferPublisher = randomPublisherOfLength(CHUNK_SIZE * 24);
+ int contentLength = CHUNK_SIZE * nChunks;
+ TestPublisher byteBufferPublisher = randomPublisherOfLength(contentLength);
StaticExtensionProvider extensionProvider = Mockito.spy(new StaticExtensionProvider("foo", "bar"));
@@ -240,6 +260,7 @@ void subscribe_randomElementSizes_chunksHaveExtensions_dataChunkedCorrectly() {
.publisher(byteBufferPublisher)
.addExtension(extensionProvider)
.chunkSize(CHUNK_SIZE)
+ .contentLength(contentLength)
.build();
List chunks = getAllElements(chunkedPublisher);
@@ -264,12 +285,14 @@ void subscribe_randomElementSizes_chunksHaveExtensions_dataChunkedCorrectly() {
@Test
void subscribe_addTrailingChunkTrue_trailingChunkAdded() {
- TestPublisher testPublisher = randomPublisherOfLength(CHUNK_SIZE * 2);
+ int contentLength = CHUNK_SIZE * 2;
+ TestPublisher testPublisher = randomPublisherOfLength(contentLength);
ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder()
.publisher(testPublisher)
.chunkSize(CHUNK_SIZE)
.addEmptyTrailingChunk(true)
+ .contentLength(contentLength)
.build();
List chunks = getAllElements(chunkedPublisher);
@@ -285,7 +308,12 @@ void subscribe_addTrailingChunkTrue_upstreamEmpty_trailingChunkAdded() {
Publisher empty = Flowable.empty();
ChunkedEncodedPublisher chunkedPublisher =
- ChunkedEncodedPublisher.builder().publisher(empty).chunkSize(CHUNK_SIZE).addEmptyTrailingChunk(true).build();
+ ChunkedEncodedPublisher.builder()
+ .publisher(empty)
+ .chunkSize(CHUNK_SIZE)
+ .addEmptyTrailingChunk(true)
+ .contentLength(0)
+ .build();
List chunks = getAllElements(chunkedPublisher);
@@ -297,10 +325,15 @@ void subscribe_extensionsPresent_extensionsInvokedForEachChunk() {
ChunkExtensionProvider mockProvider = Mockito.spy(new StaticExtensionProvider("foo", "bar"));
int nChunks = 16;
- TestPublisher elements = randomPublisherOfLength(nChunks * CHUNK_SIZE);
+ int contentLength = CHUNK_SIZE * nChunks;
+ TestPublisher elements = randomPublisherOfLength(contentLength);
- ChunkedEncodedPublisher chunkPublisher =
- ChunkedEncodedPublisher.builder().publisher(elements).chunkSize(CHUNK_SIZE).addExtension(mockProvider).build();
+ ChunkedEncodedPublisher chunkPublisher = ChunkedEncodedPublisher.builder()
+ .publisher(elements)
+ .contentLength(contentLength)
+ .chunkSize(CHUNK_SIZE)
+ .addExtension(mockProvider)
+ .build();
List chunks = getAllElements(chunkPublisher);
@@ -316,6 +349,28 @@ void subscribe_extensionsPresent_extensionsInvokedForEachChunk() {
}
}
+ @Test
+ void subscribe_wrappedExceedsContentLength_dataTruncatedToLength() {
+ int contentLength = CHUNK_SIZE * 4 - 1;
+ TestPublisher elements = randomPublisherOfLength(contentLength * 2);
+
+ TestSubscriber subscriber = new TestSubscriber<>();
+ ChunkedEncodedPublisher chunkPublisher = newChunkedBuilder(elements).contentLength(contentLength)
+ .build();
+
+ chunkPublisher.subscribe(subscriber);
+
+ subscriber.awaitTerminalEvent(30, TimeUnit.SECONDS);
+
+ int totalRemaining = subscriber.values()
+ .stream()
+ .map(this::stripEncoding)
+ .mapToInt(ByteBuffer::remaining)
+ .sum();
+
+ assertThat(totalRemaining).isEqualTo(contentLength);
+ }
+
private static ChunkedEncodedPublisher.Builder newChunkedBuilder(Publisher publisher) {
return ChunkedEncodedPublisher.builder().publisher(publisher).chunkSize(CHUNK_SIZE);
}
@@ -401,6 +456,10 @@ private ByteBuffer stripEncoding(ByteBuffer chunk) {
return stripped;
}
+ private long totalRemaining(List buffers) {
+ return buffers.stream().mapToLong(ByteBuffer::remaining).sum();
+ }
+
private static class TestPublisher implements Publisher {
private final Publisher wrapped;
private final byte[] wrappedChecksum;
diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriber.java
new file mode 100644
index 000000000000..04d009174a20
--- /dev/null
+++ b/utils/src/main/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriber.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.utils.async;
+
+import java.nio.ByteBuffer;
+import org.reactivestreams.Subscriber;
+import org.reactivestreams.Subscription;
+import software.amazon.awssdk.annotations.SdkProtectedApi;
+
+/**
+ * Decorator subscriber that limits the number of bytes sent to the wrapped subscriber to at most {@code contentLength}. Once
+ * the given content length is reached, the upstream subscription is cancelled, and the wrapped subscriber is completed.
+ */
+@SdkProtectedApi
+public final class ContentLengthAwareSubscriber implements Subscriber {
+ private final Subscriber super ByteBuffer> subscriber;
+ private Subscription subscription;
+ private boolean subscriptionCancelled;
+ private long remaining;
+
+ public ContentLengthAwareSubscriber(Subscriber super ByteBuffer> subscriber, long contentLength) {
+ this.subscriber = subscriber;
+ this.remaining = contentLength;
+ }
+
+ @Override
+ public void onSubscribe(Subscription subscription) {
+ if (subscription == null) {
+ throw new NullPointerException("subscription must not be null");
+ }
+ this.subscription = subscription;
+ subscriber.onSubscribe(subscription);
+ }
+
+ @Override
+ public void onNext(ByteBuffer byteBuffer) {
+ if (remaining > 0) {
+ long bytesToRead = Math.min(remaining, byteBuffer.remaining());
+ // cast is safe, min of long and int is <= max_int
+ byteBuffer.limit(byteBuffer.position() + (int) bytesToRead);
+ remaining -= bytesToRead;
+ subscriber.onNext(byteBuffer);
+ }
+
+ if (remaining == 0 && !subscriptionCancelled) {
+ subscriptionCancelled = true;
+ subscription.cancel();
+ onComplete();
+ }
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ if (throwable == null) {
+ throw new NullPointerException("throwable cannot be null");
+ }
+ subscriber.onError(throwable);
+ }
+
+ @Override
+ public void onComplete() {
+ subscriber.onComplete();
+ }
+}
diff --git a/utils/src/test/java/software/amazon/awssdk/testutils/PublisherUtils.java b/utils/src/test/java/software/amazon/awssdk/testutils/PublisherUtils.java
new file mode 100644
index 000000000000..45dc9d63f22b
--- /dev/null
+++ b/utils/src/test/java/software/amazon/awssdk/testutils/PublisherUtils.java
@@ -0,0 +1,50 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.testutils;
+
+import io.reactivex.Flowable;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.PrimitiveIterator;
+import java.util.Random;
+import org.reactivestreams.Publisher;
+
+public final class PublisherUtils {
+ private static final Random RNG = new Random();
+
+ private PublisherUtils() {
+ }
+
+ public static Publisher randomPublisherOfLength(int bytes, int min, int max) {
+ List elements = new ArrayList<>();
+
+ PrimitiveIterator.OfInt sizeIter = RNG.ints(min, max).iterator();
+
+ while (bytes > 0) {
+ int elementSize = sizeIter.next();
+ elementSize = Math.min(elementSize, bytes);
+
+ bytes -= elementSize;
+
+ byte[] elementContent = new byte[elementSize];
+ RNG.nextBytes(elementContent);
+ elements.add(ByteBuffer.wrap(elementContent));
+ }
+
+ return Flowable.fromIterable(elements);
+ }
+}
diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTckTest.java
new file mode 100644
index 000000000000..e2e419bb3004
--- /dev/null
+++ b/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTckTest.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.utils.async;
+
+import io.reactivex.subscribers.TestSubscriber;
+import java.nio.ByteBuffer;
+import org.reactivestreams.Subscriber;
+import org.reactivestreams.tck.SubscriberBlackboxVerification;
+import org.reactivestreams.tck.TestEnvironment;
+
+public class ContentLengthAwareSubscriberTckTest extends SubscriberBlackboxVerification {
+
+ public ContentLengthAwareSubscriberTckTest() {
+ super(new TestEnvironment());
+ }
+
+ @Override
+ public Subscriber createSubscriber() {
+ return new ContentLengthAwareSubscriber(new TestSubscriber<>(), 16);
+ }
+
+ @Override
+ public ByteBuffer createElement(int i) {
+ return ByteBuffer.wrap(Integer.toString(i).getBytes());
+ }
+}
diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTest.java
new file mode 100644
index 000000000000..9f8a982b467f
--- /dev/null
+++ b/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTest.java
@@ -0,0 +1,187 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.utils.async;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static software.amazon.awssdk.testutils.PublisherUtils.randomPublisherOfLength;
+
+import io.reactivex.subscribers.TestSubscriber;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicLong;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+import org.reactivestreams.Publisher;
+import org.reactivestreams.Subscriber;
+import org.reactivestreams.Subscription;
+
+public class ContentLengthAwareSubscriberTest {
+ @Test
+ void subscribe_upstreamExceedsContentLength_correctlyTruncates() {
+ long contentLength = 64;
+ Publisher upstream = randomPublisherOfLength(8192, 8, 16);
+
+ TestSubscriber subscriber = new TestSubscriber<>();
+
+ ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(subscriber, contentLength);
+ upstream.subscribe(lengthAwareSubscriber);
+
+ assertThat(totalRemaining(subscriber.values())).isEqualTo(contentLength);
+ }
+
+ @Test
+ void subscribe_upstreamHasExactlyContentLength_signalsComplete() {
+ long contentLength = 8192;
+ Publisher upstream = randomPublisherOfLength((int) contentLength, 8, 16);
+
+ TestSubscriber subscriber = new TestSubscriber<>();
+ ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(subscriber, contentLength);
+ upstream.subscribe(lengthAwareSubscriber);
+
+ subscriber.assertComplete();
+ assertThat(totalRemaining(subscriber.values())).isEqualTo(contentLength);
+ }
+
+ @Test
+ void subscribe_upstreamExceedsContentLength_request1BufferAtATime_correctlyTruncates() throws Exception {
+ long contentLength = 8192;
+
+ Publisher upstream = randomPublisherOfLength((int) contentLength * 2, 8, 16);
+
+ CompletableFuture subscriberFinished = new CompletableFuture<>();
+ List buffers = new ArrayList<>();
+
+ Subscriber testSubscriber = new Subscriber() {
+ private Subscription subscription;
+
+ @Override
+ public void onSubscribe(Subscription subscription) {
+ this.subscription = subscription;
+ this.subscription.request(1);
+ }
+
+ @Override
+ public void onNext(ByteBuffer byteBuffer) {
+ buffers.add(byteBuffer);
+ subscription.request(1);
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ subscriberFinished.completeExceptionally(throwable);
+ }
+
+ @Override
+ public void onComplete() {
+ subscriberFinished.complete(null);
+ }
+ };
+
+ testSubscriber = Mockito.spy(testSubscriber);
+
+ ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(testSubscriber, contentLength);
+ upstream.subscribe(lengthAwareSubscriber);
+
+ subscriberFinished.get(1, TimeUnit.MINUTES);
+ Mockito.verify(testSubscriber, Mockito.times(1)).onComplete();
+ assertThat(totalRemaining(buffers)).isEqualTo(contentLength);
+ }
+
+ @Test
+ void subscribe_upstreamExceedsContentLength_upstreamSubscriptionCancelledAfterContentLengthReached() {
+ long contentLength = 64;
+ Publisher upstream = randomPublisherOfLength((int) contentLength * 4, 8, 16);
+
+ TestSubscriber testSubscriber = new TestSubscriber<>();
+ ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(testSubscriber, contentLength);
+ SubscriptionWrappingSubscriber subscriptionWrappingSubscriber = new SubscriptionWrappingSubscriber(lengthAwareSubscriber);
+ upstream.subscribe(subscriptionWrappingSubscriber);
+
+ testSubscriber.assertComplete();
+ assertThat(subscriptionWrappingSubscriber.wrappedSubscription.cancelInvocations.get()).isEqualTo(1L);
+ assertThat(totalRemaining(testSubscriber.values())).isEqualTo(contentLength);
+ }
+
+ @Test
+ void subscribe_upstreamHasContentAndContentLength0_signalsComplete() {
+ Publisher upstream = randomPublisherOfLength(128, 8, 16);
+
+ TestSubscriber testSubscriber = new TestSubscriber<>();
+ ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(testSubscriber, 0L);
+ upstream.subscribe(lengthAwareSubscriber);
+
+ testSubscriber.awaitTerminalEvent(5, TimeUnit.SECONDS);
+ testSubscriber.assertComplete();
+ assertThat(testSubscriber.values()).isEmpty();
+ }
+
+ private static class TestSubscription implements Subscription {
+ private final Subscription wrapped;
+ private final AtomicLong cancelInvocations = new AtomicLong();
+
+ private TestSubscription(Subscription wrapped) {
+ this.wrapped = wrapped;
+ }
+
+ @Override
+ public void request(long l) {
+ this.wrapped.request(l);
+ }
+
+ @Override
+ public void cancel() {
+ cancelInvocations.incrementAndGet();
+ this.wrapped.cancel();
+ }
+ }
+
+ private long totalRemaining(List buffers) {
+ return buffers.stream().mapToLong(ByteBuffer::remaining).sum();
+ }
+
+ private static class SubscriptionWrappingSubscriber implements Subscriber {
+ private final Subscriber wrapped;
+ private TestSubscription wrappedSubscription;
+
+ private SubscriptionWrappingSubscriber(Subscriber wrapped) {
+ this.wrapped = wrapped;
+ }
+
+ @Override
+ public void onSubscribe(Subscription subscription) {
+ this.wrappedSubscription = new TestSubscription(subscription);
+ this.wrapped.onSubscribe(this.wrappedSubscription);
+ }
+
+ @Override
+ public void onNext(ByteBuffer byteBuffer) {
+ this.wrapped.onNext(byteBuffer);
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ this.wrapped.onError(throwable);
+ }
+
+ @Override
+ public void onComplete() {
+ this.wrapped.onComplete();
+ }
+ }
+}