Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-a91bebb.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "",
"description": "Adds timeouts to ResponsePublisher and ResponseInputStream to close connection if response not consumed"
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,33 @@

package software.amazon.awssdk.core;

import java.io.IOException;
import java.io.InputStream;
import java.time.Duration;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.annotations.SdkTestInternalApi;
import software.amazon.awssdk.core.io.SdkFilterInputStream;
import software.amazon.awssdk.http.Abortable;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.utils.IoUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;

/**
* Input stream that provides access to the unmarshalled POJO response returned by the service in addition to the streamed
* contents. This input stream should be closed after all data has been read from the stream.
*
* <p>
* <b>NOTE:</b> You must read this stream promptly to avoid automatic cancellation. The default timeout for reading is 60
* seconds, which starts when the response stream is ready. If {@link #read()} is not invoked before the timeout, the stream will
* automatically abort to prevent resource leakage.
* <p>
* The timeout can be customized by passing a {@link Duration} to the constructor, or disabled entirely by
* passing {@link Duration#ZERO} or a negative {@link Duration}.
* <p>
* Note about the Apache http client: This input stream can be used to leverage a feature of the Apache http client where
* connections are released back to the connection pool to be reused. As such, calling {@link ResponseInputStream#close() close}
Expand All @@ -43,19 +59,37 @@
@SdkPublicApi
public final class ResponseInputStream<ResponseT> extends SdkFilterInputStream implements Abortable {

private static final Logger log = Logger.loggerFor(ResponseInputStream.class);
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60);
private final ResponseT response;
private final Abortable abortable;
private ScheduledFuture<?> timeoutTask;
private volatile boolean hasRead = false;

public ResponseInputStream(ResponseT resp, AbortableInputStream in) {
this(resp, in, null);
}

public ResponseInputStream(ResponseT resp, AbortableInputStream in, Duration timeout) {
super(in);
this.response = Validate.paramNotNull(resp, "response");
this.abortable = Validate.paramNotNull(in, "abortableInputStream");

Duration resolvedTimeout = timeout != null ? timeout : DEFAULT_TIMEOUT;
scheduleTimeoutTask(resolvedTimeout);
}

public ResponseInputStream(ResponseT resp, InputStream in) {
this(resp, in, null);
}

public ResponseInputStream(ResponseT resp, InputStream in, Duration timeout) {
super(in);
this.response = Validate.paramNotNull(resp, "response");
this.abortable = in instanceof Abortable ? (Abortable) in : null;

Duration resolvedTimeout = timeout != null ? timeout : DEFAULT_TIMEOUT;
scheduleTimeoutTask(resolvedTimeout);
}

/**
Expand All @@ -65,15 +99,77 @@ public ResponseT response() {
return response;
}

@Override
public int read() throws IOException {
cancelTimeoutTask();
return super.read();
}

@Override
public int read(byte[] b) throws IOException {
cancelTimeoutTask();
return super.read(b);
}

@Override
public int read(byte[] b, int off, int len) throws IOException {
cancelTimeoutTask();
return super.read(b, off, len);
}

private void cancelTimeoutTask() {
if (!hasRead && timeoutTask != null) {
timeoutTask.cancel(false);
}
hasRead = true;
}

private void scheduleTimeoutTask(Duration timeout) {
if (timeout.equals(Duration.ZERO) || timeout.isNegative()) {
return;
}

long timeoutInMillis = timeout.toMillis();
timeoutTask = TimeoutScheduler.INSTANCE.schedule(() -> {
if (!hasRead) {
log.debug(() -> String.format("InputStream was not read before timeout of [%d] milliseconds, aborting "
+ "stream and closing connection.", timeoutInMillis));
abort();
}
}, timeoutInMillis, TimeUnit.MILLISECONDS);
}

private static final class TimeoutScheduler {
static final ScheduledExecutorService INSTANCE =
Executors.newScheduledThreadPool(1, r -> {
Thread t = new Thread(r, "response-input-stream-timeout-scheduler");
t.setDaemon(true);
return t;
});
}

/**
* Close the underlying connection, dropping all remaining data in the stream, and not leaving the
* connection open to be used for future requests.
*/
@Override
public void abort() {
if (timeoutTask != null) {
timeoutTask.cancel(false);
}
if (abortable != null) {
abortable.abort();
}
IoUtils.closeQuietly(in, null);
IoUtils.closeQuietlyV2(in, log);
}

