Skip to content

Commit 601cda7

Browse files
authored
Buffer if custom content provider stream (#5841)
* Buffer if custom content provider stream This updates the `RequestBody.fromContentProvider(ContentStreamProvider,long,String)` override such that the underlying implementation will buffer the contents of the stream in memory during the first pass through the stream. This is a followup to #5837. * Review changes
1 parent c137348 commit 601cda7

File tree

4 files changed

+145
-39
lines changed

4 files changed

+145
-39
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"type": "feature",
3+
"category": "AWS SDK for Java v2",
4+
"contributor": "",
5+
"description": "Buffer input data from ContentStreamProvider in cases where content length is known."
6+
}

core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProvider.java

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,21 @@
3434
@NotThreadSafe
3535
public final class BufferingContentStreamProvider implements ContentStreamProvider {
3636
private final ContentStreamProvider delegate;
37-
private InputStream bufferedStream;
37+
private final Long expectedLength;
38+
private BufferStream bufferedStream;
3839

3940
private byte[] bufferedStreamData;
4041
private int count;
4142

42-
public BufferingContentStreamProvider(ContentStreamProvider delegate) {
43+
public BufferingContentStreamProvider(ContentStreamProvider delegate, Long expectedLength) {
4344
this.delegate = delegate;
45+
this.expectedLength = expectedLength;
4446
}
4547

4648
@Override
4749
public InputStream newStream() {
4850
if (bufferedStreamData != null) {
49-
return new ByteArrayInputStream(bufferedStreamData, 0, this.count);
51+
return new ByteArrayStream(bufferedStreamData, 0, this.count);
5052
}
5153

5254
if (bufferedStream == null) {
@@ -59,36 +61,57 @@ public InputStream newStream() {
5961
return bufferedStream;
6062
}
6163

62-
private class BufferStream extends BufferedInputStream {
64+
class ByteArrayStream extends ByteArrayInputStream {
65+
66+
ByteArrayStream(byte[] buf, int offset, int length) {
67+
super(buf, offset, length);
68+
}
69+
70+
@Override
71+
public void close() throws IOException {
72+
super.close();
73+
bufferedStream.close();
74+
}
75+
}
76+
77+
class BufferStream extends BufferedInputStream {
6378
BufferStream(InputStream in) {
6479
super(in);
6580
}
6681

67-
@Override
68-
public synchronized int read() throws IOException {
69-
int read = super.read();
70-
if (read < 0) {
71-
saveBuffer();
72-
}
73-
return read;
82+
public byte[] getBuf() {
83+
return this.buf;
84+
}
85+
86+
public int getCount() {
87+
return this.count;
7488
}
7589

7690
@Override
77-
public synchronized int read(byte[] b, int off, int len) throws IOException {
78-
int read = super.read(b, off, len);
79-
if (read < 0) {
91+
public void close() throws IOException {
92+
// We only want to close the underlying stream if we're confident all its data is buffered. In some cases, the
93+
// stream might be closed before we read everything, and we want to avoid closing in these cases if the request
94+
// body is being reused.
95+
if (!hasExpectedLength() || expectedLengthReached()) {
8096
saveBuffer();
97+
super.close();
8198
}
82-
return read;
8399
}
100+
}
84101

85-
private void saveBuffer() {
86-
if (bufferedStreamData == null) {
87-
IoUtils.closeQuietlyV2(in, null);
88-
BufferingContentStreamProvider.this.bufferedStreamData = this.buf;
89-
BufferingContentStreamProvider.this.count = this.count;
90-
}
102+
private void saveBuffer() {
103+
if (bufferedStreamData == null) {
104+
this.bufferedStreamData = bufferedStream.getBuf();
105+
this.count = bufferedStream.getCount();
91106
}
92107
}
93108

109+
private boolean expectedLengthReached() {
110+
return bufferedStream.getCount() >= expectedLength;
111+
}
112+
113+
private boolean hasExpectedLength() {
114+
return this.expectedLength != null;
115+
}
116+
94117
}

core/sdk-core/src/main/java/software/amazon/awssdk/core/sync/RequestBody.java

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -131,12 +131,13 @@ public static RequestBody fromFile(File file) {
131131
public static RequestBody fromInputStream(InputStream inputStream, long contentLength) {
132132
IoUtils.markStreamWithMaxReadLimit(inputStream);
133133
InputStream nonCloseable = nonCloseableInputStream(inputStream);
134-
return fromContentProvider(() -> {
134+
ContentStreamProvider provider = () -> {
135135
if (nonCloseable.markSupported()) {
136136
invokeSafely(nonCloseable::reset);
137137
}
138138
return nonCloseable;
139-
}, contentLength, Mimetype.MIMETYPE_OCTET_STREAM);
139+
};
140+
return new RequestBody(provider, contentLength, Mimetype.MIMETYPE_OCTET_STREAM);
140141
}
141142

142143
/**
@@ -209,6 +210,14 @@ public static RequestBody empty() {
209210

210211
/**
211212
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider}.
213+
* <p>
214+
* Important: Be aware that this implementation requires buffering the contents for {@code ContentStreamProvider}, which can
215+
* cause increased memory usage.
216+
* <p>
217+
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
218+
* S3's documentation for
219+
* <a href="https://docs.aws.amazon.com/AmazonS3/latest/API/s3_example_s3_Scenario_UploadStream_section.html">alternative
220+
* methods</a>.
212221
*
213222
* @param provider The content provider.
214223
* @param contentLength The content length.
@@ -217,17 +226,14 @@ public static RequestBody empty() {
217226
* @return The created {@code RequestBody}.
218227
*/
219228
public static RequestBody fromContentProvider(ContentStreamProvider provider, long contentLength, String mimeType) {
220-
return new RequestBody(provider, contentLength, mimeType);
229+
return new RequestBody(new BufferingContentStreamProvider(provider, contentLength), contentLength, mimeType);
221230
}
222231

223232
/**
224-
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown. If you
225-
* are able to provide the content length at creation time, consider using {@link #fromInputStream(InputStream, long)} or
226-
* {@link #fromContentProvider(ContentStreamProvider, long, String)} to negate the need to read through the stream to find
227-
* the content length.
233+
* Creates a {@link RequestBody} from the given {@link ContentStreamProvider} when the content length is unknown.
228234
* <p>
229-
* Important: Be aware that this override requires the SDK to buffer the entirety of your content stream to compute the
230-
* content length. This will cause increased memory usage.
235+
* Important: Be aware that this implementation requires buffering the contents for {@code ContentStreamProvider}, which can
236+
* cause increased memory usage.
231237
* <p>
232238
* If you are using this in conjunction with S3 and want to upload a stream with an unknown content length, you can refer
233239
* S3's documentation for
@@ -240,7 +246,7 @@ public static RequestBody fromContentProvider(ContentStreamProvider provider, lo
240246
* @return The created {@code RequestBody}.
241247
*/
242248
public static RequestBody fromContentProvider(ContentStreamProvider provider, String mimeType) {
243-
return new RequestBody(new BufferingContentStreamProvider(provider), null, mimeType);
249+
return new RequestBody(new BufferingContentStreamProvider(provider, null), null, mimeType);
244250
}
245251

246252
/**
@@ -254,7 +260,7 @@ private static RequestBody fromBytesDirect(byte[] bytes) {
254260
* Creates a {@link RequestBody} using the specified bytes (without copying).
255261
*/
256262
private static RequestBody fromBytesDirect(byte[] bytes, String mimetype) {
257-
return fromContentProvider(() -> new ByteArrayInputStream(bytes), bytes.length, mimetype);
263+
return new RequestBody(() -> new ByteArrayInputStream(bytes), (long) bytes.length, mimetype);
258264
}
259265

260266
private static InputStream nonCloseableInputStream(InputStream inputStream) {

core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/sync/BufferingContentStreamProviderTest.java

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616
package software.amazon.awssdk.core.internal.sync;
1717

1818
import static org.assertj.core.api.Assertions.assertThat;
19-
import static org.mockito.ArgumentMatchers.any;
20-
import static org.mockito.ArgumentMatchers.anyInt;
2119

2220
import java.io.ByteArrayInputStream;
2321
import java.io.IOException;
2422
import java.io.InputStream;
2523
import java.nio.charset.StandardCharsets;
24+
import java.util.Random;
2625
import org.junit.jupiter.api.BeforeEach;
2726
import org.junit.jupiter.api.Test;
2827
import org.mockito.Mockito;
@@ -110,17 +109,89 @@ void newStream_closeClosesDelegateStream() throws IOException {
110109
}
111110

112111
@Test
113-
void newStream_allDataBuffered_closesDelegateStream() throws IOException {
112+
public void newStream_delegateStreamClosedOnBufferingStreamClose() throws IOException {
114113
InputStream delegateStream = Mockito.spy(new ByteArrayInputStream(TEST_DATA));
115114

116115
requestBody = RequestBody.fromContentProvider(() -> delegateStream, "text/plain");
117116

118-
IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
119-
Mockito.verify(delegateStream, Mockito.atLeast(1)).read(any(), anyInt(), anyInt());
117+
InputStream stream = requestBody.contentStreamProvider().newStream();
118+
IoUtils.drainInputStream(stream);
119+
stream.close();
120+
120121
Mockito.verify(delegateStream).close();
122+
}
123+
124+
@Test
125+
public void newStream_lengthKnown_readUpToLengthThenClosed_newStreamUsesBufferedData() throws IOException {
126+
ByteArrayInputStream stream = new ByteArrayInputStream(TEST_DATA);
127+
requestBody = RequestBody.fromContentProvider(() -> stream, TEST_DATA.length, "text/plain");
128+
129+
int totalRead = 0;
130+
int read;
131+
132+
InputStream stream1 = requestBody.contentStreamProvider().newStream();
133+
do {
134+
read = stream1.read();
135+
if (read != -1) {
136+
++totalRead;
137+
}
138+
} while (read != -1);
139+
140+
assertThat(totalRead).isEqualTo(TEST_DATA.length);
141+
142+
stream1.close();
143+
144+
assertThat(requestBody.contentStreamProvider().newStream())
145+
.isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class);
146+
}
147+
148+
@Test
149+
public void newStream_lengthKnown_partialRead_close_doesNotBufferData() throws IOException {
150+
// We need a large buffer because BufferedInputStream buffers data in chunks. If the buffer is small enough, a single
151+
// read() on the BufferedInputStream might actually buffer all the delegate's data.
152+
153+
byte[] newData = new byte[16536];
154+
new Random().nextBytes(newData);
155+
ByteArrayInputStream stream = new ByteArrayInputStream(newData);
156+
requestBody = RequestBody.fromContentProvider(() -> stream, newData.length, "text/plain");
157+
158+
InputStream stream1 = requestBody.contentStreamProvider().newStream();
159+
int read = stream1.read();
160+
assertThat(read).isNotEqualTo(-1);
161+
162+
stream1.close();
163+
164+
InputStream stream2 = requestBody.contentStreamProvider().newStream();
165+
assertThat(stream2).isInstanceOf(BufferingContentStreamProvider.BufferStream.class);
166+
167+
assertThat(getCrc32(stream2)).isEqualTo(getCrc32(new ByteArrayInputStream(newData)));
168+
}
169+
170+
@Test
171+
public void newStream_bufferedDataStreamPartialRead_closed_bufferedDataIsNotReplaced() throws IOException {
172+
byte[] newData = new byte[16536];
173+
new Random().nextBytes(newData);
174+
String newDataChecksum = getCrc32(new ByteArrayInputStream(newData));
175+
176+
ByteArrayInputStream stream = new ByteArrayInputStream(newData);
177+
178+
requestBody = RequestBody.fromContentProvider(() -> stream, "text/plain");
179+
InputStream stream1 = requestBody.contentStreamProvider().newStream();
180+
IoUtils.drainInputStream(stream1);
181+
stream1.close();
182+
183+
InputStream stream2 = requestBody.contentStreamProvider().newStream();
184+
assertThat(stream2).isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class);
185+
186+
int read = stream2.read();
187+
assertThat(read).isNotEqualTo(-1);
188+
189+
stream2.close();
190+
191+
InputStream stream3 = requestBody.contentStreamProvider().newStream();
192+
assertThat(stream3).isInstanceOf(BufferingContentStreamProvider.ByteArrayStream.class);
121193

122-
IoUtils.drainInputStream(requestBody.contentStreamProvider().newStream());
123-
Mockito.verifyNoMoreInteractions(delegateStream);
194+
assertThat(getCrc32(stream3)).isEqualTo(newDataChecksum);
124195
}
125196

126197
private static String getCrc32(InputStream inputStream) {

0 commit comments

Comments
 (0)