Skip to content

Avoid extra byte array copying when downloading to memory with AsyncResponseTransformer #6355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-dd9f8bf.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Avoid extra byte array copying when downloading to memory with AsyncResponseTransformer"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.InputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.core.async.AsyncResponseTransformer;
Expand Down Expand Up @@ -68,6 +69,48 @@ public static <ResponseT> ResponseBytes<ResponseT> fromByteArrayUnsafe(ResponseT
return new ResponseBytes<>(response, bytes);
}

/**
* Create {@link ResponseBytes} from a {@link ByteBuffer} with minimal copying. This method attempts to avoid
* copying data when possible, but introduces concurrency risks in specific scenarios.
*
* <p><b>Behavior by buffer type:</b>
* <ul>
* <li><b>Array-backed ByteBuffer (perfect match):</b> When the buffer represents the entire backing array
* (offset=0, remaining=array.length), the array is shared <b>without</b> copying. This introduces the same
* concurrency risks as {@link #fromByteArrayUnsafe(Object, byte[])}: modifications to the original
* backing array will affect the returned {@link ResponseBytes}.</li>
* <li><b>Array-backed ByteBuffer (partial):</b> When the buffer represents only a portion of the backing array,
* data is copied to a new array. No concurrency risks.</li>
* <li><b>Direct ByteBuffer:</b> Data is always copied to a heap array. No concurrency risks.</li>
* </ul>
*
* <p>The buffer's position is preserved and not modified by this operation.
*
* <p>As the method name implies, this is unsafe in the first scenario. Use a safe alternative unless you're
* sure you know the risks.
*/
public static <ResponseT> ResponseBytes<ResponseT> fromByteBufferUnsafe(ResponseT response, ByteBuffer buffer) {
byte[] array;
if (buffer.hasArray()) {
array = buffer.array();
int offset = buffer.arrayOffset() + buffer.position();
int length = buffer.remaining();
if (offset == 0 && length == array.length) {
// Perfect match - use array directly
} else {
// Create view of the relevant portion
array = Arrays.copyOfRange(array, offset, offset + length);
}
} else {
// Direct buffer - must copy to array
array = new byte[buffer.remaining()];
int originalPosition = buffer.position();
buffer.get(array);
buffer.position(originalPosition);
}
return new ResponseBytes<>(response, array);
}

