diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyChannelPipelineInstrumentation.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyChannelPipelineInstrumentation.java index 4cfe703e194..51445fed833 100644 --- a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyChannelPipelineInstrumentation.java +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/NettyChannelPipelineInstrumentation.java @@ -23,8 +23,8 @@ import datadog.trace.instrumentation.netty41.server.HttpServerResponseTracingHandler; import datadog.trace.instrumentation.netty41.server.HttpServerTracingHandler; import datadog.trace.instrumentation.netty41.server.MaybeBlockResponseHandler; -import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerRequestTracingHandler; -import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerResponseTracingHandler; +import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerInboundTracingHandler; +import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerOutboundTracingHandler; import datadog.trace.instrumentation.netty41.server.websocket.WebSocketServerTracingHandler; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelPipeline; @@ -34,6 +34,8 @@ import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.websocketx.WebSocketFrameDecoder; +import io.netty.handler.codec.http.websocketx.WebSocketFrameEncoder; import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; import io.netty.util.Attribute; import net.bytebuddy.asm.Advice; @@ -82,8 +84,8 @@ public String[] helperClassNames() { packageName + ".server.HttpServerTracingHandler", packageName + ".server.MaybeBlockResponseHandler", packageName + ".server.websocket.WebSocketServerTracingHandler", - packageName + ".server.websocket.WebSocketServerResponseTracingHandler", - packageName + ".server.websocket.WebSocketServerRequestTracingHandler", + packageName + ".server.websocket.WebSocketServerOutboundTracingHandler", + packageName + ".server.websocket.WebSocketServerInboundTracingHandler", packageName + ".NettyHttp2Helper", packageName + ".NettyPipelineHelper", }; @@ -162,23 +164,31 @@ public static void addHandler( HttpServerResponseTracingHandler.INSTANCE, MaybeBlockResponseHandler.INSTANCE); } else if (handler instanceof WebSocketServerProtocolHandler) { - if (InstrumenterConfig.get().isWebsocketTracingEnabled()) { - if (pipeline.get(HttpServerTracingHandler.class) != null) { - NettyPipelineHelper.addHandlerAfter( - pipeline, "HttpServerTracingHandler#0", new WebSocketServerTracingHandler()); + if (InstrumenterConfig.get().isWebsocketTracingEnabled() + && pipeline.get(HttpServerTracingHandler.class) != null) { + // remove single websocket handler if added before + if (pipeline.get(WebSocketServerInboundTracingHandler.class) != null) { + pipeline.remove(WebSocketServerInboundTracingHandler.class); } - if (pipeline.get(HttpServerRequestTracingHandler.class) != null) { - NettyPipelineHelper.addHandlerAfter( - pipeline, - "HttpServerRequestTracingHandler#0", - WebSocketServerRequestTracingHandler.INSTANCE); - } - if (pipeline.get(HttpServerResponseTracingHandler.class) != null) { - NettyPipelineHelper.addHandlerAfter( - pipeline, - "HttpServerResponseTracingHandler#0", - WebSocketServerResponseTracingHandler.INSTANCE); + if (pipeline.get(WebSocketServerOutboundTracingHandler.class) != null) { + pipeline.remove(WebSocketServerOutboundTracingHandler.class); } + NettyPipelineHelper.addHandlerAfter( + pipeline, + pipeline.get(HttpServerTracingHandler.class), + new WebSocketServerTracingHandler()); + } + } else if (handler instanceof WebSocketFrameDecoder) { + if (InstrumenterConfig.get().isWebsocketTracingEnabled() + && pipeline.get(WebSocketServerTracingHandler.class) == null) { + NettyPipelineHelper.addHandlerAfter( + pipeline, handler, WebSocketServerInboundTracingHandler.INSTANCE); + } + } else if (handler instanceof WebSocketFrameEncoder) { + if (InstrumenterConfig.get().isWebsocketTracingEnabled() + && pipeline.get(WebSocketServerTracingHandler.class) == null) { + NettyPipelineHelper.addHandlerAfter( + pipeline, handler, WebSocketServerOutboundTracingHandler.INSTANCE); } } // Client pipeline handlers diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerRequestTracingHandler.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerInboundTracingHandler.java similarity index 96% rename from dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerRequestTracingHandler.java rename to dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerInboundTracingHandler.java index 5556af0d2b2..882a826e15e 100644 --- a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerRequestTracingHandler.java +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerInboundTracingHandler.java @@ -17,13 +17,12 @@ import io.netty.handler.codec.http.websocketx.WebSocketFrame; @ChannelHandler.Sharable -public class WebSocketServerRequestTracingHandler extends ChannelInboundHandlerAdapter { - public static WebSocketServerRequestTracingHandler INSTANCE = - new WebSocketServerRequestTracingHandler(); +public class WebSocketServerInboundTracingHandler extends ChannelInboundHandlerAdapter { + public static WebSocketServerInboundTracingHandler INSTANCE = + new WebSocketServerInboundTracingHandler(); @Override public void channelRead(ChannelHandlerContext ctx, Object frame) { - if (frame instanceof WebSocketFrame) { Channel channel = ctx.channel(); HandlerContext.Receiver receiverContext = diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerResponseTracingHandler.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerOutboundTracingHandler.java similarity index 96% rename from dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerResponseTracingHandler.java rename to dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerOutboundTracingHandler.java index cc073f6aa1b..e500c2357f5 100644 --- a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerResponseTracingHandler.java +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerOutboundTracingHandler.java @@ -17,9 +17,9 @@ import io.netty.handler.codec.http.websocketx.WebSocketFrame; @ChannelHandler.Sharable -public class WebSocketServerResponseTracingHandler extends ChannelOutboundHandlerAdapter { - public static WebSocketServerResponseTracingHandler INSTANCE = - new WebSocketServerResponseTracingHandler(); +public class WebSocketServerOutboundTracingHandler extends ChannelOutboundHandlerAdapter { + public static WebSocketServerOutboundTracingHandler INSTANCE = + new WebSocketServerOutboundTracingHandler(); @Override public void write(ChannelHandlerContext ctx, Object frame, ChannelPromise promise) diff --git a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerTracingHandler.java b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerTracingHandler.java index 8f6f4b2e6c4..fdf48ea82b5 100644 --- a/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerTracingHandler.java +++ b/dd-java-agent/instrumentation/netty-4.1/src/main/java/datadog/trace/instrumentation/netty41/server/websocket/WebSocketServerTracingHandler.java @@ -4,11 +4,11 @@ public class WebSocketServerTracingHandler extends CombinedChannelDuplexHandler< - WebSocketServerRequestTracingHandler, WebSocketServerResponseTracingHandler> { + WebSocketServerInboundTracingHandler, WebSocketServerOutboundTracingHandler> { public WebSocketServerTracingHandler() { super( - WebSocketServerRequestTracingHandler.INSTANCE, - WebSocketServerResponseTracingHandler.INSTANCE); + WebSocketServerInboundTracingHandler.INSTANCE, + WebSocketServerOutboundTracingHandler.INSTANCE); } } diff --git a/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/SpringWebfluxTest.groovy b/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/SpringWebfluxTest.groovy index b81d625d1c1..3e5400d71c1 100644 --- a/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/SpringWebfluxTest.groovy +++ b/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/SpringWebfluxTest.groovy @@ -1,5 +1,6 @@ import datadog.trace.agent.test.AgentTestRunner import datadog.trace.agent.test.asserts.TraceAssert +import datadog.trace.agent.test.base.OkHttpWebsocketClient import datadog.trace.api.DDSpanTypes import datadog.trace.api.DDTags import datadog.trace.bootstrap.instrumentation.api.Tags @@ -9,6 +10,9 @@ import dd.trace.instrumentation.springwebflux.server.EchoHandlerFunction import dd.trace.instrumentation.springwebflux.server.FooModel import dd.trace.instrumentation.springwebflux.server.SpringWebFluxTestApplication import dd.trace.instrumentation.springwebflux.server.TestController +import dd.trace.instrumentation.springwebflux.server.WsHandler +import net.bytebuddy.utility.RandomString +import org.springframework.beans.factory.annotation.Autowired import org.springframework.boot.test.context.SpringBootTest import org.springframework.boot.test.context.TestConfiguration import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory @@ -21,6 +25,10 @@ import org.springframework.web.reactive.function.client.WebClient import org.springframework.web.server.ResponseStatusException import reactor.core.publisher.Mono +import static datadog.trace.agent.test.base.HttpServerTest.websocketCloseSpan +import static datadog.trace.agent.test.base.HttpServerTest.websocketReceiveSpan +import static datadog.trace.agent.test.base.HttpServerTest.websocketSendSpan + @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT, classes = [SpringWebFluxTestApplication, ForceNettyAutoConfiguration], properties = "server.http2.enabled=true") @@ -40,13 +48,22 @@ class SpringWebfluxTest extends AgentTestRunner { @LocalServerPort int port - WebClient client = WebClient.builder().clientConnector (new ReactorClientHttpConnector()).build() + @Autowired + private WsHandler wsHandler + + WebClient client = WebClient.builder().clientConnector(new ReactorClientHttpConnector()).build() @Override boolean useStrictTraceWrites() { false } + @Override + protected void configurePreAgent() { + super.configurePreAgent() + injectSysConfig("trace.websocket.messages.enabled", "true") + } + def "Basic GET test #testName"() { setup: String url = "http://localhost:$port$urlPath" @@ -61,7 +78,7 @@ class SpringWebfluxTest extends AgentTestRunner { sortSpansByStart() trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url)) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) } trace(2) { span { @@ -142,7 +159,7 @@ class SpringWebfluxTest extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url)) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) } trace(3) { span { @@ -237,7 +254,7 @@ class SpringWebfluxTest extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url)) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) } trace(3) { span { @@ -285,7 +302,7 @@ class SpringWebfluxTest extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), 404, true) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 404, true) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 404, true) } trace(2) { span { @@ -331,7 +348,7 @@ class SpringWebfluxTest extends AgentTestRunner { String url = "http://localhost:$port/echo" when: - def response = client.post().uri(url).body(BodyInserters.fromPublisher(Mono.just(echoString),String)).exchange().block() + def response = client.post().uri(url).body(BodyInserters.fromPublisher(Mono.just(echoString), String)).exchange().block() then: response.statusCode().value() == 202 @@ -341,7 +358,7 @@ class SpringWebfluxTest extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "POST", URI.create(url), 202) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "POST", URI.create(url), 202) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "POST", URI.create(url), 202) } trace(3) { span { @@ -406,7 +423,7 @@ class SpringWebfluxTest extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), 500) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 500) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 500) } trace(2) { span { @@ -495,7 +512,7 @@ class SpringWebfluxTest extends AgentTestRunner { trace(2) { sortSpansByStart() clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), 307) - traceParent1 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 307) + traceParent1 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 307) } trace(2) { @@ -540,7 +557,7 @@ class SpringWebfluxTest extends AgentTestRunner { trace(2) { sortSpansByStart() clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(finalUrl)) - traceParent2 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(finalUrl)) + traceParent2 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(finalUrl)) } trace(2) { sortSpansByStart() @@ -599,7 +616,7 @@ class SpringWebfluxTest extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url)) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) } trace(2) { span { @@ -660,6 +677,73 @@ class SpringWebfluxTest extends AgentTestRunner { "annotation API delayed response" | "/foo-delayed" | "/foo-delayed" | "getFooDelayed" | new FooModel(3L, "delayed").toString() } + def 'test websocket server receive #msgType message of size #size and #chunks chunks'() { + when: + String url = "http://localhost:$port/websocket" + def wsClient = new OkHttpWebsocketClient() + wsClient.connect(url) + wsHandler.awaitConnected() + if (message instanceof String) { + wsClient.send(message as String) + } else { + wsClient.send(message as byte[]) + } + wsHandler.awaitExchangeComplete() + wsClient.close(1001, "goodbye") + + then: + assertTraces(3, { + DDSpan handshake + trace(2) { + sortSpansByStart() + handshake = span(0) + span { + resourceName "GET /websocket" + operationName "netty.request" + spanType DDSpanTypes.HTTP_SERVER + tags { + "$Tags.COMPONENT" "netty" + "$Tags.SPAN_KIND" Tags.SPAN_KIND_SERVER + "$Tags.PEER_HOST_IPV4" "127.0.0.1" + "$Tags.PEER_PORT" Integer + "$Tags.HTTP_URL" url + "$Tags.HTTP_HOSTNAME" "localhost" + "$Tags.HTTP_METHOD" "GET" + "$Tags.HTTP_STATUS" 101 + "$Tags.HTTP_USER_AGENT" String + "$Tags.HTTP_CLIENT_IP" "127.0.0.1" + "$Tags.HTTP_ROUTE" "/websocket" + defaultTags() + } + } + span { + resourceName "WsHandler.handle" + operationName "WsHandler.handle" + spanType DDSpanTypes.HTTP_SERVER + childOfPrevious() + tags { + "$Tags.COMPONENT" "spring-webflux-controller" + "$Tags.SPAN_KIND" Tags.SPAN_KIND_SERVER + "handler.type" WsHandler.getName() + defaultTags() + } + } + } + trace(2) { + sortSpansByStart() + websocketReceiveSpan(it, handshake, msgType, size, chunks) + websocketSendSpan(it, handshake, msgType, size, chunks) + } + trace(1) { + websocketCloseSpan(it, handshake, false, 1001, "goodbye") + } + }) + where: + message | msgType | chunks | size + RandomString.make(10) | "text" | 1 | 10 + RandomString.make(20).getBytes("UTF-8") | "binary" | 1 | 20 + } + def clientSpan( TraceAssert trace, Object parentSpan, diff --git a/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/SpringWebFluxTestApplication.groovy b/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/SpringWebFluxTestApplication.groovy index aa3b3135500..1c8f5c0ec15 100644 --- a/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/SpringWebFluxTestApplication.groovy +++ b/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/SpringWebFluxTestApplication.groovy @@ -5,11 +5,15 @@ import org.springframework.boot.autoconfigure.SpringBootApplication import org.springframework.context.annotation.Bean import org.springframework.http.MediaType import org.springframework.stereotype.Component +import org.springframework.web.reactive.HandlerMapping import org.springframework.web.reactive.function.BodyInserters import org.springframework.web.reactive.function.server.HandlerFunction import org.springframework.web.reactive.function.server.RouterFunction import org.springframework.web.reactive.function.server.ServerRequest import org.springframework.web.reactive.function.server.ServerResponse +import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping +import org.springframework.web.reactive.socket.WebSocketHandler +import org.springframework.web.reactive.socket.server.support.WebSocketHandlerAdapter import reactor.core.publisher.Mono import java.time.Duration @@ -26,6 +30,22 @@ class SpringWebFluxTestApplication { return route(POST("/echo"), new EchoHandlerFunction(echoHandler)) } + @Bean + WebSocketHandlerAdapter webSocketHandlerAdapter() { + return new WebSocketHandlerAdapter() + } + + @Bean + HandlerMapping wsHandlerMapping(WsHandler wsHandler) { + Map map = new HashMap<>() + map.put("/websocket", wsHandler) + + SimpleUrlHandlerMapping handlerMapping = new SimpleUrlHandlerMapping() + handlerMapping.setOrder(1) + handlerMapping.setUrlMap(map) + return handlerMapping + } + @Bean RouterFunction greetRouterFunction(GreetingHandler greetingHandler) { return route(GET("/greet"), new HandlerFunction() { diff --git a/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/WsHandler.java b/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/WsHandler.java new file mode 100644 index 00000000000..89cc8f37998 --- /dev/null +++ b/dd-java-agent/instrumentation/spring-webflux-5/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/WsHandler.java @@ -0,0 +1,73 @@ +package dd.trace.instrumentation.springwebflux.server; + +import org.springframework.stereotype.Component; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketMessage; +import org.springframework.web.reactive.socket.WebSocketSession; +import reactor.core.publisher.Mono; + +@Component +public class WsHandler implements WebSocketHandler { + + volatile WebSocketSession activeSession; + + @Override + public Mono handle(WebSocketSession webSocketSession) { + // An echo server the closes after the first echoed message + activeSession = webSocketSession; + synchronized (this) { + notifyAll(); + } + return webSocketSession + .receive() + .map( + msg -> { + if (msg.getType() == WebSocketMessage.Type.TEXT) { + return webSocketSession.textMessage(msg.getPayloadAsText()); + } + byte[] bytes = new byte[msg.getPayload().readableByteCount()]; + msg.getPayload().read(bytes); + return webSocketSession.binaryMessage( + dataBufferFactory -> dataBufferFactory.wrap(bytes)); + }) + .flatMap( + msg -> + webSocketSession + .send(Mono.just(msg)) + .doFinally( + ignored -> { + synchronized (this) { + activeSession = null; + notifyAll(); + } + })) + .then(); + } + + public void awaitExchangeComplete() { + synchronized (this) { + if (activeSession == null) { + return; + } + try { + wait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } + + public void awaitConnected() { + synchronized (this) { + if (activeSession != null) { + return; + } + try { + wait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + assert activeSession != null; + } +} diff --git a/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/SpringWebfluxTest.groovy b/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/SpringWebfluxTest.groovy index dc1e6d5eda9..c4a504e1d8f 100644 --- a/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/SpringWebfluxTest.groovy +++ b/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/SpringWebfluxTest.groovy @@ -1,3 +1,11 @@ +import datadog.trace.agent.test.base.OkHttpWebsocketClient +import dd.trace.instrumentation.springwebflux.server.WsHandler +import net.bytebuddy.utility.RandomString +import org.springframework.beans.factory.annotation.Autowired + +import static datadog.trace.agent.test.base.HttpServerTest.websocketCloseSpan +import static datadog.trace.agent.test.base.HttpServerTest.websocketReceiveSpan +import static datadog.trace.agent.test.base.HttpServerTest.websocketSendSpan import static datadog.trace.api.config.TraceInstrumentationConfig.TRACE_ANNOTATION_ASYNC import datadog.trace.agent.test.AgentTestRunner @@ -38,12 +46,16 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { @LocalServerPort int port - WebClient client = WebClient.builder().clientConnector (new ReactorClientHttpConnector(buildClient())).build() + @Autowired + WsHandler wsHandler + + WebClient client = WebClient.builder().clientConnector(new ReactorClientHttpConnector(buildClient())).build() @Override protected void configurePreAgent() { super.configurePreAgent() injectSysConfig(TRACE_ANNOTATION_ASYNC, "true") + injectSysConfig("trace.websocket.messages.enabled", "true") } def "Basic GET test #testName"() { @@ -60,7 +72,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { sortSpansByStart() trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url)) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) } trace(2) { span { @@ -141,7 +153,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url)) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) } trace(3) { span { @@ -235,7 +247,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url)) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) } trace(3) { span { @@ -283,7 +295,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), 404, true) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 404, true) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 404, true) } trace(2) { span { @@ -329,7 +341,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { String url = "http://localhost:$port/echo" when: - def response = client.post().uri(url).body(BodyInserters.fromPublisher(Mono.just(echoString),String)).exchange().block() + def response = client.post().uri(url).body(BodyInserters.fromPublisher(Mono.just(echoString), String)).exchange().block() then: response.statusCode().value() == 202 @@ -339,7 +351,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "POST", URI.create(url), 202) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "POST", URI.create(url), 202) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "POST", URI.create(url), 202) } trace(3) { span { @@ -404,7 +416,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), 500) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 500) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 500) } trace(2) { span { @@ -493,7 +505,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { trace(2) { sortSpansByStart() clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), 307) - traceParent1 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 307) + traceParent1 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), 307) } trace(2) { @@ -538,7 +550,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { trace(2) { sortSpansByStart() clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(finalUrl)) - traceParent2 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(finalUrl)) + traceParent2 = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(finalUrl)) } trace(2) { sortSpansByStart() @@ -597,7 +609,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { def traceParent trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url)) - traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) + traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url)) } trace(2) { span { @@ -672,7 +684,7 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { sortSpansByStart() trace(2) { clientSpan(it, null, "http.request", "spring-webflux-client", "GET", URI.create(url), null, false, null, false, - [ "message":"The subscription was cancelled", "event":"cancelled"]) + ["message": "The subscription was cancelled", "event": "cancelled"]) traceParent = clientSpan(it, span(0), "netty.client.request", "netty-client", "GET", URI.create(url), null) } trace(2) { @@ -711,6 +723,74 @@ class SpringWebfluxHttp11Test extends AgentTestRunner { } } + def 'test websocket server receive #msgType message of size #size and #chunks chunks'() { + when: + String url = "http://localhost:$port/websocket" + def wsClient = new OkHttpWebsocketClient() + wsClient.connect(url) + wsHandler.awaitConnected() + if (message instanceof String) { + wsClient.send(message as String) + } else { + wsClient.send(message as byte[]) + } + wsHandler.awaitExchangeComplete() + wsClient.close(1001, "goodbye") + + then: + assertTraces(3, { + DDSpan handshake + trace(2) { + sortSpansByStart() + handshake = span(0) + span { + resourceName "GET /websocket" + operationName "netty.request" + spanType DDSpanTypes.HTTP_SERVER + tags { + "$Tags.COMPONENT" "netty" + "$Tags.SPAN_KIND" Tags.SPAN_KIND_SERVER + "$Tags.PEER_HOST_IPV4" "127.0.0.1" + "$Tags.PEER_PORT" Integer + "$Tags.HTTP_URL" url + "$Tags.HTTP_HOSTNAME" "localhost" + "$Tags.HTTP_METHOD" "GET" + "$Tags.HTTP_STATUS" 101 + "$Tags.HTTP_USER_AGENT" String + "$Tags.HTTP_CLIENT_IP" "127.0.0.1" + "$Tags.HTTP_ROUTE" "/websocket" + defaultTags() + } + } + span { + resourceName "WsHandler.handle" + operationName "WsHandler.handle" + spanType DDSpanTypes.HTTP_SERVER + childOfPrevious() + tags { + "$Tags.COMPONENT" "spring-webflux-controller" + "$Tags.SPAN_KIND" Tags.SPAN_KIND_SERVER + "handler.type" WsHandler.getName() + defaultTags() + } + } + } + trace(2) { + sortSpansByStart() + websocketReceiveSpan(it, handshake, msgType, size, chunks) + websocketSendSpan(it, handshake, msgType, size, chunks) + } + trace(1) { + websocketCloseSpan(it, handshake, false, 1001, "goodbye") + } + }) + where: + message | msgType | chunks | size + RandomString.make(10) | "text" | 1 | 10 + RandomString.make(20).getBytes("UTF-8") | "binary" | 1 | 20 + } + + def clientSpan( TraceAssert trace, Object parentSpan, diff --git a/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/SpringWebFluxTestApplication.groovy b/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/SpringWebFluxTestApplication.groovy index aa3b3135500..bba1477d577 100644 --- a/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/SpringWebFluxTestApplication.groovy +++ b/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/SpringWebFluxTestApplication.groovy @@ -5,11 +5,14 @@ import org.springframework.boot.autoconfigure.SpringBootApplication import org.springframework.context.annotation.Bean import org.springframework.http.MediaType import org.springframework.stereotype.Component +import org.springframework.web.reactive.HandlerMapping import org.springframework.web.reactive.function.BodyInserters import org.springframework.web.reactive.function.server.HandlerFunction import org.springframework.web.reactive.function.server.RouterFunction import org.springframework.web.reactive.function.server.ServerRequest import org.springframework.web.reactive.function.server.ServerResponse +import org.springframework.web.reactive.handler.SimpleUrlHandlerMapping +import org.springframework.web.reactive.socket.WebSocketHandler import reactor.core.publisher.Mono import java.time.Duration @@ -21,6 +24,17 @@ import static org.springframework.web.reactive.function.server.RouterFunctions.r @SpringBootApplication class SpringWebFluxTestApplication { + @Bean + HandlerMapping wsHandlerMapping(WsHandler wsHandler) { + Map map = new HashMap<>() + map.put("/websocket", wsHandler) + + SimpleUrlHandlerMapping handlerMapping = new SimpleUrlHandlerMapping() + handlerMapping.setOrder(1) + handlerMapping.setUrlMap(map) + return handlerMapping + } + @Bean RouterFunction echoRouterFunction(EchoHandler echoHandler) { return route(POST("/echo"), new EchoHandlerFunction(echoHandler)) diff --git a/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/WsHandler.java b/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/WsHandler.java new file mode 100644 index 00000000000..89cc8f37998 --- /dev/null +++ b/dd-java-agent/instrumentation/spring-webflux-6/src/bootTest/groovy/dd/trace/instrumentation/springwebflux/server/WsHandler.java @@ -0,0 +1,73 @@ +package dd.trace.instrumentation.springwebflux.server; + +import org.springframework.stereotype.Component; +import org.springframework.web.reactive.socket.WebSocketHandler; +import org.springframework.web.reactive.socket.WebSocketMessage; +import org.springframework.web.reactive.socket.WebSocketSession; +import reactor.core.publisher.Mono; + +@Component +public class WsHandler implements WebSocketHandler { + + volatile WebSocketSession activeSession; + + @Override + public Mono handle(WebSocketSession webSocketSession) { + // An echo server the closes after the first echoed message + activeSession = webSocketSession; + synchronized (this) { + notifyAll(); + } + return webSocketSession + .receive() + .map( + msg -> { + if (msg.getType() == WebSocketMessage.Type.TEXT) { + return webSocketSession.textMessage(msg.getPayloadAsText()); + } + byte[] bytes = new byte[msg.getPayload().readableByteCount()]; + msg.getPayload().read(bytes); + return webSocketSession.binaryMessage( + dataBufferFactory -> dataBufferFactory.wrap(bytes)); + }) + .flatMap( + msg -> + webSocketSession + .send(Mono.just(msg)) + .doFinally( + ignored -> { + synchronized (this) { + activeSession = null; + notifyAll(); + } + })) + .then(); + } + + public void awaitExchangeComplete() { + synchronized (this) { + if (activeSession == null) { + return; + } + try { + wait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } + + public void awaitConnected() { + synchronized (this) { + if (activeSession != null) { + return; + } + try { + wait(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + assert activeSession != null; + } +}