diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java index b4805a78dca2..d0196b48f5a3 100644 --- a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java @@ -26,14 +26,16 @@ import org.reactivestreams.Subscriber; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; import software.amazon.awssdk.utils.async.AddingTrailingDataSubscriber; +import software.amazon.awssdk.utils.async.ContentLengthAwareSubscriber; import software.amazon.awssdk.utils.async.DelegatingSubscriber; import software.amazon.awssdk.utils.async.FlatteningSubscriber; import software.amazon.awssdk.utils.internal.MappingSubscriber; /** * An implementation of chunk-transfer encoding, but by wrapping a {@link Publisher} of {@link ByteBuffer}. This implementation - * supports chunk-headers, chunk-extensions. + * supports chunk-headers, chunk-extensions, and trailer-part. *

* Per RFC-7230, a chunk-transfer encoded message is * defined as: @@ -66,6 +68,7 @@ public class ChunkedEncodedPublisher implements Publisher { private static final byte COMMA = ','; private final Publisher wrapped; + private final long contentLength; private final List extensions = new ArrayList<>(); private final List trailers = new ArrayList<>(); private final int chunkSize; @@ -74,6 +77,7 @@ public class ChunkedEncodedPublisher implements Publisher { public ChunkedEncodedPublisher(Builder b) { this.wrapped = b.publisher; + this.contentLength = Validate.notNull(b.contentLength, "contentLength must not be null"); this.chunkSize = b.chunkSize; this.extensions.addAll(b.extensions); this.trailers.addAll(b.trailers); @@ -82,7 +86,8 @@ public ChunkedEncodedPublisher(Builder b) { @Override public void subscribe(Subscriber subscriber) { - Publisher> chunked = chunk(wrapped); + Publisher lengthEnforced = limitLength(wrapped, contentLength); + Publisher> chunked = chunk(lengthEnforced); Publisher> trailingAdded = addTrailingChunks(chunked); Publisher flattened = flatten(trailingAdded); Publisher encoded = map(flattened, this::encodeChunk); @@ -111,6 +116,10 @@ private Iterable> getTrailingChunks() { return Collections.singletonList(trailing); } + private Publisher limitLength(Publisher publisher, long length) { + return subscriber -> publisher.subscribe(new ContentLengthAwareSubscriber(subscriber, length)); + } + private Publisher> chunk(Publisher upstream) { return subscriber -> { upstream.subscribe(new ChunkingSubscriber(subscriber)); @@ -153,8 +162,7 @@ private ByteBuffer encodeChunk(ByteBuffer byteBuffer) { } int trailerLen = trailerData.stream() - // + 2 for each CRLF that ends the header-field - .mapToInt(t -> t.remaining() + 2) + .mapToInt(t -> t.remaining() + CRLF.length) .sum(); int encodedLen = chunkSizeHex.length + extensionsLength + CRLF.length + contentLen + trailerLen + CRLF.length; @@ -188,11 +196,11 @@ private ByteBuffer encodeChunk(ByteBuffer byteBuffer) { encoded.put(t); encoded.put(CRLF); }); + // empty line ends the request body encoded.put(CRLF); } encoded.flip(); - return encoded; } @@ -294,6 +302,7 @@ public void onNext(ByteBuffer byteBuffer) { public static class Builder { private Publisher publisher; + private Long contentLength; private int chunkSize; private boolean addEmptyTrailingChunk; private final List extensions = new ArrayList<>(); @@ -304,6 +313,15 @@ public Builder publisher(Publisher publisher) { return this; } + public Publisher publisher() { + return publisher; + } + + public Builder contentLength(long contentLength) { + this.contentLength = contentLength; + return this; + } + public Builder chunkSize(int chunkSize) { this.chunkSize = chunkSize; return this; @@ -324,6 +342,10 @@ public Builder addTrailer(TrailerProvider trailerProvider) { return this; } + public List trailers() { + return trailers; + } + public ChunkedEncodedPublisher build() { return new ChunkedEncodedPublisher(this); } diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java index 909ae0289616..7f62802ecd1e 100644 --- a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java @@ -18,6 +18,7 @@ import static org.assertj.core.api.AssertionsForClassTypes.assertThat; import io.reactivex.Flowable; +import io.reactivex.subscribers.TestSubscriber; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -26,6 +27,7 @@ import java.util.List; import java.util.PrimitiveIterator; import java.util.Random; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; import org.junit.jupiter.api.BeforeEach; @@ -56,6 +58,7 @@ public void subscribe_publisherEmpty_onlyProducesTrailer() { .addTrailer(() -> Pair.of("foo", Collections.singletonList("1"))) .addTrailer(() -> Pair.of("bar", Collections.singletonList("2"))) .addEmptyTrailingChunk(true) + .contentLength(0) .build(); List chunks = getAllElements(build); @@ -73,12 +76,14 @@ public void subscribe_publisherEmpty_onlyProducesTrailer() { @Test void subscribe_trailerProviderPresent_trailerPartAdded() { + int contentLength = 8; TestPublisher upstream = randomPublisherOfLength(8); TrailerProvider trailerProvider = new StaticTrailerProvider("foo", "bar"); ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() .publisher(upstream) + .contentLength(contentLength) .chunkSize(CHUNK_SIZE) .addEmptyTrailingChunk(true) .addTrailer(trailerProvider) @@ -93,12 +98,14 @@ void subscribe_trailerProviderPresent_trailerPartAdded() { @Test void subscribe_trailerProviderPresent_multipleValues_trailerPartAdded() { - TestPublisher upstream = randomPublisherOfLength(8); + int contentLength = 8; + TestPublisher upstream = randomPublisherOfLength(contentLength); TrailerProvider trailerProvider = new StaticTrailerProvider("foo", Arrays.asList("bar1", "bar2", "bar3")); ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() .publisher(upstream) + .contentLength(contentLength) .chunkSize(CHUNK_SIZE) .addEmptyTrailingChunk(true) .addTrailer(trailerProvider) @@ -113,7 +120,8 @@ void subscribe_trailerProviderPresent_multipleValues_trailerPartAdded() { @Test void subscribe_trailerProviderPresent_onlyInvokedOnce() { - TestPublisher upstream = randomPublisherOfLength(8); + int contentLength = 8; + TestPublisher upstream = randomPublisherOfLength(contentLength); TrailerProvider trailerProvider = Mockito.spy(new StaticTrailerProvider("foo", "bar")); @@ -121,6 +129,7 @@ void subscribe_trailerProviderPresent_onlyInvokedOnce() { .publisher(upstream) .addEmptyTrailingChunk(true) .chunkSize(CHUNK_SIZE) + .contentLength(contentLength) .addTrailer(trailerProvider).build(); getAllElements(chunkedPublisher); @@ -130,13 +139,15 @@ void subscribe_trailerProviderPresent_onlyInvokedOnce() { @Test void subscribe_trailerPresent_trailerFormattedCorrectly() { - TestPublisher testPublisher = randomPublisherOfLength(32); + int contentLength = 32; + TestPublisher testPublisher = randomPublisherOfLength(contentLength); TrailerProvider trailerProvider = new StaticTrailerProvider("foo", "bar"); ChunkedEncodedPublisher chunkedPublisher = newChunkedBuilder(testPublisher) .addTrailer(trailerProvider) .addEmptyTrailingChunk(true) + .contentLength(contentLength) .build(); List chunks = getAllElements(chunkedPublisher); @@ -152,11 +163,13 @@ void subscribe_trailerPresent_trailerFormattedCorrectly() { @Test void subscribe_wrappedDoesNotFillBuffer_allDataInSingleChunk() { - ByteBuffer element = ByteBuffer.wrap("hello world".getBytes(StandardCharsets.UTF_8)); + byte[] content = "hello world".getBytes(StandardCharsets.UTF_8); + ByteBuffer element = ByteBuffer.wrap(content); Flowable upstream = Flowable.just(element.duplicate()); ChunkedEncodedPublisher publisher = ChunkedEncodedPublisher.builder() .chunkSize(CHUNK_SIZE) + .contentLength(content.length) .publisher(upstream) .build(); @@ -169,7 +182,8 @@ void subscribe_wrappedDoesNotFillBuffer_allDataInSingleChunk() { @Test void subscribe_extensionHasNoValue_formattedCorrectly() { - TestPublisher testPublisher = randomPublisherOfLength(8); + int contentLength = 8; + TestPublisher testPublisher = randomPublisherOfLength(contentLength); ChunkExtensionProvider extensionProvider = new StaticExtensionProvider("foo", ""); @@ -178,6 +192,7 @@ void subscribe_extensionHasNoValue_formattedCorrectly() { .publisher(testPublisher) .addExtension(extensionProvider) .chunkSize(CHUNK_SIZE) + .contentLength(contentLength) .build(); List chunks = getAllElements(chunkPublisher); @@ -187,11 +202,13 @@ void subscribe_extensionHasNoValue_formattedCorrectly() { @Test void subscribe_multipleExtensions_formattedCorrectly() { - TestPublisher testPublisher = randomPublisherOfLength(8); + int contentLength = 8; + TestPublisher testPublisher = randomPublisherOfLength(contentLength); ChunkedEncodedPublisher.Builder chunkPublisher = ChunkedEncodedPublisher.builder() .publisher(testPublisher) + .contentLength(contentLength) .chunkSize(CHUNK_SIZE); Stream.of("1", "2", "3") @@ -207,10 +224,12 @@ void subscribe_multipleExtensions_formattedCorrectly() { void subscribe_randomElementSizes_dataChunkedCorrectly() { for (int i = 0; i < 512; ++i) { int nChunks = 24; - TestPublisher byteBufferPublisher = randomPublisherOfLength(CHUNK_SIZE * 24); + int contentLength = nChunks * CHUNK_SIZE; + TestPublisher byteBufferPublisher = randomPublisherOfLength(contentLength); ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() .publisher(byteBufferPublisher) + .contentLength(contentLength) .chunkSize(CHUNK_SIZE) .build(); @@ -232,7 +251,8 @@ void subscribe_randomElementSizes_dataChunkedCorrectly() { void subscribe_randomElementSizes_chunksHaveExtensions_dataChunkedCorrectly() { for (int i = 0; i < 512; ++i) { int nChunks = 24; - TestPublisher byteBufferPublisher = randomPublisherOfLength(CHUNK_SIZE * 24); + int contentLength = CHUNK_SIZE * nChunks; + TestPublisher byteBufferPublisher = randomPublisherOfLength(contentLength); StaticExtensionProvider extensionProvider = Mockito.spy(new StaticExtensionProvider("foo", "bar")); @@ -240,6 +260,7 @@ void subscribe_randomElementSizes_chunksHaveExtensions_dataChunkedCorrectly() { .publisher(byteBufferPublisher) .addExtension(extensionProvider) .chunkSize(CHUNK_SIZE) + .contentLength(contentLength) .build(); List chunks = getAllElements(chunkedPublisher); @@ -264,12 +285,14 @@ void subscribe_randomElementSizes_chunksHaveExtensions_dataChunkedCorrectly() { @Test void subscribe_addTrailingChunkTrue_trailingChunkAdded() { - TestPublisher testPublisher = randomPublisherOfLength(CHUNK_SIZE * 2); + int contentLength = CHUNK_SIZE * 2; + TestPublisher testPublisher = randomPublisherOfLength(contentLength); ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() .publisher(testPublisher) .chunkSize(CHUNK_SIZE) .addEmptyTrailingChunk(true) + .contentLength(contentLength) .build(); List chunks = getAllElements(chunkedPublisher); @@ -285,7 +308,12 @@ void subscribe_addTrailingChunkTrue_upstreamEmpty_trailingChunkAdded() { Publisher empty = Flowable.empty(); ChunkedEncodedPublisher chunkedPublisher = - ChunkedEncodedPublisher.builder().publisher(empty).chunkSize(CHUNK_SIZE).addEmptyTrailingChunk(true).build(); + ChunkedEncodedPublisher.builder() + .publisher(empty) + .chunkSize(CHUNK_SIZE) + .addEmptyTrailingChunk(true) + .contentLength(0) + .build(); List chunks = getAllElements(chunkedPublisher); @@ -297,10 +325,15 @@ void subscribe_extensionsPresent_extensionsInvokedForEachChunk() { ChunkExtensionProvider mockProvider = Mockito.spy(new StaticExtensionProvider("foo", "bar")); int nChunks = 16; - TestPublisher elements = randomPublisherOfLength(nChunks * CHUNK_SIZE); + int contentLength = CHUNK_SIZE * nChunks; + TestPublisher elements = randomPublisherOfLength(contentLength); - ChunkedEncodedPublisher chunkPublisher = - ChunkedEncodedPublisher.builder().publisher(elements).chunkSize(CHUNK_SIZE).addExtension(mockProvider).build(); + ChunkedEncodedPublisher chunkPublisher = ChunkedEncodedPublisher.builder() + .publisher(elements) + .contentLength(contentLength) + .chunkSize(CHUNK_SIZE) + .addExtension(mockProvider) + .build(); List chunks = getAllElements(chunkPublisher); @@ -316,6 +349,28 @@ void subscribe_extensionsPresent_extensionsInvokedForEachChunk() { } } + @Test + void subscribe_wrappedExceedsContentLength_dataTruncatedToLength() { + int contentLength = CHUNK_SIZE * 4 - 1; + TestPublisher elements = randomPublisherOfLength(contentLength * 2); + + TestSubscriber subscriber = new TestSubscriber<>(); + ChunkedEncodedPublisher chunkPublisher = newChunkedBuilder(elements).contentLength(contentLength) + .build(); + + chunkPublisher.subscribe(subscriber); + + subscriber.awaitTerminalEvent(30, TimeUnit.SECONDS); + + int totalRemaining = subscriber.values() + .stream() + .map(this::stripEncoding) + .mapToInt(ByteBuffer::remaining) + .sum(); + + assertThat(totalRemaining).isEqualTo(contentLength); + } + private static ChunkedEncodedPublisher.Builder newChunkedBuilder(Publisher publisher) { return ChunkedEncodedPublisher.builder().publisher(publisher).chunkSize(CHUNK_SIZE); } @@ -401,6 +456,10 @@ private ByteBuffer stripEncoding(ByteBuffer chunk) { return stripped; } + private long totalRemaining(List buffers) { + return buffers.stream().mapToLong(ByteBuffer::remaining).sum(); + } + private static class TestPublisher implements Publisher { private final Publisher wrapped; private final byte[] wrappedChecksum; diff --git a/utils/src/main/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriber.java b/utils/src/main/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriber.java new file mode 100644 index 000000000000..04d009174a20 --- /dev/null +++ b/utils/src/main/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriber.java @@ -0,0 +1,77 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import java.nio.ByteBuffer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkProtectedApi; + +/** + * Decorator subscriber that limits the number of bytes sent to the wrapped subscriber to at most {@code contentLength}. Once + * the given content length is reached, the upstream subscription is cancelled, and the wrapped subscriber is completed. + */ +@SdkProtectedApi +public final class ContentLengthAwareSubscriber implements Subscriber { + private final Subscriber subscriber; + private Subscription subscription; + private boolean subscriptionCancelled; + private long remaining; + + public ContentLengthAwareSubscriber(Subscriber subscriber, long contentLength) { + this.subscriber = subscriber; + this.remaining = contentLength; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (subscription == null) { + throw new NullPointerException("subscription must not be null"); + } + this.subscription = subscription; + subscriber.onSubscribe(subscription); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + if (remaining > 0) { + long bytesToRead = Math.min(remaining, byteBuffer.remaining()); + // cast is safe, min of long and int is <= max_int + byteBuffer.limit(byteBuffer.position() + (int) bytesToRead); + remaining -= bytesToRead; + subscriber.onNext(byteBuffer); + } + + if (remaining == 0 && !subscriptionCancelled) { + subscriptionCancelled = true; + subscription.cancel(); + onComplete(); + } + } + + @Override + public void onError(Throwable throwable) { + if (throwable == null) { + throw new NullPointerException("throwable cannot be null"); + } + subscriber.onError(throwable); + } + + @Override + public void onComplete() { + subscriber.onComplete(); + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/testutils/PublisherUtils.java b/utils/src/test/java/software/amazon/awssdk/testutils/PublisherUtils.java new file mode 100644 index 000000000000..45dc9d63f22b --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/testutils/PublisherUtils.java @@ -0,0 +1,50 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.testutils; + +import io.reactivex.Flowable; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.PrimitiveIterator; +import java.util.Random; +import org.reactivestreams.Publisher; + +public final class PublisherUtils { + private static final Random RNG = new Random(); + + private PublisherUtils() { + } + + public static Publisher randomPublisherOfLength(int bytes, int min, int max) { + List elements = new ArrayList<>(); + + PrimitiveIterator.OfInt sizeIter = RNG.ints(min, max).iterator(); + + while (bytes > 0) { + int elementSize = sizeIter.next(); + elementSize = Math.min(elementSize, bytes); + + bytes -= elementSize; + + byte[] elementContent = new byte[elementSize]; + RNG.nextBytes(elementContent); + elements.add(ByteBuffer.wrap(elementContent)); + } + + return Flowable.fromIterable(elements); + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTckTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTckTest.java new file mode 100644 index 000000000000..e2e419bb3004 --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTckTest.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import io.reactivex.subscribers.TestSubscriber; +import java.nio.ByteBuffer; +import org.reactivestreams.Subscriber; +import org.reactivestreams.tck.SubscriberBlackboxVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class ContentLengthAwareSubscriberTckTest extends SubscriberBlackboxVerification { + + public ContentLengthAwareSubscriberTckTest() { + super(new TestEnvironment()); + } + + @Override + public Subscriber createSubscriber() { + return new ContentLengthAwareSubscriber(new TestSubscriber<>(), 16); + } + + @Override + public ByteBuffer createElement(int i) { + return ByteBuffer.wrap(Integer.toString(i).getBytes()); + } +} diff --git a/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTest.java b/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTest.java new file mode 100644 index 000000000000..9f8a982b467f --- /dev/null +++ b/utils/src/test/java/software/amazon/awssdk/utils/async/ContentLengthAwareSubscriberTest.java @@ -0,0 +1,187 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.utils.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.testutils.PublisherUtils.randomPublisherOfLength; + +import io.reactivex.subscribers.TestSubscriber; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +public class ContentLengthAwareSubscriberTest { + @Test + void subscribe_upstreamExceedsContentLength_correctlyTruncates() { + long contentLength = 64; + Publisher upstream = randomPublisherOfLength(8192, 8, 16); + + TestSubscriber subscriber = new TestSubscriber<>(); + + ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(subscriber, contentLength); + upstream.subscribe(lengthAwareSubscriber); + + assertThat(totalRemaining(subscriber.values())).isEqualTo(contentLength); + } + + @Test + void subscribe_upstreamHasExactlyContentLength_signalsComplete() { + long contentLength = 8192; + Publisher upstream = randomPublisherOfLength((int) contentLength, 8, 16); + + TestSubscriber subscriber = new TestSubscriber<>(); + ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(subscriber, contentLength); + upstream.subscribe(lengthAwareSubscriber); + + subscriber.assertComplete(); + assertThat(totalRemaining(subscriber.values())).isEqualTo(contentLength); + } + + @Test + void subscribe_upstreamExceedsContentLength_request1BufferAtATime_correctlyTruncates() throws Exception { + long contentLength = 8192; + + Publisher upstream = randomPublisherOfLength((int) contentLength * 2, 8, 16); + + CompletableFuture subscriberFinished = new CompletableFuture<>(); + List buffers = new ArrayList<>(); + + Subscriber testSubscriber = new Subscriber() { + private Subscription subscription; + + @Override + public void onSubscribe(Subscription subscription) { + this.subscription = subscription; + this.subscription.request(1); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + buffers.add(byteBuffer); + subscription.request(1); + } + + @Override + public void onError(Throwable throwable) { + subscriberFinished.completeExceptionally(throwable); + } + + @Override + public void onComplete() { + subscriberFinished.complete(null); + } + }; + + testSubscriber = Mockito.spy(testSubscriber); + + ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(testSubscriber, contentLength); + upstream.subscribe(lengthAwareSubscriber); + + subscriberFinished.get(1, TimeUnit.MINUTES); + Mockito.verify(testSubscriber, Mockito.times(1)).onComplete(); + assertThat(totalRemaining(buffers)).isEqualTo(contentLength); + } + + @Test + void subscribe_upstreamExceedsContentLength_upstreamSubscriptionCancelledAfterContentLengthReached() { + long contentLength = 64; + Publisher upstream = randomPublisherOfLength((int) contentLength * 4, 8, 16); + + TestSubscriber testSubscriber = new TestSubscriber<>(); + ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(testSubscriber, contentLength); + SubscriptionWrappingSubscriber subscriptionWrappingSubscriber = new SubscriptionWrappingSubscriber(lengthAwareSubscriber); + upstream.subscribe(subscriptionWrappingSubscriber); + + testSubscriber.assertComplete(); + assertThat(subscriptionWrappingSubscriber.wrappedSubscription.cancelInvocations.get()).isEqualTo(1L); + assertThat(totalRemaining(testSubscriber.values())).isEqualTo(contentLength); + } + + @Test + void subscribe_upstreamHasContentAndContentLength0_signalsComplete() { + Publisher upstream = randomPublisherOfLength(128, 8, 16); + + TestSubscriber testSubscriber = new TestSubscriber<>(); + ContentLengthAwareSubscriber lengthAwareSubscriber = new ContentLengthAwareSubscriber(testSubscriber, 0L); + upstream.subscribe(lengthAwareSubscriber); + + testSubscriber.awaitTerminalEvent(5, TimeUnit.SECONDS); + testSubscriber.assertComplete(); + assertThat(testSubscriber.values()).isEmpty(); + } + + private static class TestSubscription implements Subscription { + private final Subscription wrapped; + private final AtomicLong cancelInvocations = new AtomicLong(); + + private TestSubscription(Subscription wrapped) { + this.wrapped = wrapped; + } + + @Override + public void request(long l) { + this.wrapped.request(l); + } + + @Override + public void cancel() { + cancelInvocations.incrementAndGet(); + this.wrapped.cancel(); + } + } + + private long totalRemaining(List buffers) { + return buffers.stream().mapToLong(ByteBuffer::remaining).sum(); + } + + private static class SubscriptionWrappingSubscriber implements Subscriber { + private final Subscriber wrapped; + private TestSubscription wrappedSubscription; + + private SubscriptionWrappingSubscriber(Subscriber wrapped) { + this.wrapped = wrapped; + } + + @Override + public void onSubscribe(Subscription subscription) { + this.wrappedSubscription = new TestSubscription(subscription); + this.wrapped.onSubscribe(this.wrappedSubscription); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + this.wrapped.onNext(byteBuffer); + } + + @Override + public void onError(Throwable throwable) { + this.wrapped.onError(throwable); + } + + @Override + public void onComplete() { + this.wrapped.onComplete(); + } + } +}