Skip to content

Commit 8c23618

Browse files
authored
Buffer if content provider provided w/o len (#5837)
This updates the `RequestBody.fromContentProvider(ContentProvider, String)` method such that the underlying implementation will buffer the contents of the stream in memory during the first pass through the stream.
1 parent b1b5216 commit 8c23618

File tree

4 files changed

+255
-2
lines changed

4 files changed

+255
-2
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "feature",
3+
"category": "AWS SDK for Java v2",
4+
"contributor": "",
5+
"description": "Buffer input data from ContentStreamProvider to avoid the need to reread the stream after calculating its length."
6+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.core.internal.sync;
17+
18+
import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely;
19+
20+
import java.io.BufferedInputStream;
21+
import java.io.ByteArrayInputStream;
22+
import java.io.IOException;
23+
import java.io.InputStream;
24+
import software.amazon.awssdk.annotations.NotThreadSafe;
25+
import software.amazon.awssdk.annotations.SdkInternalApi;
26+
import software.amazon.awssdk.http.ContentStreamProvider;
27+
import software.amazon.awssdk.utils.IoUtils;
28+
29+
/**
30+
* {@code ContentStreamProvider} implementation that buffers the data stream data to memory as it's read. Once the underlying
31+
* stream is read fully, all subsequent calls to {@link #newStream()} will use the buffered data.
32+
*/
33+
@SdkInternalApi
34+
@NotThreadSafe
35+
public final class BufferingContentStreamProvider implements ContentStreamProvider {
36+
private final ContentStreamProvider delegate;
37+
private InputStream bufferedStream;
38+
39+
private byte[] bufferedStreamData;
40+
private int count;
41+
42+
public BufferingContentStreamProvider(ContentStreamProvider delegate) {
43+
this.delegate = delegate;
44+
}
45+
46+
@Override
47+
public InputStream newStream() {
48+
if (bufferedStreamData != null) {
49+
return new ByteArrayInputStream(bufferedStreamData, 0, this.count);
50+
}
51+
52+
if (bufferedStream == null) {
53+
InputStream delegateStream = delegate.newStream();
54+
bufferedStream = new BufferStream(delegateStream);
55+
IoUtils.markStreamWithMaxReadLimit(bufferedStream, Integer.MAX_VALUE);
56+
}
57+
58+
invokeSafely(bufferedStream::reset);
59+
return bufferedStream;
60+
}
61+
62+
private class BufferStream extends BufferedInputStream {
63+
BufferStream(InputStream in) {
64+
super(in);
65+
}
66+
67+
@Override
68+
public synchronized int read() throws IOException {
69+
int read = super.read();
70+
if (read < 0) {
71+
saveBuffer();
72+
}
73+
return read;
74+
}
75+
76+
@Override
77+
public synchronized int read(byte[] b, int off, int len) throws IOException {
78+
int read = super.read(b, off, len);
79+
if (read < 0) {
80+
saveBuffer();
81+
}
82+
return read;
83+
}
84+
85+
private void saveBuffer() {
86+
if (bufferedStreamData == null) {
87+
IoUtils.closeQuietlyV2(in, null);
88+
BufferingContentStreamProvider.this.bufferedStreamData = this.buf;
89+
BufferingContentStreamProvider.this.count = this.count;
90+
}
91+
}
92+
}
93+
94+
}

core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.Arrays;
3232
import java.util.Optional;
3333
import software.amazon.awssdk.annotations.SdkPublicApi;
34+
import software.amazon.awssdk.core.internal.sync.BufferingContentStreamProvider;
3435
import software.amazon.awssdk.core.internal.sync.FileContentStreamProvider;
3536
import software.amazon.awssdk.core.internal.util.Mimetype;
3637
import software.amazon.awssdk.core.io.ReleasableInputStream;
@@ -220,15 +221,26 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo
220221
}
221222

222223
/**
223-
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider}.
224+
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. If you
225+
* are able to provide the content length at creation time, consider using {@link #fromInputStream(InputStream, long)} or
226+
* {@link #fromContentProvider(ContentStreamProvider, long, String)} to negate the need to read through the stream to find
227+
* the content length.
228+
* <p>
229+
* Important: Be aware that this override requires the SDK to buffer the entirety of your content stream to compute the
230+
* content length. This will cause increased memory usage.
231+
* <p>
232+
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
233+
* S3's documentation for
234+
* <a href="https://docs.aws.amazon.com/AmazonS3/latest/API/s3_example_s3_Scenario_UploadStream_section.html">alternative
235+
* methods</a>.
224236
*
225237
* @param provider The content provider.
226238
* @param mimeType The MIME type of the content.
227239
*
228240
* @return The created {@code RequestBody}.
229241
*/
230242
public static RequestBody fromContentProvider(ContentStreamProvider provider, String mimeType) {
231-
return new RequestBody(provider, null, mimeType);
243+
return new RequestBody(new BufferingContentStreamProvider(provider), null, mimeType);
232244
}
233245

