diff --git a/contrib/mark3labs/mcp-go/README.md b/contrib/mark3labs/mcp-go/README.md index eb27083d6a..d23ab18c9d 100644 --- a/contrib/mark3labs/mcp-go/README.md +++ b/contrib/mark3labs/mcp-go/README.md @@ -2,6 +2,8 @@ This integration provides Datadog tracing for the [mark3labs/mcp-go](https://github.com/mark3labs/mcp-go) library. +Both hooks and middleware are used. + ## Usage ```go @@ -12,15 +14,19 @@ import ( ) func main() { - tracer.Start() if err := tracer.Start(); err != nil { log.Fatal(err) } defer tracer.Stop() - srv := server.NewMCPServer("my-server", "1.0.0", - server.WithToolHandlerMiddleware(mcpgotrace.NewToolHandlerMiddleware())) - _ = srv + // Add tracing to your server hooks + hooks := &server.Hooks{} + cleanup := mcpgotrace.AddServerHooks(hooks) + defer cleanup() + + srv := server.NewMCPServer("my-server", "1.0.0", + server.WithHooks(hooks), + server.WithToolHandlerMiddleware(mcpgotrace.NewToolHandlerMiddleware())) } ``` @@ -28,3 +34,4 @@ func main() { The integration automatically traces: - **Tool calls**: Creates LLMObs tool spans with input/output annotation for all tool invocations +- **Session initialization**: Create LLMObs task spans for session initialization, including client information. \ No newline at end of file diff --git a/contrib/mark3labs/mcp-go/example_test.go b/contrib/mark3labs/mcp-go/example_test.go index ff83a474c5..01a537253f 100644 --- a/contrib/mark3labs/mcp-go/example_test.go +++ b/contrib/mark3labs/mcp-go/example_test.go @@ -15,7 +15,13 @@ func Example() { tracer.Start() defer tracer.Stop() + // Create server hooks and add Datadog tracing + hooks := &server.Hooks{} + cleanup := mcpgotrace.AddServerHooks(hooks) + defer cleanup() + srv := server.NewMCPServer("my-server", "1.0.0", + server.WithHooks(hooks), server.WithToolHandlerMiddleware(mcpgotrace.NewToolHandlerMiddleware())) _ = srv } diff --git a/contrib/mark3labs/mcp-go/go.mod b/contrib/mark3labs/mcp-go/go.mod index a5563ab3dc..5bc77d6de9 100644 --- a/contrib/mark3labs/mcp-go/go.mod +++ b/contrib/mark3labs/mcp-go/go.mod @@ -4,6 +4,7 @@ go 1.24.0 require ( github.com/DataDog/dd-trace-go/v2 v2.4.0-dev + github.com/jellydator/ttlcache/v3 v3.4.0 github.com/mark3labs/mcp-go v0.42.0 github.com/stretchr/testify v1.11.1 ) @@ -86,6 +87,7 @@ require ( golang.org/x/exp v0.0.0-20250408133849-7e4ce0ab07d0 // indirect golang.org/x/mod v0.28.0 // indirect golang.org/x/net v0.44.0 // indirect + golang.org/x/sync v0.17.0 // indirect golang.org/x/sys v0.36.0 // indirect golang.org/x/text v0.30.0 // indirect golang.org/x/time v0.12.0 // indirect diff --git a/contrib/mark3labs/mcp-go/go.sum b/contrib/mark3labs/mcp-go/go.sum index d721f9a061..467e29ae0f 100644 --- a/contrib/mark3labs/mcp-go/go.sum +++ b/contrib/mark3labs/mcp-go/go.sum @@ -90,6 +90,8 @@ github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKe github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= +github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= @@ -273,6 +275,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/contrib/mark3labs/mcp-go/mcpgo.go b/contrib/mark3labs/mcp-go/mcpgo.go index 2c0c1589c9..1ec512b10f 100644 --- a/contrib/mark3labs/mcp-go/mcpgo.go +++ b/contrib/mark3labs/mcp-go/mcpgo.go @@ -8,9 +8,11 @@ package mcpgo // import "github.com/DataDog/dd-trace-go/contrib/mark3labs/mcp-go import ( "context" "encoding/json" + "time" "github.com/DataDog/dd-trace-go/v2/instrumentation" "github.com/DataDog/dd-trace-go/v2/llmobs" + "github.com/jellydator/ttlcache/v3" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -22,6 +24,27 @@ func init() { instr = instrumentation.Load(instrumentation.PackageMark3LabsMcpGo) } +type hooks struct { + spanCache *ttlcache.Cache[any, llmobs.Span] +} + +type textIOAnnotator interface { + AnnotateTextIO(input, output string, opts ...llmobs.AnnotateOption) +} + +// AddServerHooks appends Datadog tracing hooks to an existing server.Hooks object. +// Returns a cleanup function that should be called upon server shutdown. +func AddServerHooks(hooks *server.Hooks) func() { + ddHooks := newHooks() + hooks.AddBeforeInitialize(ddHooks.onBeforeInitialize) + hooks.AddAfterInitialize(ddHooks.onAfterInitialize) + hooks.AddOnError(ddHooks.onError) + + return func() { + ddHooks.stop() + } +} + func NewToolHandlerMiddleware() server.ToolHandlerMiddleware { return func(next server.ToolHandlerFunc) server.ToolHandlerFunc { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -48,3 +71,63 @@ func NewToolHandlerMiddleware() server.ToolHandlerMiddleware { } } } + +func newHooks() *hooks { + spanCache := ttlcache.New[any, llmobs.Span]( + ttlcache.WithTTL[any, llmobs.Span](5 * time.Minute), + ) + spanCache.OnEviction(func(ctx context.Context, reason ttlcache.EvictionReason, item *ttlcache.Item[any, llmobs.Span]) { + if span := item.Value(); span != nil { + if reason == ttlcache.EvictionReasonExpired { + span.Finish() + } + } + }) + go spanCache.Start() + + return &hooks{ + spanCache: spanCache, + } +} + +func (h *hooks) onBeforeInitialize(ctx context.Context, id any, request *mcp.InitializeRequest) { + taskSpan, _ := llmobs.StartTaskSpan(ctx, "mcp.initialize", llmobs.WithIntegration("mark3labs/mcp-go")) + h.spanCache.Set(id, taskSpan, ttlcache.DefaultTTL) +} + +func (h *hooks) onAfterInitialize(ctx context.Context, id any, request *mcp.InitializeRequest, result *mcp.InitializeResult) { + finishSpanWithIO(h, id, request, result) +} + +func (h *hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + if method == mcp.MethodInitialize { + if item := h.spanCache.Get(id); item != nil { + span := item.Value() + if annotator, ok := span.(textIOAnnotator); ok { + inputJSON, _ := json.Marshal(message) + annotator.AnnotateTextIO(string(inputJSON), err.Error()) + span.Finish(llmobs.WithError(err)) + } + h.spanCache.Delete(id) + } + } +} + +func (h *hooks) stop() { + h.spanCache.Stop() +} + +func finishSpanWithIO[Req any, Res any](h *hooks, id any, request Req, result Res) { + if item := h.spanCache.Get(id); item != nil { + span := item.Value() + if annotator, ok := span.(textIOAnnotator); ok { + inputJSON, _ := json.Marshal(request) + resultJSON, _ := json.Marshal(result) + outputText := string(resultJSON) + + annotator.AnnotateTextIO(string(inputJSON), outputText) + span.Finish() + } + h.spanCache.Delete(id) + } +} diff --git a/contrib/mark3labs/mcp-go/mcpgo_test.go b/contrib/mark3labs/mcp-go/mcpgo_test.go index a4aad3ce9b..743ee0795e 100644 --- a/contrib/mark3labs/mcp-go/mcpgo_test.go +++ b/contrib/mark3labs/mcp-go/mcpgo_test.go @@ -29,6 +29,74 @@ func TestNewToolHandlerMiddleware(t *testing.T) { assert.NotNil(t, middleware) } +func TestAddServerHooks(t *testing.T) { + mt := mocktracer.Start() + defer mt.Stop() + + serverHooks := &server.Hooks{} + cleanup := AddServerHooks(serverHooks) + defer cleanup() + + assert.Len(t, serverHooks.OnBeforeInitialize, 1) + assert.Len(t, serverHooks.OnAfterInitialize, 1) + assert.Len(t, serverHooks.OnError, 1) +} + +// Integration Tests + +func TestIntegrationSessionInitialize(t *testing.T) { + tt := testTracer(t) + defer tt.Stop() + + hooks := &server.Hooks{} + cleanup := AddServerHooks(hooks) + defer cleanup() + + srv := server.NewMCPServer("test-server", "1.0.0", + server.WithHooks(hooks)) + + ctx := context.Background() + initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}` + + response := srv.HandleMessage(ctx, []byte(initRequest)) + assert.NotNil(t, response) + + responseBytes, err := json.Marshal(response) + require.NoError(t, err) + + var resp map[string]interface{} + err = json.Unmarshal(responseBytes, &resp) + require.NoError(t, err) + assert.Equal(t, "2.0", resp["jsonrpc"]) + assert.Equal(t, float64(1), resp["id"]) + assert.NotNil(t, resp["result"]) + + spans := tt.WaitForLLMObsSpans(t, 1) + require.Len(t, spans, 1) + + taskSpan := spans[0] + assert.Equal(t, "mcp.initialize", taskSpan.Name) + assert.Equal(t, "task", taskSpan.Meta["span.kind"]) + + assert.Contains(t, taskSpan.Meta, "input") + assert.Contains(t, taskSpan.Meta, "output") + + inputMeta := taskSpan.Meta["input"] + assert.NotNil(t, inputMeta) + inputJSON, err := json.Marshal(inputMeta) + require.NoError(t, err) + inputStr := string(inputJSON) + assert.Contains(t, inputStr, "2024-11-05") + assert.Contains(t, inputStr, "test-client") + + outputMeta := taskSpan.Meta["output"] + assert.NotNil(t, outputMeta) + outputJSON, err := json.Marshal(outputMeta) + require.NoError(t, err) + outputStr := string(outputJSON) + assert.Contains(t, outputStr, "serverInfo") +} + func TestIntegrationToolCallSuccess(t *testing.T) { tt := testTracer(t) defer tt.Stop()