Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ public static AsyncAdviceScope start(HttpRequest request) {
}

public CompletableFuture<HttpResponse<?>> end(
@Nullable Throwable throwable, @Nullable CompletableFuture<HttpResponse<?>> future) {
@Nullable Throwable throwable, CompletableFuture<HttpResponse<?>> future) {
if (callDepth.decrementAndGet() > 0 || scope == null) {
// async end nested call
return future;
Expand All @@ -167,8 +167,8 @@ public CompletableFuture<HttpResponse<?>> end(
instrumenter().end(context, request, null, throwable);
return future;
}
future = future.whenComplete(new ResponseConsumer(instrumenter(), context, request));
return CompletableFutureWrapper.wrap(future, parentContext);
return CompletableFutureWrapper.wrap(future, parentContext)
.whenComplete(new ResponseConsumer(instrumenter(), context, request));
}
}

Expand All @@ -182,7 +182,7 @@ public static AsyncAdviceScope methodEnter(
@AssignReturned.ToReturned
@Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class)
public static CompletableFuture<HttpResponse<?>> methodExit(
@Advice.Return @Nullable CompletableFuture<HttpResponse<?>> future,
@Advice.Return CompletableFuture<HttpResponse<?>> future,
@Advice.Thrown @Nullable Throwable throwable,
@Advice.Enter @Nullable AsyncAdviceScope scope) {
return scope == null ? future : scope.end(throwable, future);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
* This class is internal and is hence not for public use. Its APIs are unstable and can change at
* any time.
*/
public final class CompletableFutureWrapper {
public final class CompletableFutureWrapper<T> extends CompletableFuture<T> {
private final CompletableFuture<?> future;

private CompletableFutureWrapper() {}
private CompletableFutureWrapper(CompletableFuture<?> future) {
this.future = future;
}

public static <T> CompletableFuture<T> wrap(CompletableFuture<T> future, Context context) {
CompletableFuture<T> result = new CompletableFuture<>();
CompletableFuture<T> result = new CompletableFutureWrapper<>(future);
future.whenComplete(
(T value, Throwable throwable) -> {
try (Scope ignored = context.makeCurrent()) {
Expand All @@ -32,4 +35,16 @@ public static <T> CompletableFuture<T> wrap(CompletableFuture<T> future, Context

return result;
}

@Override
public <U> CompletableFuture<U> newIncompleteFuture() {
return new CompletableFutureWrapper<>(future);
}

@Override
public boolean cancel(boolean mayInterruptIfRunning) {
boolean result = super.cancel(mayInterruptIfRunning);
future.cancel(mayInterruptIfRunning);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,9 @@ private <T> CompletableFuture<HttpResponse<T>> traceAsync(
instrumenter.end(context, request, null, t);
throw t;
}
future = future.whenComplete(new ResponseConsumer(instrumenter, context, request));
future = CompletableFutureWrapper.wrap(future, parentContext);
future =
CompletableFutureWrapper.wrap(future, parentContext)
.whenComplete(new ResponseConsumer(instrumenter, context, request));
return future;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@

package io.opentelemetry.instrumentation.javahttpclient;

import static io.opentelemetry.api.common.AttributeKey.stringKey;
import static io.opentelemetry.sdk.testing.assertj.OpenTelemetryAssertions.equalTo;
import static io.opentelemetry.semconv.NetworkAttributes.NETWORK_PROTOCOL_VERSION;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import io.opentelemetry.api.common.AttributeKey;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpClientTest;
import io.opentelemetry.instrumentation.testing.junit.http.HttpClientResult;
import io.opentelemetry.instrumentation.testing.junit.http.HttpClientTestOptions;
import io.opentelemetry.sdk.trace.data.StatusData;
import io.opentelemetry.semconv.ErrorAttributes;
import io.opentelemetry.semconv.HttpAttributes;
import io.opentelemetry.semconv.ServerAttributes;
import io.opentelemetry.semconv.UrlAttributes;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
Expand All @@ -19,7 +29,11 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

public abstract class AbstractJavaHttpClientTest extends AbstractHttpClientTest<HttpRequest> {

Expand Down Expand Up @@ -106,4 +120,76 @@ protected void configure(HttpClientTestOptions.Builder optionsBuilder) {
return attributes;
});
}

@SuppressWarnings("Interruption") // test calls CompletableFuture.cancel with true
@Test
void cancelRequest() throws InterruptedException {
boolean isJdk11 = "11".equals(System.getProperty("java.specification.version"));
String method = "GET";
URI uri = resolveAddress("/long-request");

CompletableFuture<String> future =
testing.runWithSpan(
"parent",
() -> {
HttpRequest request =
HttpRequest.newBuilder()
.uri(uri)
.method(method, HttpRequest.BodyPublishers.noBody())
.header("delay", String.valueOf(TimeUnit.SECONDS.toMillis(5)))
.build();
return client
.sendAsync(request, HttpResponse.BodyHandlers.ofString())
.thenApply(HttpResponse::body)
.whenComplete(
(response, throwable) ->
testing.runWithSpan(
"child",
() -> {
if (throwable != null && throwable.getCause() != null) {
Span.current()
.setAttribute(
"throwable", throwable.getCause().getClass().getName());
}
}))
// this stage is only added to trigger the whenComplete stage when this stage gets
// cancelled
.exceptionally(ex -> "cancelled");
});

// sleep a bit to let the request start
Thread.sleep(1_000);
future.cancel(true);
assertThatThrownBy(future::get).isInstanceOf(CancellationException.class);

testing.waitAndAssertTraces(
trace ->
trace.hasSpansSatisfyingExactly(
span -> span.hasName("parent").hasNoParent(),
span ->
span.hasName("GET")
.hasKind(SpanKind.CLIENT)
.hasParent(trace.getSpan(0))
.hasStatus(StatusData.error())
.hasAttributesSatisfyingExactly(
equalTo(UrlAttributes.URL_FULL, uri.toString()),
equalTo(ServerAttributes.SERVER_ADDRESS, uri.getHost()),
equalTo(ServerAttributes.SERVER_PORT, uri.getPort()),
equalTo(HttpAttributes.HTTP_REQUEST_METHOD, method),
equalTo(
ErrorAttributes.ERROR_TYPE, CancellationException.class.getName())),
span ->
span.hasName("test-http-server")
.hasKind(SpanKind.SERVER)
.hasParent(trace.getSpan(1))
// jdk 11 does not cancel the request on the server side so the request
// succeeds
.hasStatus(isJdk11 ? StatusData.unset() : StatusData.error()),
span ->
span.hasName("child")
.hasParent(trace.getSpan(0))
.hasAttributesSatisfyingExactly(
equalTo(
stringKey("throwable"), CancellationException.class.getName()))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import static io.opentelemetry.testing.internal.armeria.common.MediaType.PLAIN_TEXT_UTF_8;

import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.trace.Span;
import io.opentelemetry.api.trace.SpanBuilder;
import io.opentelemetry.api.trace.StatusCode;
import io.opentelemetry.api.trace.Tracer;
import io.opentelemetry.context.Context;
import io.opentelemetry.instrumentation.test.server.http.RequestContextGetter;
Expand Down Expand Up @@ -137,7 +139,7 @@ protected void configure(ServerBuilder sb) throws Exception {
throw new AssertionError((Object) ("more than one " + field + " header present"));
}
}
SpanBuilder span =
SpanBuilder spanBuilder =
tracer
.spanBuilder("test-http-server")
.setSpanKind(SERVER)
Expand All @@ -149,9 +151,20 @@ protected void configure(ServerBuilder sb) throws Exception {

String traceRequestId = req.headers().get("test-request-id");
if (traceRequestId != null) {
span.setAttribute("test.request.id", Integer.parseInt(traceRequestId));
spanBuilder.setAttribute("test.request.id", Integer.parseInt(traceRequestId));
}
span.startSpan().end();
Span span = spanBuilder.startSpan();
ctx.log()
.whenComplete()
.thenAccept(
log -> {
Throwable error = log.responseCause();
if (error != null) {
span.recordException(error);
span.setStatus(StatusCode.ERROR);
}
span.end();
});

// this header is set by java http client http/2 tests
// we delay the response a bit to ensure that client can send the full request before
Expand Down