@SdkTestInternalApi
public boolean hasTimeoutTask() {
return timeoutTask != null;
}

@SdkTestInternalApi
public boolean timeoutTaskDoneOrCancelled() {
return timeoutTask != null && timeoutTask.isDone();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.function.Consumer;
Expand Down Expand Up @@ -271,6 +272,10 @@ static <ResponseT> AsyncResponseTransformer<ResponseT, ResponseBytes<ResponseT>>
* other transformers, like {@link #toFile(Path)} and {@link #toBytes()}, which only have their {@link CompletableFuture}
* completed after the entire response body has finished streaming.
* <p>
* The publisher has a default timeout of 60 seconds that starts when the response body begins streaming. If no subscriber is
* registered within this time, the subscription will be automatically cancelled. Use {@link #toPublisher(Duration)} to
* specify a custom timeout.
* <p>
* You are responsible for subscribing to this publisher and managing the associated back-pressure. Therefore, this
* transformer is only recommended for advanced use cases.
* <p>
Expand All @@ -293,6 +298,34 @@ static <ResponseT extends SdkResponse> AsyncResponseTransformer<ResponseT, Respo
return new PublisherAsyncResponseTransformer<>();
}

/**
* Creates an {@link AsyncResponseTransformer} with a custom timeout that publishes the response body content through a
* {@link ResponsePublisher}, which is an {@link SdkPublisher} that also contains a reference to the {@link SdkResponse}
* returned by the service.
* <p>
* When this transformer is used with an async client, the {@link CompletableFuture} that the client returns will be completed
* once the {@link SdkResponse} is available and the response body <i>begins</i> streaming. This behavior differs from some
* other transformers, like {@link #toFile(Path)} and {@link #toBytes()}, which only have their {@link CompletableFuture}
* completed after the entire response body has finished streaming.
* <p>
* The timeout starts when the response body begins streaming. If no subscriber is registered within the specified timeout,
* the subscription will be automatically cancelled. To disable the timeout, pass {@link Duration#ZERO} or a negative
* {@link Duration}.
* <p>
* You are responsible for subscribing to this publisher and managing the associated back-pressure. Therefore, this
* transformer is only recommended for advanced use cases.
*
* @param timeout Maximum time to wait for subscription before cancelling. Use {@link Duration#ZERO} or a negative
* {@link Duration} to disable timeout.
* @param <ResponseT> Pojo response type.
* @return AsyncResponseTransformer instance.
* @see #toPublisher()
*/
static <ResponseT extends SdkResponse> AsyncResponseTransformer<ResponseT,
ResponsePublisher<ResponseT>> toPublisher(Duration timeout) {
return new PublisherAsyncResponseTransformer<>(timeout);
}

/**
* Creates an {@link AsyncResponseTransformer} that allows reading the response body content as an {@link InputStream}.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,55 @@
package software.amazon.awssdk.core.async;

import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.Objects;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkPublicApi;
import software.amazon.awssdk.annotations.SdkTestInternalApi;
import software.amazon.awssdk.core.SdkResponse;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.ToString;
import software.amazon.awssdk.utils.Validate;

/**
* An {@link SdkPublisher} that publishes response body content and also contains a reference to the {@link SdkResponse} returned
* by the service.
* <p>
* <b>NOTE:</b> You must subscribe to this publisher promptly to avoid automatic cancellation. The default timeout for
* subscribing is 60 seconds, which starts when the response body begins streaming. If {@link #subscribe(Subscriber)} is not
* invoked before the timeout, the publisher will automatically cancel the underlying subscription to prevent resource leakage.
* <p>
* The timeout can be customized by passing a {@link Duration} to the constructor, or disabled entirely by
* passing {@link Duration#ZERO} or a negative {@link Duration}.
*
* @param <ResponseT> Pojo response type.
* @see AsyncResponseTransformer#toPublisher()
*/
@SdkPublicApi
public final class ResponsePublisher<ResponseT extends SdkResponse> implements SdkPublisher<ByteBuffer> {

private static final Logger log = Logger.loggerFor(ResponsePublisher.class);
private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60);
private final ResponseT response;
private final SdkPublisher<ByteBuffer> publisher;
private ScheduledFuture<?> timeoutTask;
private volatile boolean subscribed = false;

public ResponsePublisher(ResponseT response, SdkPublisher<ByteBuffer> publisher) {
this(response, publisher, null);
}

public ResponsePublisher(ResponseT response, SdkPublisher<ByteBuffer> publisher, Duration timeout) {
this.response = Validate.paramNotNull(response, "response");
this.publisher = Validate.paramNotNull(publisher, "publisher");

Duration resolvedTimeout = timeout != null ? timeout : DEFAULT_TIMEOUT;
scheduleTimeoutTask(resolvedTimeout);
}

/**
Expand All @@ -50,9 +76,59 @@ public ResponseT response() {

@Override
public void subscribe(Subscriber<? super ByteBuffer> subscriber) {
subscribed = true;
if (timeoutTask != null) {
timeoutTask.cancel(false);
}

publisher.subscribe(subscriber);
}

private void scheduleTimeoutTask(Duration timeout) {
if (timeout.equals(Duration.ZERO) || timeout.isNegative()) {
return;
}

long timeoutInMillis = timeout.toMillis();
timeoutTask = TimeoutScheduler.INSTANCE.schedule(() -> {
if (!subscribed) {
log.debug(() -> String.format("Publisher was not consumed before timeout of [%d] milliseconds, cancelling "
+ "subscription and closing connection.", timeoutInMillis));

publisher.subscribe(new CancellingSubscriber());
}
}, timeoutInMillis, TimeUnit.MILLISECONDS);
}

private static final class TimeoutScheduler {
static final ScheduledExecutorService INSTANCE =
Executors.newScheduledThreadPool(1, r -> {
Thread t = new Thread(r, "response-publisher-timeout-scheduler");
t.setDaemon(true);
return t;
});
}

private static class CancellingSubscriber implements Subscriber<ByteBuffer> {

@Override
public void onSubscribe(Subscription s) {
s.cancel();
}

@Override
public void onNext(ByteBuffer b) {
}

@Override
public void onError(Throwable t) {
}

@Override
public void onComplete() {
}
}

@Override
public String toString() {
return ToString.builder("ResponsePublisher")
Expand Down Expand Up @@ -84,4 +160,14 @@ public int hashCode() {
result = 31 * result + (publisher != null ? publisher.hashCode() : 0);
return result;
}

@SdkTestInternalApi
public boolean hasTimeoutTask() {
return timeoutTask != null;
}

@SdkTestInternalApi
public boolean timeoutTaskDoneOrCancelled() {
return timeoutTask != null && timeoutTask.isDone();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package software.amazon.awssdk.core.internal.async;

import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.concurrent.CompletableFuture;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.SdkResponse;
Expand All @@ -35,6 +36,14 @@ public final class PublisherAsyncResponseTransformer<ResponseT extends SdkRespon

private volatile CompletableFuture<ResponsePublisher<ResponseT>> future;
private volatile ResponseT response;
private Duration timeout;

public PublisherAsyncResponseTransformer() {
}

public PublisherAsyncResponseTransformer(Duration timeout) {
this.timeout = timeout;
}

@Override
public CompletableFuture<ResponsePublisher<ResponseT>> prepare() {
Expand All @@ -50,7 +59,7 @@ public void onResponse(ResponseT response) {

@Override
public void onStream(SdkPublisher<ByteBuffer> publisher) {
future.complete(new ResponsePublisher<>(response, publisher));
future.complete(new ResponsePublisher<>(response, publisher, timeout));
}

@Override
Expand Down
Loading
Loading