diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java new file mode 100644 index 000000000000..91ed18caa91d --- /dev/null +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java @@ -0,0 +1,250 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.async.AddingTrailingDataSubscriber; +import software.amazon.awssdk.utils.async.DelegatingSubscriber; +import software.amazon.awssdk.utils.async.FlatteningSubscriber; +import software.amazon.awssdk.utils.internal.MappingSubscriber; + +/** + * An implementation of chunk-transfer encoding, but by wrapping a {@link Publisher} of {@link ByteBuffer}. This implementation + * supports chunk-headers, chunk-extensions. + *

+ * Per RFC-7230, a chunk-transfer encoded message is + * defined as: + *

+ *     chunked-body   = *chunk
+ *                      last-chunk
+ *                      trailer-part
+ *                      CRLF
+ *     chunk          = chunk-size [ chunk-ext ] CRLF
+ *                      chunk-data CRLF
+ *     chunk-size     = 1*HEXDIG
+ *     last-chunk     = 1*("0") [ chunk-ext ] CRLF
+ *     chunk-data     = 1*OCTET ; a sequence of chunk-size octets
+ *
+ *     chunk-ext      = *( ";" chunk-ext-name [ "=" chunk-ext-val ] )
+ *     chunk-ext-name = token
+ *     chunk-ext-val  = token / quoted-string
+ * 
+ * + * @see ChunkedEncodedInputStream + */ +@SdkInternalApi +public class ChunkedEncodedPublisher implements Publisher { + private static final byte[] CRLF = {'\r', '\n'}; + private static final byte SEMICOLON = ';'; + private static final byte EQUALS = '='; + + private final Publisher wrapped; + private final List extensions = new ArrayList<>(); + private final int chunkSize; + private ByteBuffer chunkBuffer; + private final boolean addEmptyTrailingChunk; + + public ChunkedEncodedPublisher(Builder b) { + this.wrapped = b.publisher; + this.chunkSize = b.chunkSize; + this.extensions.addAll(b.extensions); + this.addEmptyTrailingChunk = b.addEmptyTrailingChunk; + } + + @Override + public void subscribe(Subscriber subscriber) { + Publisher> chunked = chunk(wrapped); + Publisher> trailingAdded = addTrailingChunks(chunked); + Publisher flattened = flatten(trailingAdded); + Publisher encoded = map(flattened, this::encodeChunk); + + encoded.subscribe(subscriber); + } + + public static Builder builder() { + return new Builder(); + } + + private Iterable> getTrailingChunks() { + List trailing = new ArrayList<>(); + + if (chunkBuffer != null) { + chunkBuffer.flip(); + if (chunkBuffer.hasRemaining()) { + trailing.add(chunkBuffer); + } + } + + if (addEmptyTrailingChunk) { + trailing.add(ByteBuffer.allocate(0)); + } + + return Collections.singletonList(trailing); + } + + private Publisher> chunk(Publisher upstream) { + return subscriber -> { + upstream.subscribe(new ChunkingSubscriber(subscriber)); + }; + } + + private Publisher flatten(Publisher> upstream) { + return subscriber -> upstream.subscribe(new FlatteningSubscriber<>(subscriber)); + } + + public Publisher> addTrailingChunks(Publisher> upstream) { + return subscriber -> { + upstream.subscribe(new AddingTrailingDataSubscriber<>(subscriber, this::getTrailingChunks)); + }; + } + + public Publisher map(Publisher upstream, Function mapper) { + return subscriber -> upstream.subscribe(MappingSubscriber.create(subscriber, mapper)); + } + + // TODO: Trailing checksum + private ByteBuffer encodeChunk(ByteBuffer byteBuffer) { + int contentLen = byteBuffer.remaining(); + byte[] chunkSizeHex = Integer.toHexString(contentLen).getBytes(StandardCharsets.UTF_8); + + List> chunkExtensions = this.extensions.stream() + .map(e -> { + ByteBuffer duplicate = byteBuffer.duplicate(); + return e.get(duplicate); + }).collect(Collectors.toList()); + + int extensionsLength = calculateExtensionsLength(chunkExtensions); + + int encodedLen = chunkSizeHex.length + extensionsLength + CRLF.length + contentLen + CRLF.length; + + ByteBuffer encoded = ByteBuffer.allocate(encodedLen); + encoded.put(chunkSizeHex); + + chunkExtensions.forEach(p -> { + encoded.put(SEMICOLON); + encoded.put(p.left()); + if (p.right() != null && p.right().length > 0) { + encoded.put(EQUALS); + encoded.put(p.right()); + } + }); + + encoded.put(CRLF); + encoded.put(byteBuffer); + encoded.put(CRLF); + + encoded.flip(); + + return encoded; + } + + private int calculateExtensionsLength(List> chunkExtensions) { + return chunkExtensions.stream() + .mapToInt(p -> { + int keyLen = p.left().length; + byte[] value = p.right(); + if (value.length > 0) { + return 1 + keyLen + 1 + value.length; // ';ext-key=ext-value' + } + // ';ext-key + return 1 + keyLen; + }).sum(); + } + + private class ChunkingSubscriber extends DelegatingSubscriber> { + protected ChunkingSubscriber(Subscriber> subscriber) { + super(subscriber); + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + if (chunkBuffer == null) { + chunkBuffer = ByteBuffer.allocate(chunkSize); + } + + long totalBufferedBytes = (long) chunkBuffer.position() + byteBuffer.remaining(); + int nBufferedChunks = (int) (totalBufferedBytes / chunkSize); + + List chunks = new ArrayList<>(nBufferedChunks); + + if (nBufferedChunks > 0) { + for (int i = 0; i < nBufferedChunks; i++) { + ByteBuffer slice = byteBuffer.slice(); + int maxBytesToCopy = Math.min(chunkBuffer.remaining(), slice.remaining()); + slice.limit(maxBytesToCopy); + + chunkBuffer.put(slice); + if (!chunkBuffer.hasRemaining()) { + chunkBuffer.flip(); + chunks.add(chunkBuffer); + chunkBuffer = ByteBuffer.allocate(chunkSize); + } + + byteBuffer.position(byteBuffer.position() + maxBytesToCopy); + } + + if (byteBuffer.hasRemaining()) { + chunkBuffer.put(byteBuffer); + } + } else { + chunkBuffer.put(byteBuffer); + } + + subscriber.onNext(chunks); + } + } + + public static class Builder { + private Publisher publisher; + private int chunkSize; + private boolean addEmptyTrailingChunk; + private final List extensions = new ArrayList<>(); + + public Builder publisher(Publisher publisher) { + this.publisher = publisher; + return this; + } + + public Builder chunkSize(int chunkSize) { + this.chunkSize = chunkSize; + return this; + } + + public Builder addEmptyTrailingChunk(boolean addEmptyTrailingChunk) { + this.addEmptyTrailingChunk = addEmptyTrailingChunk; + return this; + } + + public Builder addExtension(ChunkExtensionProvider extension) { + this.extensions.add(extension); + return this; + } + + public ChunkedEncodedPublisher build() { + return new ChunkedEncodedPublisher(this); + } + } +} diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTckTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTckTest.java new file mode 100644 index 000000000000..e539e5fe8ad4 --- /dev/null +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTckTest.java @@ -0,0 +1,70 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding; + +import io.reactivex.Flowable; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import org.reactivestreams.Publisher; +import org.reactivestreams.tck.PublisherVerification; +import org.reactivestreams.tck.TestEnvironment; + +public class ChunkedEncodedPublisherTckTest extends PublisherVerification { + private static final int INPUT_STREAM_ELEMENT_SIZE = 64; + private static final int CHUNK_SIZE = 16 * 1024; + + public ChunkedEncodedPublisherTckTest() { + super(new TestEnvironment()); + } + + @Override + public Publisher createPublisher(long l) { + return createChunkedPublisher(l); + } + + @Override + public Publisher createFailedPublisher() { + return null; + } + + @Override + public long maxElementsFromPublisher() { + return 512; + } + + private Publisher createChunkedPublisher(long chunksToProduce) { + // max of 8 MiB + long totalSize = chunksToProduce * CHUNK_SIZE; + + int totalElements = (int) (totalSize / INPUT_STREAM_ELEMENT_SIZE); + + byte[] content = new byte[INPUT_STREAM_ELEMENT_SIZE]; + + List elements = new ArrayList<>(); + for (int i = 0; i < totalElements; i++) { + elements.add(ByteBuffer.wrap(content)); + } + + Publisher inputPublisher = Flowable.fromIterable(elements); + + return ChunkedEncodedPublisher.builder() + .chunkSize(CHUNK_SIZE) + .publisher(inputPublisher) + .addEmptyTrailingChunk(false) + .build(); + } +} diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java new file mode 100644 index 000000000000..84bedcaea7b9 --- /dev/null +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java @@ -0,0 +1,326 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding; + +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +import io.reactivex.Flowable; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.PrimitiveIterator; +import java.util.Random; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.checksums.DefaultChecksumAlgorithm; +import software.amazon.awssdk.checksums.SdkChecksum; +import software.amazon.awssdk.utils.Pair; + +public class ChunkedEncodedPublisherTest { + private static final int CHUNK_SIZE = 16 * 1024; + private final Random RNG = new Random(); + private final SdkChecksum CRC32 = SdkChecksum.forAlgorithm(DefaultChecksumAlgorithm.CRC32); + + @BeforeEach + public void setup() { + CRC32.reset(); + } + + @Test + void subscribe_wrappedDoesNotFillBuffer_allDataInSingleChunk() { + ByteBuffer element = ByteBuffer.wrap("hello world".getBytes(StandardCharsets.UTF_8)); + Flowable upstream = Flowable.just(element.duplicate()); + + ChunkedEncodedPublisher publisher = ChunkedEncodedPublisher.builder() + .chunkSize(CHUNK_SIZE) + .publisher(upstream) + .build(); + + List chunks = getAllElements(publisher); + + assertThat(chunks.size()).isEqualTo(1); + assertThat(stripEncoding(chunks.get(0))) + .isEqualTo(element); + } + + @Test + void subscribe_extensionHasNoValue_formattedCorrectly() { + TestPublisher testPublisher = randomPublisherOfLength(8); + + ChunkExtensionProvider extensionProvider = new StaticExtensionProvider("foo", ""); + + ChunkedEncodedPublisher chunkPublisher = + ChunkedEncodedPublisher.builder() + .publisher(testPublisher) + .addExtension(extensionProvider) + .chunkSize(CHUNK_SIZE) + .build(); + + List chunks = getAllElements(chunkPublisher); + + assertThat(getHeaderAsString(chunks.get(0))).endsWith(";foo"); + } + + @Test + void subscribe_multipleExtensions_formattedCorrectly() { + TestPublisher testPublisher = randomPublisherOfLength(8); + + ChunkedEncodedPublisher.Builder chunkPublisher = + ChunkedEncodedPublisher.builder() + .publisher(testPublisher) + .chunkSize(CHUNK_SIZE); + + Stream.of("1", "2", "3") + .map(v -> new StaticExtensionProvider("key" + v, "value" + v)) + .forEach(chunkPublisher::addExtension); + + List chunks = getAllElements(chunkPublisher.build()); + + chunks.forEach(chunk -> assertThat(getHeaderAsString(chunk)).endsWith(";key1=value1;key2=value2;key3=value3")); + } + + @Test + void subscribe_randomElementSizes_dataChunkedCorrectly() { + for (int i = 0; i < 512; ++i) { + int nChunks = 24; + TestPublisher byteBufferPublisher = randomPublisherOfLength(CHUNK_SIZE * 24); + + ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() + .publisher(byteBufferPublisher) + .chunkSize(CHUNK_SIZE) + .build(); + + List chunks = getAllElements(chunkedPublisher); + + List stripped = chunks.stream().map(this::stripEncoding).collect(Collectors.toList()); + assertThat(stripped.size()).isEqualTo(nChunks); + + stripped.forEach(chunk -> assertThat(chunk.remaining()).isEqualTo(CHUNK_SIZE)); + + CRC32.reset(); + stripped.forEach(CRC32::update); + + assertThat(CRC32.getChecksumBytes()).isEqualTo(byteBufferPublisher.wrappedChecksum); + } + } + + @Test + void subscribe_randomElementSizes_chunksHaveExtensions_dataChunkedCorrectly() { + for (int i = 0; i < 512; ++i) { + int nChunks = 24; + TestPublisher byteBufferPublisher = randomPublisherOfLength(CHUNK_SIZE * 24); + + StaticExtensionProvider extensionProvider = Mockito.spy(new StaticExtensionProvider("foo", "bar")); + + ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() + .publisher(byteBufferPublisher) + .addExtension(extensionProvider) + .chunkSize(CHUNK_SIZE) + .build(); + + List chunks = getAllElements(chunkedPublisher); + + chunks.forEach(c -> { + String header = StandardCharsets.UTF_8.decode(getHeader(c)).toString(); + assertThat(header).isEqualTo("4000;foo=bar"); + }); + + List stripped = chunks.stream().map(this::stripEncoding).collect(Collectors.toList()); + + assertThat(stripped.size()).isEqualTo(nChunks); + + stripped.forEach(chunk -> assertThat(chunk.remaining()).isEqualTo(CHUNK_SIZE)); + + CRC32.reset(); + stripped.forEach(CRC32::update); + + assertThat(CRC32.getChecksumBytes()).isEqualTo(byteBufferPublisher.wrappedChecksum); + } + } + + @Test + void subscribe_addTrailingChunkTrue_trailingChunkAdded() { + TestPublisher testPublisher = randomPublisherOfLength(CHUNK_SIZE * 2); + + ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() + .publisher(testPublisher) + .chunkSize(CHUNK_SIZE) + .addEmptyTrailingChunk(true) + .build(); + + List chunks = getAllElements(chunkedPublisher); + + assertThat(chunks.size()).isEqualTo(3); + + ByteBuffer trailing = chunks.get(chunks.size() - 1); + assertThat(stripEncoding(trailing).remaining()).isEqualTo(0); + } + + @Test + void subscribe_addTrailingChunkTrue_upstreamEmpty_trailingChunkAdded() { + Publisher empty = Flowable.empty(); + + ChunkedEncodedPublisher chunkedPublisher = + ChunkedEncodedPublisher.builder().publisher(empty).chunkSize(CHUNK_SIZE).addEmptyTrailingChunk(true).build(); + + List chunks = getAllElements(chunkedPublisher); + + assertThat(chunks.size()).isEqualTo(1); + } + + @Test + void subscribe_extensionsPresent_extensionsInvokedForEachChunk() { + ChunkExtensionProvider mockProvider = Mockito.spy(new StaticExtensionProvider("foo", "bar")); + + int nChunks = 16; + TestPublisher elements = randomPublisherOfLength(nChunks * CHUNK_SIZE); + + ChunkedEncodedPublisher chunkPublisher = + ChunkedEncodedPublisher.builder().publisher(elements).chunkSize(CHUNK_SIZE).addExtension(mockProvider).build(); + + List chunks = getAllElements(chunkPublisher); + + ArgumentCaptor chunkCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + Mockito.verify(mockProvider, Mockito.times(nChunks)).get(chunkCaptor.capture()); + List extensionChunks = chunkCaptor.getAllValues(); + + for (int i = 0; i < chunks.size(); ++i) { + ByteBuffer chunk = chunks.get(i); + ByteBuffer extensionChunk = extensionChunks.get(i); + assertThat(stripEncoding(chunk)).isEqualTo(extensionChunk); + } + } + + private TestPublisher randomPublisherOfLength(int bytes) { + List elements = new ArrayList<>(); + + PrimitiveIterator.OfInt sizeIter = RNG.ints(16, 8192).iterator(); + + CRC32.reset(); + while (bytes > 0) { + int elementSize = sizeIter.next(); + elementSize = Math.min(elementSize, bytes); + + bytes -= elementSize; + + byte[] elementContent = new byte[elementSize]; + RNG.nextBytes(elementContent); + CRC32.update(elementContent); + elements.add(ByteBuffer.wrap(elementContent)); + } + + Flowable publisher = Flowable.fromIterable(elements); + + return new TestPublisher(publisher, CRC32.getChecksumBytes()); + } + + private List getAllElements(Publisher publisher) { + return Flowable.fromPublisher(publisher).toList().blockingGet(); + } + + private String getHeaderAsString(ByteBuffer chunk) { + return StandardCharsets.UTF_8.decode(getHeader(chunk)).toString(); + } + + private ByteBuffer getHeader(ByteBuffer chunk) { + ByteBuffer header = chunk.duplicate(); + byte a = header.get(0); + byte b = header.get(1); + + int i = 2; + for (; i < header.limit() && a != '\r' && b != '\n'; ++i) { + a = b; + b = header.get(i); + } + + header.limit(i - 2); + return header; + } + + private ByteBuffer stripEncoding(ByteBuffer chunk) { + ByteBuffer header = getHeader(chunk); + + ByteBuffer lengthHex = header.duplicate(); + + boolean semiFound = false; + while (lengthHex.hasRemaining()) { + byte b = lengthHex.get(); + if (b == ';') { + semiFound = true; + break; + } + } + + if (semiFound) { + lengthHex.position(lengthHex.position() - 1); + } + // assume the whole line is the length (no extensions) + lengthHex.flip(); + + int length = Integer.parseInt(StandardCharsets.UTF_8.decode(lengthHex).toString(), 16); + + ByteBuffer stripped = chunk.duplicate(); + + int chunkStart = header.remaining() + 2; + stripped.position(chunkStart); + stripped.limit(chunkStart + length); + + return stripped; + } + + private static class TestPublisher implements Publisher { + private final Publisher wrapped; + private final byte[] wrappedChecksum; + + public TestPublisher(Publisher wrapped, byte[] wrappedChecksum) { + this.wrapped = wrapped; + this.wrappedChecksum = new byte[wrappedChecksum.length]; + System.arraycopy(wrappedChecksum, 0, this.wrappedChecksum, 0, wrappedChecksum.length); + } + + @Override + public void subscribe(Subscriber subscriber) { + wrapped.subscribe(subscriber); + } + + public byte[] wrappedChecksum() { + return wrappedChecksum; + } + } + + private static class StaticExtensionProvider implements ChunkExtensionProvider { + private final byte[] key; + private final byte[] value; + + public StaticExtensionProvider(String key, String value) { + this.key = key.getBytes(StandardCharsets.UTF_8); + this.value = value == null ? null : value.getBytes(StandardCharsets.UTF_8); + } + + @Override + public Pair get(ByteBuffer chunk) { + return Pair.of(key, value); + } + } +}