/**
* @return the unmarshalled response object from the service.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@
public final class ByteArrayAsyncResponseTransformer<ResponseT> implements
AsyncResponseTransformer<ResponseT, ResponseBytes<ResponseT>> {

private volatile CompletableFuture<byte[]> cf;
private volatile CompletableFuture<ByteBuffer> cf;
private volatile ResponseT response;

@Override
public CompletableFuture<ResponseBytes<ResponseT>> prepare() {
cf = new CompletableFuture<>();
// Using fromByteArrayUnsafe() to avoid unnecessary extra copying of byte array. The data writing has completed and the
// byte array will not be further modified so this is safe
return cf.thenApply(arr -> ResponseBytes.fromByteArrayUnsafe(response, arr));
// Using fromByteBufferUnsafe() to avoid unnecessary extra copying of byte array. The data writing has completed and the
// byte buffer will not be further modified so this is safe
return cf.thenApply(buffer -> ResponseBytes.fromByteBufferUnsafe(response, buffer));
}

@Override
Expand All @@ -73,13 +73,11 @@ public String name() {
}

static class BaosSubscriber implements Subscriber<ByteBuffer> {
private final CompletableFuture<byte[]> resultFuture;

private ByteArrayOutputStream baos = new ByteArrayOutputStream();

private final CompletableFuture<ByteBuffer> resultFuture;
private DirectAccessByteArrayOutputStream directAccessOutputStream = new DirectAccessByteArrayOutputStream();
private Subscription subscription;

BaosSubscriber(CompletableFuture<byte[]> resultFuture) {
BaosSubscriber(CompletableFuture<ByteBuffer> resultFuture) {
this.resultFuture = resultFuture;
}

Expand All @@ -95,19 +93,38 @@ public void onSubscribe(Subscription s) {

@Override
public void onNext(ByteBuffer byteBuffer) {
invokeSafely(() -> baos.write(BinaryUtils.copyBytesFrom(byteBuffer)));
subscription.request(1);
invokeSafely(() -> {
if (byteBuffer.hasArray()) {
directAccessOutputStream.write(byteBuffer.array(), byteBuffer.arrayOffset() + byteBuffer.position(),
byteBuffer.remaining());
} else {
directAccessOutputStream.write(BinaryUtils.copyBytesFrom(byteBuffer));
}
});
}

@Override
public void onError(Throwable throwable) {
baos = null;
directAccessOutputStream = null;
resultFuture.completeExceptionally(throwable);
}

@Override
public void onComplete() {
resultFuture.complete(baos.toByteArray());
resultFuture.complete(directAccessOutputStream.toByteBuffer());
}
}

/**
* Custom ByteArrayOutputStream that exposes internal buffer without copying
*/
static class DirectAccessByteArrayOutputStream extends ByteArrayOutputStream {

/**
* Returns the internal buffer wrapped as ByteBuffer with length set to count.
*/
ByteBuffer toByteBuffer() {
return ByteBuffer.wrap(buf, 0, count);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

import static org.assertj.core.api.Assertions.assertThat;

import java.nio.ByteBuffer;
import org.junit.jupiter.api.Test;

public class ResponseBytesTest {
private static final Object OBJECT = new Object();
@Test
public void fromByteArrayCreatesCopy() {
byte[] input = new byte[] { 'a' };
byte[] input = {'a'};
byte[] output = ResponseBytes.fromByteArray(OBJECT, input).asByteArrayUnsafe();

input[0] = 'b';
Expand All @@ -32,7 +33,7 @@ public void fromByteArrayCreatesCopy() {

@Test
public void asByteArrayCreatesCopy() {
byte[] input = new byte[] { 'a' };
byte[] input = {'a'};
byte[] output = ResponseBytes.fromByteArrayUnsafe(OBJECT, input).asByteArray();

input[0] = 'b';
Expand All @@ -41,9 +42,64 @@ public void asByteArrayCreatesCopy() {

@Test
public void fromByteArrayUnsafeAndAsByteArrayUnsafeDoNotCopy() {
byte[] input = new byte[] { 'a' };
byte[] input = {'a'};
byte[] output = ResponseBytes.fromByteArrayUnsafe(OBJECT, input).asByteArrayUnsafe();

assertThat(output).isSameAs(input);
}

@Test
public void fromByteBufferUnsafe_fullBuffer_doesNotCopy() {
byte[] inputBytes = {'a'};
ByteBuffer inputBuffer = ByteBuffer.wrap(inputBytes);

ResponseBytes<Object> responseBytes = ResponseBytes.fromByteBufferUnsafe(OBJECT, inputBuffer);
byte[] outputBytes = responseBytes.asByteArrayUnsafe();

assertThat(inputBuffer.hasArray()).isTrue();
assertThat(inputBuffer.isDirect()).isFalse();
assertThat(outputBytes).isSameAs(inputBytes);

inputBytes[0] = 'b';
assertThat(outputBytes[0]).isEqualTo((byte) 'b');
}

@Test
public void fromByteBufferUnsafe_directBuffer_createsCopy() {
byte[] inputBytes = {'a'};
ByteBuffer directBuffer = ByteBuffer.allocateDirect(1);
directBuffer.put(inputBytes);
directBuffer.flip();

ResponseBytes<Object> responseBytes = ResponseBytes.fromByteBufferUnsafe(OBJECT, directBuffer);
ByteBuffer outputBuffer = responseBytes.asByteBuffer();
byte[] outputBytes = responseBytes.asByteArrayUnsafe();

assertThat(directBuffer.hasArray()).isFalse();
assertThat(directBuffer.isDirect()).isTrue();
assertThat(outputBuffer.isDirect()).isFalse();
assertThat(outputBytes).isEqualTo(inputBytes);
assertThat(outputBytes).isNotSameAs(inputBytes);

inputBytes[0] = 'b';
assertThat(outputBytes[0]).isNotEqualTo((byte) 'b');
}

@Test
public void fromByteBufferUnsafe_bufferWithOffset_createsCopy() {
byte[] inputBytes = "abcdefgh".getBytes();

ByteBuffer slicedBuffer = ByteBuffer.wrap(inputBytes, 2, 3); // "cde"

ResponseBytes<Object> responseBytes = ResponseBytes.fromByteBufferUnsafe(OBJECT, slicedBuffer);
byte[] outputBytes = responseBytes.asByteArrayUnsafe();

assertThat(slicedBuffer.hasArray()).isTrue();
assertThat(outputBytes).isEqualTo("cde".getBytes());
assertThat(outputBytes.length).isEqualTo(3);
assertThat(outputBytes).isNotSameAs(inputBytes);

inputBytes[0] = 'X';
assertThat(outputBytes[0]).isEqualTo((byte) 'c');
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static software.amazon.awssdk.core.internal.async.SplittingPublisherTestUtils.verifyIndividualAsyncRequestBody;
import static software.amazon.awssdk.utils.FunctionalUtils.invokeSafely;

import java.io.ByteArrayOutputStream;
Expand All @@ -38,12 +37,9 @@
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.testutils.RandomTempFile;
import software.amazon.awssdk.utils.BinaryUtils;

Expand Down Expand Up @@ -236,7 +232,7 @@ public void changingFile_fileGetsDeleted_failsBecauseDeleted() throws Exception

@Test
public void positionNotZero_shouldReadFromPosition() throws Exception {
CompletableFuture<byte[]> future = new CompletableFuture<>();
CompletableFuture<ByteBuffer> future = new CompletableFuture<>();
long position = 20L;
AsyncRequestBody asyncRequestBody = FileAsyncRequestBody.builder()
.path(smallFile)
Expand All @@ -249,7 +245,9 @@ public void positionNotZero_shouldReadFromPosition() throws Exception {
asyncRequestBody.subscribe(baosSubscriber);
assertThat(asyncRequestBody.contentLength()).contains(80L);

byte[] bytes = future.get(1, TimeUnit.SECONDS);
ByteBuffer buffer = future.get(1, TimeUnit.SECONDS);
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);

byte[] expected = new byte[80];
try(FileInputStream inputStream = new FileInputStream(smallFile.toFile())) {
Expand All @@ -262,7 +260,7 @@ public void positionNotZero_shouldReadFromPosition() throws Exception {

@Test
public void bothPositionAndNumBytesToReadConfigured_shouldHonor() throws Exception {
CompletableFuture<byte[]> future = new CompletableFuture<>();
CompletableFuture<ByteBuffer> future = new CompletableFuture<>();
long position = 20L;
long numBytesToRead = 5L;
AsyncRequestBody asyncRequestBody = FileAsyncRequestBody.builder()
Expand All @@ -277,7 +275,9 @@ public void bothPositionAndNumBytesToReadConfigured_shouldHonor() throws Excepti
asyncRequestBody.subscribe(baosSubscriber);
assertThat(asyncRequestBody.contentLength()).contains(numBytesToRead);

byte[] bytes = future.get(1, TimeUnit.SECONDS);
ByteBuffer buffer = future.get(1, TimeUnit.SECONDS);
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);

byte[] expected = new byte[5];
try (FileInputStream inputStream = new FileInputStream(smallFile.toFile())) {
Expand All @@ -290,7 +290,7 @@ public void bothPositionAndNumBytesToReadConfigured_shouldHonor() throws Excepti

@Test
public void numBytesToReadConfigured_shouldHonor() throws Exception {
CompletableFuture<byte[]> future = new CompletableFuture<>();
CompletableFuture<ByteBuffer> future = new CompletableFuture<>();
AsyncRequestBody asyncRequestBody = FileAsyncRequestBody.builder()
.path(smallFile)
.numBytesToRead(5L)
Expand All @@ -302,7 +302,9 @@ public void numBytesToReadConfigured_shouldHonor() throws Exception {
asyncRequestBody.subscribe(baosSubscriber);
assertThat(asyncRequestBody.contentLength()).contains(5L);

byte[] bytes = future.get(1, TimeUnit.SECONDS);
ByteBuffer buffer = future.get(1, TimeUnit.SECONDS);
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);

byte[] expected = new byte[5];
try (FileInputStream inputStream = new FileInputStream(smallFile.toFile())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,13 @@ void failedStream_completesExceptionally() {
}

private static String drainPublisherToStr(SdkPublisher<ByteBuffer> publisher) throws Exception {
CompletableFuture<byte[]> bodyFuture = new CompletableFuture<>();
CompletableFuture<ByteBuffer> bodyFuture = new CompletableFuture<>();
publisher.subscribe(new BaosSubscriber(bodyFuture));
byte[] body = bodyFuture.get();
return new String(body);

ByteBuffer buffer = bodyFuture.get();
byte[] bytes = new byte[buffer.remaining()];
buffer.get(bytes);

return new String(bytes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,26 @@

package software.amazon.awssdk.core.internal.async;

import static org.assertj.core.api.AssertionsForClassTypes.assertThat;

import java.io.File;
import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import org.assertj.core.api.Assertions;
import org.reactivestreams.Publisher;
import software.amazon.awssdk.core.async.AsyncRequestBody;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.internal.async.ByteArrayAsyncResponseTransformer;
import software.amazon.awssdk.core.internal.async.SplittingPublisherTest;

public final class SplittingPublisherTestUtils {

public static void verifyIndividualAsyncRequestBody(SdkPublisher<AsyncRequestBody> publisher,
Path file,
int chunkSize) throws Exception {

List<CompletableFuture<byte[]>> futures = new ArrayList<>();
List<CompletableFuture<ByteBuffer>> futures = new ArrayList<>();
publisher.subscribe(requestBody -> {
CompletableFuture<byte[]> baosFuture = new CompletableFuture<>();
CompletableFuture<ByteBuffer> baosFuture = new CompletableFuture<>();
ByteArrayAsyncResponseTransformer.BaosSubscriber subscriber =
new ByteArrayAsyncResponseTransformer.BaosSubscriber(baosFuture);
requestBody.subscribe(subscriber);
Expand All @@ -62,7 +55,10 @@ public static void verifyIndividualAsyncRequestBody(SdkPublisher<AsyncRequestBod
}
fileInputStream.skip(i * chunkSize);
fileInputStream.read(expected);
byte[] actualBytes = futures.get(i).join();
ByteBuffer actualByteBuffer = futures.get(i).join();
byte[] actualBytes = new byte[actualByteBuffer.remaining()];
actualByteBuffer.get(actualBytes);

Assertions.assertThat(actualBytes).isEqualTo(expected);
}
}
Expand Down
Loading