diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java
index c19ab8e245f8..a304d75ccf94 100644
--- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java
+++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java
@@ -16,21 +16,47 @@
package software.amazon.awssdk.core.internal.async;
import java.nio.ByteBuffer;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicLong;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
+import software.amazon.awssdk.annotations.SdkTestInternalApi;
import software.amazon.awssdk.core.async.AsyncRequestBody;
+import software.amazon.awssdk.core.exception.NonRetryableException;
import software.amazon.awssdk.core.internal.util.Mimetype;
+import software.amazon.awssdk.core.internal.util.NoopSubscription;
import software.amazon.awssdk.utils.Logger;
+import software.amazon.awssdk.utils.SdkAutoCloseable;
+import software.amazon.awssdk.utils.Validate;
/**
* An implementation of {@link AsyncRequestBody} for providing data from the supplied {@link ByteBuffer} array. This is created
* using static methods on {@link AsyncRequestBody}
*
+ *
Subscription Behavior:
+ *
+ * - Each subscriber receives a read-only view of the buffered data
+ * - Subscribers receive data independently based on their own demand signaling
+ * - If the body is closed, new subscribers will receive an error immediately
+ *
+ *
+ * Resource Management:
+ * The body should be closed when no longer needed to free buffered data and notify active subscribers.
+ * Closing the body will:
+ *
+ * - Clear all buffered data
+ * - Send error notifications to all active subscribers
+ * - Prevent new subscriptions
+ *
* @see AsyncRequestBody#fromBytes(byte[])
* @see AsyncRequestBody#fromBytesUnsafe(byte[])
* @see AsyncRequestBody#fromByteBuffer(ByteBuffer)
@@ -40,17 +66,21 @@
* @see AsyncRequestBody#fromString(String)
*/
@SdkInternalApi
-public final class ByteBuffersAsyncRequestBody implements AsyncRequestBody {
+public final class ByteBuffersAsyncRequestBody implements AsyncRequestBody, SdkAutoCloseable {
private static final Logger log = Logger.loggerFor(ByteBuffersAsyncRequestBody.class);
private final String mimetype;
private final Long length;
- private final ByteBuffer[] buffers;
+ private List buffers;
+ private final Set subscriptions;
+ private final Object lock = new Object();
+ private boolean closed;
- private ByteBuffersAsyncRequestBody(String mimetype, Long length, ByteBuffer... buffers) {
+ private ByteBuffersAsyncRequestBody(String mimetype, Long length, List buffers) {
this.mimetype = mimetype;
- this.length = length;
this.buffers = buffers;
+ this.length = length;
+ this.subscriptions = ConcurrentHashMap.newKeySet();
}
@Override
@@ -64,61 +94,25 @@ public String contentType() {
}
@Override
- public void subscribe(Subscriber super ByteBuffer> s) {
- // As per rule 1.9 we must throw NullPointerException if the subscriber parameter is null
- if (s == null) {
- throw new NullPointerException("Subscription MUST NOT be null.");
+ public void subscribe(Subscriber super ByteBuffer> subscriber) {
+ Validate.paramNotNull(subscriber, "subscriber");
+ synchronized (lock) {
+ if (closed) {
+ subscriber.onSubscribe(new NoopSubscription(subscriber));
+ subscriber.onError(NonRetryableException.create(
+ "AsyncRequestBody has been closed"));
+ return;
+ }
}
- // As per 2.13, this method must return normally (i.e. not throw).
try {
- s.onSubscribe(
- new Subscription() {
- private final AtomicInteger index = new AtomicInteger(0);
- private final AtomicBoolean completed = new AtomicBoolean(false);
-
- @Override
- public void request(long n) {
- if (completed.get()) {
- return;
- }
-
- if (n > 0) {
- int i = index.getAndIncrement();
-
- if (buffers.length == 0 && completed.compareAndSet(false, true)) {
- s.onComplete();
- }
-
- if (i >= buffers.length) {
- return;
- }
-
- long remaining = n;
-
- do {
- ByteBuffer buffer = buffers[i];
-
- s.onNext(buffer.asReadOnlyBuffer());
- remaining--;
- } while (remaining > 0 && (i = index.getAndIncrement()) < buffers.length);
-
- if (i >= buffers.length - 1 && completed.compareAndSet(false, true)) {
- s.onComplete();
- }
- } else {
- s.onError(new IllegalArgumentException("§3.9: non-positive requests are not allowed!"));
- }
- }
-
- @Override
- public void cancel() {
- completed.set(true);
- }
- }
- );
+ ReplayableByteBufferSubscription replayableByteBufferSubscription =
+ new ReplayableByteBufferSubscription(subscriber);
+ subscriber.onSubscribe(replayableByteBufferSubscription);
+ subscriptions.add(replayableByteBufferSubscription);
} catch (Throwable ex) {
- log.error(() -> s + " violated the Reactive Streams rule 2.13 by throwing an exception from onSubscribe.", ex);
+ log.error(() -> subscriber + " violated the Reactive Streams rule 2.13 by throwing an exception from onSubscribe.",
+ ex);
}
}
@@ -127,34 +121,167 @@ public String body() {
return BodyType.BYTES.getName();
}
- public static ByteBuffersAsyncRequestBody of(ByteBuffer... buffers) {
- long length = Arrays.stream(buffers)
- .mapToLong(ByteBuffer::remaining)
- .sum();
+ public static ByteBuffersAsyncRequestBody of(List buffers) {
+ long length = buffers.stream()
+ .mapToLong(ByteBuffer::remaining)
+ .sum();
return new ByteBuffersAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, length, buffers);
}
+ public static ByteBuffersAsyncRequestBody of(ByteBuffer... buffers) {
+ return of(Arrays.asList(buffers));
+ }
+
public static ByteBuffersAsyncRequestBody of(Long length, ByteBuffer... buffers) {
- return new ByteBuffersAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, length, buffers);
+ return new ByteBuffersAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, length, Arrays.asList(buffers));
}
public static ByteBuffersAsyncRequestBody of(String mimetype, ByteBuffer... buffers) {
long length = Arrays.stream(buffers)
.mapToLong(ByteBuffer::remaining)
.sum();
- return new ByteBuffersAsyncRequestBody(mimetype, length, buffers);
+ return new ByteBuffersAsyncRequestBody(mimetype, length, Arrays.asList(buffers));
}
public static ByteBuffersAsyncRequestBody of(String mimetype, Long length, ByteBuffer... buffers) {
- return new ByteBuffersAsyncRequestBody(mimetype, length, buffers);
+ return new ByteBuffersAsyncRequestBody(mimetype, length, Arrays.asList(buffers));
}
public static ByteBuffersAsyncRequestBody from(byte[] bytes) {
return new ByteBuffersAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, (long) bytes.length,
- ByteBuffer.wrap(bytes));
+ Collections.singletonList(ByteBuffer.wrap(bytes)));
}
public static ByteBuffersAsyncRequestBody from(String mimetype, byte[] bytes) {
- return new ByteBuffersAsyncRequestBody(mimetype, (long) bytes.length, ByteBuffer.wrap(bytes));
+ return new ByteBuffersAsyncRequestBody(mimetype, (long) bytes.length,
+ Collections.singletonList(ByteBuffer.wrap(bytes)));
+ }
+
+ @Override
+ public void close() {
+ synchronized (lock) {
+ if (closed) {
+ return;
+ }
+
+ closed = true;
+ buffers = new ArrayList<>();
+ subscriptions.forEach(s -> s.notifyError(new IllegalStateException("The publisher has been closed")));
+ subscriptions.clear();
+ }
+ }
+
+ @SdkTestInternalApi
+ public List bufferedData() {
+ return buffers;
+ }
+
+ private class ReplayableByteBufferSubscription implements Subscription {
+ private final AtomicInteger index = new AtomicInteger(0);
+ private volatile boolean done;
+ private final AtomicBoolean processingRequest = new AtomicBoolean(false);
+ private Subscriber super ByteBuffer> currentSubscriber;
+ private final AtomicLong outstandingDemand = new AtomicLong();
+
+ private ReplayableByteBufferSubscription(Subscriber super ByteBuffer> subscriber) {
+ this.currentSubscriber = subscriber;
+ }
+
+ @Override
+ public void request(long n) {
+ if (n <= 0) {
+ currentSubscriber.onError(new IllegalArgumentException("§3.9: non-positive requests are not allowed!"));
+ currentSubscriber = null;
+ return;
+ }
+
+ if (done) {
+ return;
+ }
+
+ if (buffers.size() == 0) {
+ currentSubscriber.onComplete();
+ done = true;
+ subscriptions.remove(this);
+ return;
+ }
+
+ outstandingDemand.updateAndGet(current -> {
+ if (Long.MAX_VALUE - current < n) {
+ return Long.MAX_VALUE;
+ }
+
+ return current + n;
+ });
+ processRequest();
+ }
+
+ private void processRequest() {
+ do {
+ if (!processingRequest.compareAndSet(false, true)) {
+ // Some other thread is processing the queue, so we don't need to.
+ return;
+ }
+
+ try {
+ doProcessRequest();
+ } catch (Throwable e) {
+ notifyError(new IllegalStateException("Encountered fatal error in publisher", e));
+ subscriptions.remove(this);
+ break;
+ } finally {
+ processingRequest.set(false);
+ }
+
+ } while (shouldProcessRequest());
+ }
+
+ private boolean shouldProcessRequest() {
+ return !done && outstandingDemand.get() > 0 && index.get() < buffers.size();
+ }
+
+ private void doProcessRequest() {
+ while (true) {
+ if (!shouldProcessRequest()) {
+ return;
+ }
+
+ int currentIndex = this.index.getAndIncrement();
+
+ if (currentIndex >= buffers.size()) {
+ // This should never happen because shouldProcessRequest() ensures that index.get() < buffers.size()
+ // before incrementing. If this condition is true, it likely indicates a concurrency bug or that buffers
+ // was modified unexpectedly. This defensive check is here to catch such rare, unexpected situations.
+ notifyError(new IllegalStateException("Index out of bounds"));
+ subscriptions.remove(this);
+ return;
+ }
+
+ ByteBuffer buffer = buffers.get(currentIndex);
+ currentSubscriber.onNext(buffer.asReadOnlyBuffer());
+ outstandingDemand.decrementAndGet();
+
+ if (currentIndex == buffers.size() - 1) {
+ done = true;
+ currentSubscriber.onComplete();
+ subscriptions.remove(this);
+ break;
+ }
+ }
+ }
+
+ @Override
+ public void cancel() {
+ done = true;
+ subscriptions.remove(this);
+ }
+
+ public void notifyError(Exception exception) {
+ if (currentSubscriber != null) {
+ done = true;
+ currentSubscriber.onError(exception);
+ currentSubscriber = null;
+ }
+ }
}
}
diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBufferAsyncRequestBodyTckTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBufferAsyncRequestBodyTckTest.java
new file mode 100644
index 000000000000..3d9df4b9e1e2
--- /dev/null
+++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBufferAsyncRequestBodyTckTest.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License").
+ * You may not use this file except in compliance with the License.
+ * A copy of the License is located at
+ *
+ * http://aws.amazon.com/apache2.0
+ *
+ * or in the "license" file accompanying this file. This file is distributed
+ * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+ * express or implied. See the License for the specific language governing
+ * permissions and limitations under the License.
+ */
+
+package software.amazon.awssdk.core.internal.async;
+
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.commons.lang3.RandomStringUtils;
+import org.reactivestreams.Publisher;
+import org.reactivestreams.tck.TestEnvironment;
+import software.amazon.awssdk.core.SdkBytes;
+
+public class ByteBufferAsyncRequestBodyTckTest extends org.reactivestreams.tck.PublisherVerification {
+ public ByteBufferAsyncRequestBodyTckTest() {
+ super(new TestEnvironment());
+ }
+
+ @Override
+ public Publisher createPublisher(long elements) {
+ List buffers = new ArrayList<>();
+ for (int i = 0; i < elements; i++) {
+ buffers.add(SdkBytes.fromUtf8String(RandomStringUtils.randomAscii(1024)).asByteBuffer());
+ }
+ return ByteBuffersAsyncRequestBody.of(buffers.toArray(new ByteBuffer[0]));
+ }
+
+ @Override
+ public Publisher createFailedPublisher() {
+ ByteBuffersAsyncRequestBody bufferingAsyncRequestBody = ByteBuffersAsyncRequestBody.of(ByteBuffer.wrap(RandomStringUtils.randomAscii(1024).getBytes()));
+ bufferingAsyncRequestBody.close();
+ return bufferingAsyncRequestBody;
+ }
+
+ public long maxElementsFromPublisher() {
+ return 100;
+ }
+
+}
diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBodyTest.java
index e1035dc25b0c..c80784b26a90 100644
--- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBodyTest.java
+++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBodyTest.java
@@ -15,58 +15,44 @@
package software.amazon.awssdk.core.internal.async;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.argThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
+import org.apache.commons.lang3.RandomStringUtils;
+import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.core.async.AsyncRequestBody;
+import software.amazon.awssdk.core.exception.NonRetryableException;
import software.amazon.awssdk.utils.BinaryUtils;
class ByteBuffersAsyncRequestBodyTest {
- private static class TestSubscriber implements Subscriber {
- private Subscription subscription;
- private boolean onCompleteCalled = false;
- private int callsToComplete = 0;
- private final List publishedResults = Collections.synchronizedList(new ArrayList<>());
-
- public void request(long n) {
- subscription.request(n);
- }
-
- @Override
- public void onSubscribe(Subscription s) {
- this.subscription = s;
- }
-
- @Override
- public void onNext(ByteBuffer byteBuffer) {
- publishedResults.add(byteBuffer);
- }
-
- @Override
- public void onError(Throwable throwable) {
- throw new IllegalStateException(throwable);
- }
+ private static ExecutorService executor = Executors.newFixedThreadPool(10);
- @Override
- public void onComplete() {
- onCompleteCalled = true;
- callsToComplete++;
- }
+ @AfterAll
+ static void tearDown() {
+ executor.shutdown();
}
@Test
@@ -214,4 +200,167 @@ public void staticFromBytesConstructorSetsLengthBasedOnArrayLength() {
assertEquals(bytes.length, body.contentLength().get());
}
+ @Test
+ public void subscribe_whenBodyIsClosed_shouldNotifySubscriberWithError() {
+ ByteBuffersAsyncRequestBody body = ByteBuffersAsyncRequestBody.of(ByteBuffer.wrap("test".getBytes()));
+ body.close(); // Set closed to true
+ Subscriber