diff --git a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java index 91ed18caa91d..b4805a78dca2 100644 --- a/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java +++ b/core/http-auth-aws/src/main/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisher.java @@ -51,6 +51,8 @@ * chunk-ext = *( ";" chunk-ext-name [ "=" chunk-ext-val ] ) * chunk-ext-name = token * chunk-ext-val = token / quoted-string + * + * trailer-part = *( header-field CRLF ) * * * @see ChunkedEncodedInputStream @@ -60,9 +62,12 @@ public class ChunkedEncodedPublisher implements Publisher { private static final byte[] CRLF = {'\r', '\n'}; private static final byte SEMICOLON = ';'; private static final byte EQUALS = '='; + private static final byte COLON = ':'; + private static final byte COMMA = ','; private final Publisher wrapped; private final List extensions = new ArrayList<>(); + private final List trailers = new ArrayList<>(); private final int chunkSize; private ByteBuffer chunkBuffer; private final boolean addEmptyTrailingChunk; @@ -71,6 +76,7 @@ public ChunkedEncodedPublisher(Builder b) { this.wrapped = b.publisher; this.chunkSize = b.chunkSize; this.extensions.addAll(b.extensions); + this.trailers.addAll(b.trailers); this.addEmptyTrailingChunk = b.addEmptyTrailingChunk; } @@ -125,10 +131,9 @@ public Publisher map(Publisher upstream, Function upstream.subscribe(MappingSubscriber.create(subscriber, mapper)); } - // TODO: Trailing checksum private ByteBuffer encodeChunk(ByteBuffer byteBuffer) { int contentLen = byteBuffer.remaining(); - byte[] chunkSizeHex = Integer.toHexString(contentLen).getBytes(StandardCharsets.UTF_8); + byte[] chunkSizeHex = Integer.toHexString(byteBuffer.remaining()).getBytes(StandardCharsets.UTF_8); List> chunkExtensions = this.extensions.stream() .map(e -> { @@ -138,12 +143,30 @@ private ByteBuffer encodeChunk(ByteBuffer byteBuffer) { int extensionsLength = calculateExtensionsLength(chunkExtensions); - int encodedLen = chunkSizeHex.length + extensionsLength + CRLF.length + contentLen + CRLF.length; + boolean isTrailerChunk = contentLen == 0; + + List trailerData; + if (isTrailerChunk) { + trailerData = getTrailerData(); + } else { + trailerData = Collections.emptyList(); + } + + int trailerLen = trailerData.stream() + // + 2 for each CRLF that ends the header-field + .mapToInt(t -> t.remaining() + 2) + .sum(); + + int encodedLen = chunkSizeHex.length + extensionsLength + CRLF.length + contentLen + trailerLen + CRLF.length; + + if (isTrailerChunk) { + encodedLen += CRLF.length; + } ByteBuffer encoded = ByteBuffer.allocate(encodedLen); - encoded.put(chunkSizeHex); - chunkExtensions.forEach(p -> { + encoded.put(chunkSizeHex); // chunk-size + chunkExtensions.forEach(p -> { // chunk-ext encoded.put(SEMICOLON); encoded.put(p.left()); if (p.right() != null && p.right().length > 0) { @@ -151,11 +174,23 @@ private ByteBuffer encodeChunk(ByteBuffer byteBuffer) { encoded.put(p.right()); } }); - - encoded.put(CRLF); - encoded.put(byteBuffer); encoded.put(CRLF); + // chunk-data + if (byteBuffer.hasRemaining()) { + encoded.put(byteBuffer); + encoded.put(CRLF); + } + + if (isTrailerChunk) { + // trailer-part + trailerData.forEach(t -> { + encoded.put(t); + encoded.put(CRLF); + }); + encoded.put(CRLF); + } + encoded.flip(); return encoded; @@ -174,6 +209,46 @@ private int calculateExtensionsLength(List> chunkExtensions }).sum(); } + private List getTrailerData() { + List data = new ArrayList<>(); + + for (TrailerProvider provider : trailers) { + Pair> trailer = provider.get(); + + byte[] key = trailer.left().getBytes(StandardCharsets.UTF_8); + List values = trailer.right() + .stream().map(v -> v.getBytes(StandardCharsets.UTF_8)) + .collect(Collectors.toList()); + + if (values.isEmpty()) { + throw new RuntimeException(String.format("Trailing header '%s' has no values", trailer.left())); + } + + int valuesLen = values.stream().mapToInt(v -> v.length).sum(); + // name:value1,value2,.. + int size = key.length + + 1 // colon + + valuesLen + + values.size() - 1; // commas + + ByteBuffer trailerData = ByteBuffer.allocate(size); + + trailerData.put(key); + trailerData.put(COLON); + + for (int i = 0; i < values.size(); ++i) { + trailerData.put(values.get(i)); + if (i + 1 != values.size()) { + trailerData.put(COMMA); + } + } + + trailerData.flip(); + data.add(trailerData); + } + return data; + } + private class ChunkingSubscriber extends DelegatingSubscriber> { protected ChunkingSubscriber(Subscriber> subscriber) { super(subscriber); @@ -222,6 +297,7 @@ public static class Builder { private int chunkSize; private boolean addEmptyTrailingChunk; private final List extensions = new ArrayList<>(); + private final List trailers = new ArrayList<>(); public Builder publisher(Publisher publisher) { this.publisher = publisher; @@ -243,6 +319,11 @@ public Builder addExtension(ChunkExtensionProvider extension) { return this; } + public Builder addTrailer(TrailerProvider trailerProvider) { + this.trailers.add(trailerProvider); + return this; + } + public ChunkedEncodedPublisher build() { return new ChunkedEncodedPublisher(this); } diff --git a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java index 84bedcaea7b9..909ae0289616 100644 --- a/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java +++ b/core/http-auth-aws/src/test/java/software/amazon/awssdk/http/auth/aws/internal/signer/chunkedencoding/ChunkedEncodedPublisherTest.java @@ -21,6 +21,8 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.PrimitiveIterator; import java.util.Random; @@ -46,6 +48,108 @@ public void setup() { CRC32.reset(); } + @Test + public void subscribe_publisherEmpty_onlyProducesTrailer() { + Publisher emptyPublisher = Flowable.empty(); + + ChunkedEncodedPublisher build = newChunkedBuilder(emptyPublisher) + .addTrailer(() -> Pair.of("foo", Collections.singletonList("1"))) + .addTrailer(() -> Pair.of("bar", Collections.singletonList("2"))) + .addEmptyTrailingChunk(true) + .build(); + + List chunks = getAllElements(build); + + assertThat(chunks.size()).isEqualTo(1); + + String trailerAsString = StandardCharsets.UTF_8.decode(chunks.get(0)).toString(); + + assertThat(trailerAsString).isEqualTo( + "0\r\n" + + "foo:1\r\n" + + "bar:2\r\n" + + "\r\n"); + } + + @Test + void subscribe_trailerProviderPresent_trailerPartAdded() { + TestPublisher upstream = randomPublisherOfLength(8); + + TrailerProvider trailerProvider = new StaticTrailerProvider("foo", "bar"); + + ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() + .publisher(upstream) + .chunkSize(CHUNK_SIZE) + .addEmptyTrailingChunk(true) + .addTrailer(trailerProvider) + .build(); + + List chunks = getAllElements(chunkedPublisher); + + String expectedTrailer = "foo:bar"; + String trailerAsString = StandardCharsets.UTF_8.decode(chunks.get(1)).toString().trim(); + assertThat(trailerAsString).endsWith(expectedTrailer); + } + + @Test + void subscribe_trailerProviderPresent_multipleValues_trailerPartAdded() { + TestPublisher upstream = randomPublisherOfLength(8); + + TrailerProvider trailerProvider = new StaticTrailerProvider("foo", Arrays.asList("bar1", "bar2", "bar3")); + + ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() + .publisher(upstream) + .chunkSize(CHUNK_SIZE) + .addEmptyTrailingChunk(true) + .addTrailer(trailerProvider) + .build(); + + List chunks = getAllElements(chunkedPublisher); + + String expectedTrailer = "foo:bar1,bar2,bar3"; + String trailerAsString = StandardCharsets.UTF_8.decode(chunks.get(1)).toString().trim(); + assertThat(trailerAsString).endsWith(expectedTrailer); + } + + @Test + void subscribe_trailerProviderPresent_onlyInvokedOnce() { + TestPublisher upstream = randomPublisherOfLength(8); + + TrailerProvider trailerProvider = Mockito.spy(new StaticTrailerProvider("foo", "bar")); + + ChunkedEncodedPublisher chunkedPublisher = ChunkedEncodedPublisher.builder() + .publisher(upstream) + .addEmptyTrailingChunk(true) + .chunkSize(CHUNK_SIZE) + .addTrailer(trailerProvider).build(); + + getAllElements(chunkedPublisher); + + Mockito.verify(trailerProvider, Mockito.times(1)).get(); + } + + @Test + void subscribe_trailerPresent_trailerFormattedCorrectly() { + TestPublisher testPublisher = randomPublisherOfLength(32); + + TrailerProvider trailerProvider = new StaticTrailerProvider("foo", "bar"); + + ChunkedEncodedPublisher chunkedPublisher = newChunkedBuilder(testPublisher) + .addTrailer(trailerProvider) + .addEmptyTrailingChunk(true) + .build(); + + List chunks = getAllElements(chunkedPublisher); + + ByteBuffer last = chunks.get(chunks.size() - 1); + + String expected = "0\r\n" + + "foo:bar\r\n" + + "\r\n"; + + assertThat(chunkAsString(last)).isEqualTo(expected); + } + @Test void subscribe_wrappedDoesNotFillBuffer_allDataInSingleChunk() { ByteBuffer element = ByteBuffer.wrap("hello world".getBytes(StandardCharsets.UTF_8)); @@ -212,6 +316,10 @@ void subscribe_extensionsPresent_extensionsInvokedForEachChunk() { } } + private static ChunkedEncodedPublisher.Builder newChunkedBuilder(Publisher publisher) { + return ChunkedEncodedPublisher.builder().publisher(publisher).chunkSize(CHUNK_SIZE); + } + private TestPublisher randomPublisherOfLength(int bytes) { List elements = new ArrayList<>(); @@ -239,6 +347,10 @@ private List getAllElements(Publisher publisher) { return Flowable.fromPublisher(publisher).toList().blockingGet(); } + private String chunkAsString(ByteBuffer chunk) { + return StandardCharsets.UTF_8.decode(chunk).toString(); + } + private String getHeaderAsString(ByteBuffer chunk) { return StandardCharsets.UTF_8.decode(getHeader(chunk)).toString(); } @@ -323,4 +435,24 @@ public Pair get(ByteBuffer chunk) { return Pair.of(key, value); } } + + private static class StaticTrailerProvider implements TrailerProvider { + private final String key; + private final List values; + + public StaticTrailerProvider(String key, String value) { + this.key = key; + this.values = Collections.singletonList(value); + } + + public StaticTrailerProvider(String key, List values) { + this.key = key; + this.values = values; + } + + @Override + public Pair> get() { + return Pair.of(key, values); + } + } }