Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.instrumentation.spring.webflux.v5_3;

import static java.util.Collections.emptyList;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.List;
import javax.annotation.Nullable;
import org.springframework.http.HttpHeaders;

class HeaderUtil {
@Nullable private static final MethodHandle GET_HEADERS;

static {
// since spring web 7.0
MethodHandle methodHandle =
findGetHeadersMethod(MethodType.methodType(List.class, String.class, List.class));
if (methodHandle == null) {
// up to spring web 7.0
methodHandle =
findGetHeadersMethod(MethodType.methodType(Object.class, Object.class, Object.class));
}
GET_HEADERS = methodHandle;
}

private static MethodHandle findGetHeadersMethod(MethodType methodType) {
try {
return MethodHandles.lookup().findVirtual(HttpHeaders.class, "getOrDefault", methodType);
} catch (Throwable t) {
return null;
}
}

@SuppressWarnings("unchecked") // casting MethodHandle.invoke result
static List<String> getHeader(HttpHeaders headers, String name) {
if (GET_HEADERS != null) {
try {
return (List<String>) GET_HEADERS.invoke(headers, name, emptyList());
} catch (Throwable t) {
// ignore
}
}
return emptyList();
}

private HeaderUtil() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@
package io.opentelemetry.instrumentation.spring.webflux.v5_3;

import io.opentelemetry.instrumentation.api.semconv.http.HttpServerAttributesGetter;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.List;
import javax.annotation.Nullable;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.web.reactive.HandlerMapping;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.util.pattern.PathPattern;
Expand All @@ -18,27 +21,85 @@ enum WebfluxServerHttpAttributesGetter
implements HttpServerAttributesGetter<ServerWebExchange, ServerWebExchange> {
INSTANCE;

private static final MethodHandle GET_RAW_STATUS_CODE;
private static final MethodHandle GET_STATUS_CODE;
private static final MethodHandle STATUS_CODE_VALUE;

static {
MethodHandle getRawStatusCode = null;
MethodHandle getStatusCode = null;
MethodHandle statusCodeValue = null;

MethodHandles.Lookup lookup = MethodHandles.publicLookup();

// up to webflux 7.0
try {
getRawStatusCode =
lookup.findVirtual(
ServerHttpResponse.class, "getRawStatusCode", MethodType.methodType(Integer.class));
} catch (Exception exception) {
// ignore
}

// since webflux 7.0
try {
Class<?> httpStatusCodeClass = Class.forName("org.springframework.http.HttpStatusCode");
getStatusCode =
lookup.findVirtual(
ServerHttpResponse.class,
"getStatusCode",
MethodType.methodType(httpStatusCodeClass));
statusCodeValue =
lookup.findVirtual(httpStatusCodeClass, "value", MethodType.methodType(int.class));
} catch (Exception exception) {
// ignore
}

GET_RAW_STATUS_CODE = getRawStatusCode;
GET_STATUS_CODE = getStatusCode;
STATUS_CODE_VALUE = statusCodeValue;
}

private static Integer getStatusCode(ServerHttpResponse response) {
if (GET_RAW_STATUS_CODE != null) {
try {
return (Integer) GET_RAW_STATUS_CODE.invoke(response);
} catch (Throwable e) {
// ignore
}
}
if (GET_STATUS_CODE != null && STATUS_CODE_VALUE != null) {
try {
Object statusCode = GET_STATUS_CODE.invoke(response);
return (Integer) STATUS_CODE_VALUE.invoke(statusCode);
} catch (Throwable e) {
// ignore
}
}
return null;
}

@Override
public String getHttpRequestMethod(ServerWebExchange request) {
return request.getRequest().getMethod().name();
}

@Override
public List<String> getHttpRequestHeader(ServerWebExchange request, String name) {
return request.getRequest().getHeaders().getOrDefault(name, Collections.emptyList());
return HeaderUtil.getHeader(request.getRequest().getHeaders(), name);
}

@Nullable
@Override
public Integer getHttpResponseStatusCode(
ServerWebExchange request, ServerWebExchange response, @Nullable Throwable error) {
return response.getResponse().getRawStatusCode();
return getStatusCode(response.getResponse());
}

@Override
public List<String> getHttpResponseHeader(
ServerWebExchange request, ServerWebExchange response, String name) {
return response.getResponse().getHeaders().getOrDefault(name, Collections.emptyList());
return HeaderUtil.getHeader(response.getResponse().getHeaders(), name);
}

@Nullable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import io.opentelemetry.context.propagation.TextMapGetter;
import java.util.Iterator;
import java.util.List;
import javax.annotation.Nullable;
import org.springframework.web.server.ServerWebExchange;

Expand All @@ -35,7 +34,6 @@ public Iterator<String> getAll(@Nullable ServerWebExchange exchange, String key)
if (exchange == null) {
return emptyIterator();
}
List<String> list = exchange.getRequest().getHeaders().get(key);
return list != null ? list.iterator() : emptyIterator();
return HeaderUtil.getHeader(exchange.getRequest().getHeaders(), key).iterator();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
plugins {
id("otel.java-conventions")
}

dependencies {
implementation("io.opentelemetry.javaagent:opentelemetry-testing-common")
compileOnly("org.springframework:spring-webflux:7.0.0")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.instrumentation.spring.webflux.client;

import java.util.function.Consumer;
import java.util.function.Function;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

class Webflux7Util {
static final boolean isWebflux7 = detectWebflux7();

private static boolean detectWebflux7() {
try {
WebClient.RequestBodySpec.class.getMethod("exchange");
return false;
} catch (NoSuchMethodException e) {
return true;
}
}

static Mono<ClientResponse> exchangeToMono(WebClient.RequestBodySpec request) {
return request.exchangeToMono(Mono::just);
}

static <T> T doRequest(
WebClient.RequestBodySpec request, Function<ClientResponse, Mono<T>> handler) {
return request.exchangeToMono(handler).block();
}

static int doRequest(WebClient.RequestBodySpec request) {
return doRequest(request, response -> Mono.just(response.statusCode().value()));
}

static int getStatusCode(ClientResponse response) {
return response.statusCode().value();
}

static void sendRequestWithCallback(
WebClient.RequestBodySpec request,
Consumer<Integer> callback,
Consumer<Throwable> errorCallback) {
request
.exchangeToMono(response -> Mono.just(response.statusCode().value()))
.subscribe(callback, errorCallback);
}

private Webflux7Util() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ plugins {

dependencies {
implementation("io.opentelemetry.javaagent:opentelemetry-testing-common")
implementation(project(":instrumentation:spring:spring-webflux:spring-webflux-5.3:testing-webflux7"))

compileOnly("org.springframework:spring-webflux:5.0.0.RELEASE")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import io.opentelemetry.api.common.AttributeKey;
import io.opentelemetry.api.trace.SpanKind;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.testing.junit.http.AbstractHttpClientTest;
import io.opentelemetry.instrumentation.testing.junit.http.HttpClientResult;
import io.opentelemetry.instrumentation.testing.junit.http.HttpClientTestOptions;
Expand All @@ -35,6 +37,7 @@
import org.springframework.http.HttpMethod;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

public abstract class AbstractSpringWebfluxClientInstrumentationTest
extends AbstractHttpClientTest<WebClient.RequestBodySpec> {
Expand All @@ -58,6 +61,10 @@ public WebClient.RequestBodySpec buildRequest(
@Override
public int sendRequest(
WebClient.RequestBodySpec request, String method, URI uri, Map<String, String> headers) {
if (Webflux7Util.isWebflux7) {
return Webflux7Util.doRequest(request);
}

ClientResponse response = requireNonNull(request.exchange().block());
return getStatusCode(response);
}
Expand All @@ -69,11 +76,24 @@ public void sendRequestWithCallback(
URI uri,
Map<String, String> headers,
HttpClientResult httpClientResult) {
request
.exchange()
.subscribe(
response -> httpClientResult.complete(getStatusCode(response)),
httpClientResult::complete);
if (Webflux7Util.isWebflux7) {
// FIXME: context is not propagated to the callback, this needs to be fixed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to look into this more thoroughly. As far as I can tell context is propagated with the agent but not with the library instrumentation.

Context context = Context.current();
Webflux7Util.sendRequestWithCallback(
request,
status -> {
try (Scope ignore = context.makeCurrent()) {
httpClientResult.complete(status);
}
},
httpClientResult::complete);
} else {
request
.exchange()
.subscribe(
response -> httpClientResult.complete(getStatusCode(response)),
httpClientResult::complete);
}
}

@Override
Expand Down Expand Up @@ -164,12 +184,17 @@ void shouldEndSpanOnMonoTimeout() {
() ->
testing.runWithSpan(
"parent",
() ->
buildRequest("GET", uri, emptyMap())
.exchange()
// apply Mono timeout that is way shorter than HTTP request timeout
.timeout(Duration.ofSeconds(1))
.block()));
() -> {
WebClient.RequestBodySpec request = buildRequest("GET", uri, emptyMap());
Mono<ClientResponse> mono;
if (Webflux7Util.isWebflux7) {
mono = Webflux7Util.exchangeToMono(request);
} else {
mono = request.exchange();
}
// apply Mono timeout that is way shorter than HTTP request timeout
return mono.timeout(Duration.ofSeconds(1)).block();
}));

testing.waitAndAssertTraces(
trace ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.springframework.http.HttpMethod;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;

final class SpringWebfluxSingleConnection implements SingleConnection {

Expand Down Expand Up @@ -49,16 +50,31 @@ public int doRequest(String path, Map<String, String> headers) throws Exception
WebClient.RequestBodySpec request =
webClient.method(HttpMethod.GET).uri(uri).headers(h -> headers.forEach(h::add));

ClientResponse response = request.exchange().block();
// read response body, this seems to be needed to ensure that the connection can be reused
response.bodyToMono(String.class).block();
if (Webflux7Util.isWebflux7) {
return Webflux7Util.doRequest(
request,
response -> {
String responseId = response.headers().asHttpHeaders().getFirst(REQUEST_ID_HEADER);
if (!requestId.equals(responseId)) {
return Mono.error(
new IllegalStateException(
String.format(
"Received response with id %s, expected %s", responseId, requestId)));
}
return Mono.just(Webflux7Util.getStatusCode(response));
});
} else {
ClientResponse response = request.exchange().block();
// read response body, this seems to be needed to ensure that the connection can be reused
response.bodyToMono(String.class).block();

String responseId = response.headers().asHttpHeaders().getFirst(REQUEST_ID_HEADER);
if (!requestId.equals(responseId)) {
throw new IllegalStateException(
String.format("Received response with id %s, expected %s", responseId, requestId));
}
String responseId = response.headers().asHttpHeaders().getFirst(REQUEST_ID_HEADER);
if (!requestId.equals(responseId)) {
throw new IllegalStateException(
String.format("Received response with id %s, expected %s", responseId, requestId));
}

return response.statusCode().value();
return response.statusCode().value();
}
}
}
1 change: 1 addition & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ include(":instrumentation:spring:spring-webflux:spring-webflux-5.0:javaagent")
include(":instrumentation:spring:spring-webflux:spring-webflux-5.0:testing")
include(":instrumentation:spring:spring-webflux:spring-webflux-5.3:library")
include(":instrumentation:spring:spring-webflux:spring-webflux-5.3:testing")
include(":instrumentation:spring:spring-webflux:spring-webflux-5.3:testing-webflux7")
include(":instrumentation:spring:spring-webmvc:spring-webmvc-3.1:javaagent")
include(":instrumentation:spring:spring-webmvc:spring-webmvc-3.1:wildfly-testing")
include(":instrumentation:spring:spring-webmvc:spring-webmvc-5.3:library")
Expand Down
Loading