diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java index c19ab8e245f8..a304d75ccf94 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBody.java @@ -16,21 +16,47 @@ package software.amazon.awssdk.core.internal.async; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.annotations.SdkTestInternalApi; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.exception.NonRetryableException; import software.amazon.awssdk.core.internal.util.Mimetype; +import software.amazon.awssdk.core.internal.util.NoopSubscription; import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.SdkAutoCloseable; +import software.amazon.awssdk.utils.Validate; /** * An implementation of {@link AsyncRequestBody} for providing data from the supplied {@link ByteBuffer} array. This is created * using static methods on {@link AsyncRequestBody} * + *

Subscription Behavior:

+ * + * + *

Resource Management:

+ * The body should be closed when no longer needed to free buffered data and notify active subscribers. + * Closing the body will: + * * @see AsyncRequestBody#fromBytes(byte[]) * @see AsyncRequestBody#fromBytesUnsafe(byte[]) * @see AsyncRequestBody#fromByteBuffer(ByteBuffer) @@ -40,17 +66,21 @@ * @see AsyncRequestBody#fromString(String) */ @SdkInternalApi -public final class ByteBuffersAsyncRequestBody implements AsyncRequestBody { +public final class ByteBuffersAsyncRequestBody implements AsyncRequestBody, SdkAutoCloseable { private static final Logger log = Logger.loggerFor(ByteBuffersAsyncRequestBody.class); private final String mimetype; private final Long length; - private final ByteBuffer[] buffers; + private List buffers; + private final Set subscriptions; + private final Object lock = new Object(); + private boolean closed; - private ByteBuffersAsyncRequestBody(String mimetype, Long length, ByteBuffer... buffers) { + private ByteBuffersAsyncRequestBody(String mimetype, Long length, List buffers) { this.mimetype = mimetype; - this.length = length; this.buffers = buffers; + this.length = length; + this.subscriptions = ConcurrentHashMap.newKeySet(); } @Override @@ -64,61 +94,25 @@ public String contentType() { } @Override - public void subscribe(Subscriber s) { - // As per rule 1.9 we must throw NullPointerException if the subscriber parameter is null - if (s == null) { - throw new NullPointerException("Subscription MUST NOT be null."); + public void subscribe(Subscriber subscriber) { + Validate.paramNotNull(subscriber, "subscriber"); + synchronized (lock) { + if (closed) { + subscriber.onSubscribe(new NoopSubscription(subscriber)); + subscriber.onError(NonRetryableException.create( + "AsyncRequestBody has been closed")); + return; + } } - // As per 2.13, this method must return normally (i.e. not throw). try { - s.onSubscribe( - new Subscription() { - private final AtomicInteger index = new AtomicInteger(0); - private final AtomicBoolean completed = new AtomicBoolean(false); - - @Override - public void request(long n) { - if (completed.get()) { - return; - } - - if (n > 0) { - int i = index.getAndIncrement(); - - if (buffers.length == 0 && completed.compareAndSet(false, true)) { - s.onComplete(); - } - - if (i >= buffers.length) { - return; - } - - long remaining = n; - - do { - ByteBuffer buffer = buffers[i]; - - s.onNext(buffer.asReadOnlyBuffer()); - remaining--; - } while (remaining > 0 && (i = index.getAndIncrement()) < buffers.length); - - if (i >= buffers.length - 1 && completed.compareAndSet(false, true)) { - s.onComplete(); - } - } else { - s.onError(new IllegalArgumentException("§3.9: non-positive requests are not allowed!")); - } - } - - @Override - public void cancel() { - completed.set(true); - } - } - ); + ReplayableByteBufferSubscription replayableByteBufferSubscription = + new ReplayableByteBufferSubscription(subscriber); + subscriber.onSubscribe(replayableByteBufferSubscription); + subscriptions.add(replayableByteBufferSubscription); } catch (Throwable ex) { - log.error(() -> s + " violated the Reactive Streams rule 2.13 by throwing an exception from onSubscribe.", ex); + log.error(() -> subscriber + " violated the Reactive Streams rule 2.13 by throwing an exception from onSubscribe.", + ex); } } @@ -127,34 +121,167 @@ public String body() { return BodyType.BYTES.getName(); } - public static ByteBuffersAsyncRequestBody of(ByteBuffer... buffers) { - long length = Arrays.stream(buffers) - .mapToLong(ByteBuffer::remaining) - .sum(); + public static ByteBuffersAsyncRequestBody of(List buffers) { + long length = buffers.stream() + .mapToLong(ByteBuffer::remaining) + .sum(); return new ByteBuffersAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, length, buffers); } + public static ByteBuffersAsyncRequestBody of(ByteBuffer... buffers) { + return of(Arrays.asList(buffers)); + } + public static ByteBuffersAsyncRequestBody of(Long length, ByteBuffer... buffers) { - return new ByteBuffersAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, length, buffers); + return new ByteBuffersAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, length, Arrays.asList(buffers)); } public static ByteBuffersAsyncRequestBody of(String mimetype, ByteBuffer... buffers) { long length = Arrays.stream(buffers) .mapToLong(ByteBuffer::remaining) .sum(); - return new ByteBuffersAsyncRequestBody(mimetype, length, buffers); + return new ByteBuffersAsyncRequestBody(mimetype, length, Arrays.asList(buffers)); } public static ByteBuffersAsyncRequestBody of(String mimetype, Long length, ByteBuffer... buffers) { - return new ByteBuffersAsyncRequestBody(mimetype, length, buffers); + return new ByteBuffersAsyncRequestBody(mimetype, length, Arrays.asList(buffers)); } public static ByteBuffersAsyncRequestBody from(byte[] bytes) { return new ByteBuffersAsyncRequestBody(Mimetype.MIMETYPE_OCTET_STREAM, (long) bytes.length, - ByteBuffer.wrap(bytes)); + Collections.singletonList(ByteBuffer.wrap(bytes))); } public static ByteBuffersAsyncRequestBody from(String mimetype, byte[] bytes) { - return new ByteBuffersAsyncRequestBody(mimetype, (long) bytes.length, ByteBuffer.wrap(bytes)); + return new ByteBuffersAsyncRequestBody(mimetype, (long) bytes.length, + Collections.singletonList(ByteBuffer.wrap(bytes))); + } + + @Override + public void close() { + synchronized (lock) { + if (closed) { + return; + } + + closed = true; + buffers = new ArrayList<>(); + subscriptions.forEach(s -> s.notifyError(new IllegalStateException("The publisher has been closed"))); + subscriptions.clear(); + } + } + + @SdkTestInternalApi + public List bufferedData() { + return buffers; + } + + private class ReplayableByteBufferSubscription implements Subscription { + private final AtomicInteger index = new AtomicInteger(0); + private volatile boolean done; + private final AtomicBoolean processingRequest = new AtomicBoolean(false); + private Subscriber currentSubscriber; + private final AtomicLong outstandingDemand = new AtomicLong(); + + private ReplayableByteBufferSubscription(Subscriber subscriber) { + this.currentSubscriber = subscriber; + } + + @Override + public void request(long n) { + if (n <= 0) { + currentSubscriber.onError(new IllegalArgumentException("§3.9: non-positive requests are not allowed!")); + currentSubscriber = null; + return; + } + + if (done) { + return; + } + + if (buffers.size() == 0) { + currentSubscriber.onComplete(); + done = true; + subscriptions.remove(this); + return; + } + + outstandingDemand.updateAndGet(current -> { + if (Long.MAX_VALUE - current < n) { + return Long.MAX_VALUE; + } + + return current + n; + }); + processRequest(); + } + + private void processRequest() { + do { + if (!processingRequest.compareAndSet(false, true)) { + // Some other thread is processing the queue, so we don't need to. + return; + } + + try { + doProcessRequest(); + } catch (Throwable e) { + notifyError(new IllegalStateException("Encountered fatal error in publisher", e)); + subscriptions.remove(this); + break; + } finally { + processingRequest.set(false); + } + + } while (shouldProcessRequest()); + } + + private boolean shouldProcessRequest() { + return !done && outstandingDemand.get() > 0 && index.get() < buffers.size(); + } + + private void doProcessRequest() { + while (true) { + if (!shouldProcessRequest()) { + return; + } + + int currentIndex = this.index.getAndIncrement(); + + if (currentIndex >= buffers.size()) { + // This should never happen because shouldProcessRequest() ensures that index.get() < buffers.size() + // before incrementing. If this condition is true, it likely indicates a concurrency bug or that buffers + // was modified unexpectedly. This defensive check is here to catch such rare, unexpected situations. + notifyError(new IllegalStateException("Index out of bounds")); + subscriptions.remove(this); + return; + } + + ByteBuffer buffer = buffers.get(currentIndex); + currentSubscriber.onNext(buffer.asReadOnlyBuffer()); + outstandingDemand.decrementAndGet(); + + if (currentIndex == buffers.size() - 1) { + done = true; + currentSubscriber.onComplete(); + subscriptions.remove(this); + break; + } + } + } + + @Override + public void cancel() { + done = true; + subscriptions.remove(this); + } + + public void notifyError(Exception exception) { + if (currentSubscriber != null) { + done = true; + currentSubscriber.onError(exception); + currentSubscriber = null; + } + } } } diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBufferAsyncRequestBodyTckTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBufferAsyncRequestBodyTckTest.java new file mode 100644 index 000000000000..3d9df4b9e1e2 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBufferAsyncRequestBodyTckTest.java @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import org.apache.commons.lang3.RandomStringUtils; +import org.reactivestreams.Publisher; +import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.core.SdkBytes; + +public class ByteBufferAsyncRequestBodyTckTest extends org.reactivestreams.tck.PublisherVerification { + public ByteBufferAsyncRequestBodyTckTest() { + super(new TestEnvironment()); + } + + @Override + public Publisher createPublisher(long elements) { + List buffers = new ArrayList<>(); + for (int i = 0; i < elements; i++) { + buffers.add(SdkBytes.fromUtf8String(RandomStringUtils.randomAscii(1024)).asByteBuffer()); + } + return ByteBuffersAsyncRequestBody.of(buffers.toArray(new ByteBuffer[0])); + } + + @Override + public Publisher createFailedPublisher() { + ByteBuffersAsyncRequestBody bufferingAsyncRequestBody = ByteBuffersAsyncRequestBody.of(ByteBuffer.wrap(RandomStringUtils.randomAscii(1024).getBytes())); + bufferingAsyncRequestBody.close(); + return bufferingAsyncRequestBody; + } + + public long maxElementsFromPublisher() { + return 100; + } + +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBodyTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBodyTest.java index e1035dc25b0c..c80784b26a90 100644 --- a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBodyTest.java +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/ByteBuffersAsyncRequestBodyTest.java @@ -15,58 +15,44 @@ package software.amazon.awssdk.core.internal.async; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.IntStream; +import org.apache.commons.lang3.RandomStringUtils; +import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; import org.reactivestreams.Subscriber; import org.reactivestreams.Subscription; import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.exception.NonRetryableException; import software.amazon.awssdk.utils.BinaryUtils; class ByteBuffersAsyncRequestBodyTest { - private static class TestSubscriber implements Subscriber { - private Subscription subscription; - private boolean onCompleteCalled = false; - private int callsToComplete = 0; - private final List publishedResults = Collections.synchronizedList(new ArrayList<>()); - - public void request(long n) { - subscription.request(n); - } - - @Override - public void onSubscribe(Subscription s) { - this.subscription = s; - } - - @Override - public void onNext(ByteBuffer byteBuffer) { - publishedResults.add(byteBuffer); - } - - @Override - public void onError(Throwable throwable) { - throw new IllegalStateException(throwable); - } + private static ExecutorService executor = Executors.newFixedThreadPool(10); - @Override - public void onComplete() { - onCompleteCalled = true; - callsToComplete++; - } + @AfterAll + static void tearDown() { + executor.shutdown(); } @Test @@ -214,4 +200,167 @@ public void staticFromBytesConstructorSetsLengthBasedOnArrayLength() { assertEquals(bytes.length, body.contentLength().get()); } + @Test + public void subscribe_whenBodyIsClosed_shouldNotifySubscriberWithError() { + ByteBuffersAsyncRequestBody body = ByteBuffersAsyncRequestBody.of(ByteBuffer.wrap("test".getBytes())); + body.close(); // Set closed to true + Subscriber mockSubscriber = mock(Subscriber.class); + + body.subscribe(mockSubscriber); + + verify(mockSubscriber).onSubscribe(any()); + verify(mockSubscriber).onError(argThat(e -> + e instanceof NonRetryableException && + e.getMessage().equals("AsyncRequestBody has been closed") + )); + } + + @Test + public void close_withActiveSubscriptions_shouldNotifyAllSubscribers() { + ByteBuffersAsyncRequestBody body = ByteBuffersAsyncRequestBody.of(ByteBuffer.wrap(RandomStringUtils.randomAscii(1024).getBytes(StandardCharsets.UTF_8))); + + Subscriber subscriber1 = mock(Subscriber.class); + Subscriber subscriber2 = mock(Subscriber.class); + Subscriber subscriber3 = mock(Subscriber.class); + + body.subscribe(subscriber1); + body.subscribe(subscriber2); + body.subscribe(subscriber3); + + body.close(); + + verify(subscriber1).onError(argThat(e -> + e instanceof IllegalStateException && + e.getMessage().contains("The publisher has been closed") + )); + verify(subscriber2).onError(argThat(e -> + e instanceof IllegalStateException && + e.getMessage().contains("The publisher has been closed") + )); + verify(subscriber3).onError(argThat(e -> + e instanceof IllegalStateException && + e.getMessage().contains("The publisher has been closed") + )); + } + + @Test + public void bufferedData_afterClose_shouldBeEmpty() { + ByteBuffersAsyncRequestBody body = ByteBuffersAsyncRequestBody.of( + ByteBuffer.wrap("test1".getBytes()), + ByteBuffer.wrap("test2".getBytes())); + + assertThat(body.bufferedData()).hasSize(2); + body.close(); + assertThat(body.bufferedData()).isEmpty(); + } + + @Test + public void concurrentSubscribeAndClose_shouldBeThreadSafe() throws InterruptedException { + ByteBuffersAsyncRequestBody body = ByteBuffersAsyncRequestBody.of( + ByteBuffer.wrap("test1".getBytes()), + ByteBuffer.wrap("test2".getBytes())); + + + CountDownLatch latch = new CountDownLatch(10); + AtomicInteger successfulSubscriptions = new AtomicInteger(0); + AtomicInteger errorNotifications = new AtomicInteger(0); + + // Start multiple threads subscribing + for (int i = 0; i < 10; i++) { + executor.submit(() -> { + try { + Subscriber subscriber = new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + successfulSubscriptions.incrementAndGet(); + s.request(1); + } + + @Override + public void onNext(ByteBuffer byteBuffer) {} + + @Override + public void onError(Throwable t) { + errorNotifications.incrementAndGet(); + } + + @Override + public void onComplete() {} + }; + body.subscribe(subscriber); + } finally { + latch.countDown(); + } + }); + } + + // Close after a short delay + Thread.sleep(10); + body.close(); + + latch.await(5, TimeUnit.SECONDS); + executor.shutdown(); + + // All subscribers should have been notified + assertThat(successfulSubscriptions.get()).isEqualTo(10); + } + + @Test + public void subscription_readOnlyBuffers_shouldNotAffectOriginalData() { + ByteBuffer originalBuffer = ByteBuffer.wrap(RandomStringUtils.randomAscii(1024).getBytes()); + ByteBuffersAsyncRequestBody body = ByteBuffersAsyncRequestBody.of( + originalBuffer); + int originalPosition = originalBuffer.position(); + + Subscriber mockSubscriber = mock(Subscriber.class); + ArgumentCaptor subscriptionCaptor = ArgumentCaptor.forClass(Subscription.class); + ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + + body.subscribe(mockSubscriber); + verify(mockSubscriber).onSubscribe(subscriptionCaptor.capture()); + + Subscription subscription = subscriptionCaptor.getValue(); + subscription.request(1); + + verify(mockSubscriber).onNext(bufferCaptor.capture()); + ByteBuffer receivedBuffer = bufferCaptor.getValue(); + byte[] bytes = BinaryUtils.copyBytesFrom(receivedBuffer); + + assertThat(receivedBuffer.isReadOnly()).isTrue(); + + assertThat(originalBuffer.position()).isEqualTo(originalPosition); + } + + private static class TestSubscriber implements Subscriber { + private Subscription subscription; + private boolean onCompleteCalled = false; + private int callsToComplete = 0; + private final List publishedResults = Collections.synchronizedList(new ArrayList<>()); + + public void request(long n) { + subscription.request(n); + } + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + } + + @Override + public void onNext(ByteBuffer byteBuffer) { + publishedResults.add(byteBuffer); + } + + @Override + public void onError(Throwable throwable) { + throw new IllegalStateException(throwable); + } + + @Override + public void onComplete() { + onCompleteCalled = true; + callsToComplete++; + } + } + }