234246
/**
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.core.internal.sync;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.mockito.ArgumentMatchers.any;
20+
import static org.mockito.ArgumentMatchers.anyInt;
21+
22+
import java.io.ByteArrayInputStream;
23+
import java.io.IOException;
24+
import java.io.InputStream;
25+
import java.nio.charset.StandardCharsets;
26+
import org.junit.jupiter.api.BeforeEach;
27+
import org.junit.jupiter.api.Test;
28+
import org.mockito.Mockito;
29+
import software.amazon.awssdk.checksums.DefaultChecksumAlgorithm;
30+
import software.amazon.awssdk.checksums.SdkChecksum;
31+
import software.amazon.awssdk.core.sync.RequestBody;
32+
import software.amazon.awssdk.utils.BinaryUtils;
33+
import software.amazon.awssdk.utils.IoUtils;
34+
35+
class BufferingContentStreamProviderTest {
36+
private static final SdkChecksum CRC32 = SdkChecksum.forAlgorithm(DefaultChecksumAlgorithm.CRC32);
37+
private static final byte[] TEST_DATA = "BufferingContentStreamProviderTest".getBytes(StandardCharsets.UTF_8);
38+
private static final String TEST_DATA_CHECKSUM = "f9ed1825";
39+
40+
private RequestBody requestBody;
41+
42+
@BeforeEach
43+
void setup() {
44+
ByteArrayInputStream stream = new ByteArrayInputStream(TEST_DATA);
45+
requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain");
46+
}
47+
48+
@Test
49+
void newStream_alwaysStartsAtBeginning() {
50+
String stream1Crc32 = getCrc32(requestBody.contentStreamProvider().newStream());
51+
String stream2Crc32 = getCrc32(requestBody.contentStreamProvider().newStream());
52+
53+
assertThat(stream1Crc32).isEqualTo(TEST_DATA_CHECKSUM);
54+
assertThat(stream2Crc32).isEqualTo(TEST_DATA_CHECKSUM);
55+
}
56+
57+
@Test
58+
void newStream_buffersSkippedBytes() throws IOException {
59+
InputStream stream1 = requestBody.contentStreamProvider().newStream();
60+
61+
assertThat(stream1.skip(Long.MAX_VALUE)).isEqualTo(TEST_DATA.length);
62+
63+
String stream2Crc32 = getCrc32(requestBody.contentStreamProvider().newStream());
64+
65+
assertThat(stream2Crc32).isEqualTo(TEST_DATA_CHECKSUM);
66+
}
67+
68+
@Test
69+
void newStream_oneByteReads_dataBufferedCorrectly() throws IOException {
70+
InputStream stream = requestBody.contentStreamProvider().newStream();
71+
int read;
72+
do {
73+
read = stream.read();
74+
} while (read != -1);
75+
76+
assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
77+
}
78+
79+
@Test
80+
void newStream_wholeArrayReads_dataBufferedCorrectly() throws IOException {
81+
InputStream stream = requestBody.contentStreamProvider().newStream();
82+
int read;
83+
byte[] buff = new byte[32];
84+
do {
85+
read = stream.read(buff);
86+
} while (read != -1);
87+
88+
assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
89+
}
90+
91+
@Test
92+
void newStream_offsetArrayReads_dataBufferedCorrectly() throws IOException {
93+
InputStream stream = requestBody.contentStreamProvider().newStream();
94+
int read;
95+
byte[] buff = new byte[32];
96+
do {
97+
read = stream.read(buff, 0, 32);
98+
} while (read != -1);
99+
100+
assertThat(getCrc32(requestBody.contentStreamProvider().newStream())).isEqualTo(TEST_DATA_CHECKSUM);
101+
}
102+
103+
@Test
104+
void newStream_closeClosesDelegateStream() throws IOException {
105+
InputStream stream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));
106+
requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain");
107+
requestBody.contentStreamProvider().newStream().close();
108+
109+
Mockito.verify(stream).close();
110+
}
111+
112+
@Test
113+
void newStream_allDataBuffered_closesDelegateStream() throws IOException {
114+
InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));
115+
116+
requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain");
117+
118+
IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
119+
Mockito.verify(delegateStream, Mockito.atLeast(1)).read(any(), anyInt(), anyInt());
120+
Mockito.verify(delegateStream).close();
121+
122+
IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
123+
Mockito.verifyNoMoreInteractions(delegateStream);
124+
}
125+
126+
private static String getCrc32(InputStream inputStream) {
127+
byte[] buff = new byte[1024];
128+
int read;
129+
130+
CRC32.reset();
131+
try {
132+
while ((read = inputStream.read(buff)) != -1) {
133+
CRC32.update(buff, 0, read);
134+
}
135+
} catch (IOException e) {
136+
throw new RuntimeException(e);
137+
}
138+
139+
return BinaryUtils.toHex(CRC32.getChecksumBytes());
140+
}
141+
}

0 commit comments

Comments
 (0)