Skip to content

Commit 220b90c

Browse files
authored
mcp: add configured header attribtues to metrics and spans (#1342)
**Description** The LLM metrics and trace providers allow for the configuration of custom headers to be included as metric attributes and span attributes. The MCP implementation, however, did not honor those values. This PR updates the MCP metrics and tracer configurations to also take this configuration into account. * The tracer method to create a span now receives the request headers. * To avoid having to propagate request headers all the way down inside the MCP Proxy so that they can be taken into account when recording metrics, I've opted to instantiate an MCPProxy on each request, and have it set up with the request attributes from the beginning. This allows reusing the metrics implementation in a clean way. **Related Issues/PRs (if applicable)** Fixes #1335 Fixes #1348 **Special notes for reviewers (if applicable)** N/A --------- Signed-off-by: Ignasi Barrera <[email protected]>
1 parent 3341eb8 commit 220b90c

19 files changed

+418
-170
lines changed

cmd/extproc/mainlib/main.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
233233
chatCompletionMetrics := metrics.NewChatCompletion(meter, metricsRequestHeaderAttributes)
234234
completionMetrics := metrics.NewCompletion(meter, metricsRequestHeaderAttributes)
235235
embeddingsMetrics := metrics.NewEmbeddings(meter, metricsRequestHeaderAttributes)
236-
mcpMetrics := metrics.NewMCP(meter)
236+
mcpMetrics := metrics.NewMCP(meter, metricsRequestHeaderAttributes)
237237

238238
tracing, err := tracing.NewTracingFromEnv(ctx, os.Stdout, spanRequestHeaderAttributes)
239239
if err != nil {
@@ -264,13 +264,13 @@ func Main(ctx context.Context, args []string, stderr io.Writer) (err error) {
264264
seed, fallbackSeed, _ := strings.Cut(flags.mcpSessionEncryptionSeed, ",")
265265
mcpSessionCrypto := mcpproxy.DefaultSessionCrypto(seed, fallbackSeed)
266266
var mcpProxyMux *http.ServeMux
267-
var mcpProxy *mcpproxy.MCPProxy
268-
mcpProxy, mcpProxyMux, err = mcpproxy.NewMCPProxy(l.With("component", "mcp-proxy"), mcpMetrics,
267+
var mcpProxyConfig *mcpproxy.ProxyConfig
268+
mcpProxyConfig, mcpProxyMux, err = mcpproxy.NewMCPProxy(l.With("component", "mcp-proxy"), mcpMetrics,
269269
tracing.MCPTracer(), mcpSessionCrypto)
270270
if err != nil {
271271
return fmt.Errorf("failed to create MCP proxy: %w", err)
272272
}
273-
if err = extproc.StartConfigWatcher(ctx, flags.configPath, mcpProxy, l, time.Second*5); err != nil {
273+
if err = extproc.StartConfigWatcher(ctx, flags.configPath, mcpProxyConfig, l, time.Second*5); err != nil {
274274
return fmt.Errorf("failed to start config watcher: %w", err)
275275
}
276276

internal/lang/maps.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright Envoy AI Gateway Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
// The full text of the Apache license is available in the LICENSE file at
4+
// the root of the repo.
5+
6+
package lang
7+
8+
import (
9+
"fmt"
10+
"maps"
11+
"slices"
12+
"strings"
13+
)
14+
15+
// CaseInsensitiveValue retrieves a value from the meta map in a case-insensitive manner.
16+
// If the same key is present in different cases, the first one in alphabetical order
17+
// that matches is returned.
18+
// If the key is not found, it returns an empty string.
19+
func CaseInsensitiveValue(m map[string]any, key string) string {
20+
if m == nil {
21+
return ""
22+
}
23+
24+
if v, ok := m[key]; ok {
25+
return fmt.Sprintf("%v", v)
26+
}
27+
28+
keys := slices.Sorted(maps.Keys(m))
29+
for _, k := range keys {
30+
if strings.EqualFold(k, key) {
31+
return fmt.Sprintf("%v", m[k])
32+
}
33+
}
34+
35+
return ""
36+
}

internal/lang/maps_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright Envoy AI Gateway Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
// The full text of the Apache license is available in the LICENSE file at
4+
// the root of the repo.
5+
6+
package lang
7+
8+
import "testing"
9+
10+
func TestCaseInsensitiveValue(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
m map[string]any
14+
key string
15+
want string
16+
}{
17+
{
18+
name: "nil map",
19+
m: nil,
20+
key: "anything",
21+
want: "",
22+
},
23+
{
24+
name: "exact match returns value",
25+
m: map[string]any{"Foo": "bar", "foo": "should-not-be-used"},
26+
key: "Foo",
27+
want: "bar",
28+
},
29+
{
30+
name: "case-insensitive match when exact not present",
31+
m: map[string]any{"FOO": "baz"},
32+
key: "foo",
33+
want: "baz",
34+
},
35+
{
36+
name: "multiple case variants - alphabetical first chosen",
37+
m: map[string]any{"ALPHA": 2, "Alpha": 1},
38+
key: "alpha",
39+
want: "2", // ALPHA is alphabetically first
40+
},
41+
{
42+
name: "nil value formatted",
43+
m: map[string]any{"key": nil},
44+
key: "key",
45+
want: "<nil>",
46+
},
47+
}
48+
49+
for _, tc := range tests {
50+
t.Run(tc.name, func(t *testing.T) {
51+
got := CaseInsensitiveValue(tc.m, tc.key)
52+
if got != tc.want {
53+
t.Fatalf("CaseInsensitiveValue(%v, %q) = %q; want %q", tc.m, tc.key, got, tc.want)
54+
}
55+
})
56+
}
57+
}

internal/mcpproxy/handlers.go

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
107107
errType metrics.MCPErrorType
108108
requestMethod string
109109
span tracing.MCPSpan
110+
params mcp.Params
110111
)
111112
defer func() {
112113
if m.l.Enabled(ctx, slog.LevelDebug) {
@@ -119,17 +120,17 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
119120
if span != nil {
120121
span.EndSpanOnError(string(errType), err)
121122
}
122-
m.metrics.RecordMethodErrorCount(ctx)
123-
m.metrics.RecordRequestErrorDuration(ctx, &startAt, errType)
123+
m.metrics.RecordMethodErrorCount(ctx, params)
124+
m.metrics.RecordRequestErrorDuration(ctx, &startAt, errType, params)
124125
return
125126
}
126127

127128
if span != nil {
128129
span.EndSpan()
129130
}
130-
m.metrics.RecordRequestDuration(ctx, &startAt)
131+
m.metrics.RecordRequestDuration(ctx, &startAt, params)
131132
// TODO: should we special case when this request is "Response" where method is empty?
132-
m.metrics.RecordMethodCount(ctx, requestMethod)
133+
m.metrics.RecordMethodCount(ctx, requestMethod, params)
133134
}()
134135
if sessionID := r.Header.Get(sessionIDHeader); sessionID != "" {
135136
s, err = m.sessionFromID(secureClientToGatewaySessionID(sessionID), secureClientToGatewayEventID(r.Header.Get(lastEventIDHeader)))
@@ -189,37 +190,37 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
189190

190191
switch msg.Method {
191192
case "notifications/roots/list_changed":
192-
p := &mcp.RootsListChangedParams{}
193-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
193+
params = &mcp.RootsListChangedParams{}
194+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
194195
if err != nil {
195196
errType = metrics.MCPErrorInvalidParam
196197
onErrorResponse(w, http.StatusBadRequest, "invalid params")
197198
return
198199
}
199200
err = m.handleNotificationsRootsListChanged(ctx, s, w, msg, span)
200201
case "completion/complete":
201-
p := &mcp.CompleteParams{}
202-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
202+
params = &mcp.CompleteParams{}
203+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
203204
if err != nil {
204205
errType = metrics.MCPErrorInvalidParam
205206
onErrorResponse(w, http.StatusBadRequest, "invalid params")
206207
return
207208
}
208-
err = m.handleCompletionComplete(ctx, s, w, msg, p, span)
209+
err = m.handleCompletionComplete(ctx, s, w, msg, params.(*mcp.CompleteParams), span)
209210
case "notifications/progress":
210-
m.metrics.RecordProgress(ctx)
211-
p := &mcp.ProgressNotificationParams{}
212-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
211+
params = &mcp.ProgressNotificationParams{}
212+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
213+
m.metrics.RecordProgress(ctx, params)
213214
if err != nil {
214215
errType = metrics.MCPErrorInvalidParam
215216
onErrorResponse(w, http.StatusBadRequest, "invalid params")
216217
return
217218
}
218-
err = m.handleClientToServerNotificationsProgress(ctx, s, w, msg, p, span)
219+
err = m.handleClientToServerNotificationsProgress(ctx, s, w, msg, params.(*mcp.ProgressNotificationParams), span)
219220
case "initialize":
220221
// The very first request from the client to establish a session.
221-
p := &mcp.InitializeParams{}
222-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
222+
params = &mcp.InitializeParams{}
223+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
223224
if err != nil {
224225
errType = metrics.MCPErrorInvalidParam
225226
m.l.Error("Failed to unmarshal initialize params", slog.String("error", err.Error()))
@@ -235,107 +236,107 @@ func (m *MCPProxy) servePOST(w http.ResponseWriter, r *http.Request) {
235236
onErrorResponse(w, http.StatusInternalServerError, "missing route header")
236237
return
237238
}
238-
err = m.handleInitializeRequest(ctx, w, msg, p, route, extractSubject(r), span)
239+
err = m.handleInitializeRequest(ctx, w, msg, params.(*mcp.InitializeParams), route, extractSubject(r), span)
239240
case "notifications/initialized":
240241
// According to the MCP spec, when the server receives a JSON-RPC response or notification from the client
241242
// and accepts it, the server MUST return HTTP 202 Accepted with an empty body.
242243
// https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server
243244
w.WriteHeader(http.StatusAccepted)
244245
case "logging/setLevel":
245-
p := &mcp.SetLoggingLevelParams{}
246-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
246+
params = &mcp.SetLoggingLevelParams{}
247+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
247248
if err != nil {
248249
errType = metrics.MCPErrorInvalidParam
249250
m.l.Error("Failed to unmarshal set logging level params", slog.String("error", err.Error()))
250251
onErrorResponse(w, http.StatusBadRequest, "invalid set logging level params")
251252
return
252253
}
253-
err = m.handleSetLoggingLevel(ctx, s, w, msg, p, span)
254+
err = m.handleSetLoggingLevel(ctx, s, w, msg, params.(*mcp.SetLoggingLevelParams), span)
254255
case "ping":
255256
// Ping is intentionally not traced as it's a lightweight health check.
256257
err = m.handlePing(ctx, w, msg)
257258
case "prompts/list":
258-
p := &mcp.ListPromptsParams{}
259-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
259+
params = &mcp.ListPromptsParams{}
260+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
260261
if err != nil {
261262
errType = metrics.MCPErrorInvalidParam
262263
onErrorResponse(w, http.StatusBadRequest, "invalid params")
263264
return
264265
}
265-
err = m.handlePromptListRequest(ctx, s, w, msg, p, span)
266+
err = m.handlePromptListRequest(ctx, s, w, msg, params.(*mcp.ListPromptsParams), span)
266267
case "prompts/get":
267-
p := &mcp.GetPromptParams{}
268-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
268+
params = &mcp.GetPromptParams{}
269+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
269270
if err != nil {
270271
errType = metrics.MCPErrorInvalidParam
271272
onErrorResponse(w, http.StatusBadRequest, "invalid params")
272273
return
273274
}
274-
err = m.handlePromptGetRequest(ctx, s, w, msg, p)
275+
err = m.handlePromptGetRequest(ctx, s, w, msg, params.(*mcp.GetPromptParams))
275276
case "tools/call":
276-
p := &mcp.CallToolParams{}
277-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
277+
params = &mcp.CallToolParams{}
278+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
278279
if err != nil {
279280
errType = metrics.MCPErrorInvalidParam
280281
m.l.Error("Failed to unmarshal params", slog.String("method", msg.Method), slog.String("error", err.Error()))
281282
onErrorResponse(w, http.StatusBadRequest, "invalid params")
282283
return
283284
}
284-
err = m.handleToolCallRequest(ctx, s, w, msg, p, span)
285+
err = m.handleToolCallRequest(ctx, s, w, msg, params.(*mcp.CallToolParams), span)
285286
case "tools/list":
286-
p := &mcp.ListToolsParams{}
287-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
287+
params = &mcp.ListToolsParams{}
288+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
288289
if err != nil {
289290
errType = metrics.MCPErrorInvalidParam
290291
onErrorResponse(w, http.StatusBadRequest, "invalid params")
291292
return
292293
}
293-
err = m.handleToolsListRequest(ctx, s, w, msg, p, span)
294+
err = m.handleToolsListRequest(ctx, s, w, msg, params.(*mcp.ListToolsParams), span)
294295
case "resources/list":
295-
p := &mcp.ListResourcesParams{}
296-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
296+
params = &mcp.ListResourcesParams{}
297+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
297298
if err != nil {
298299
errType = metrics.MCPErrorInvalidParam
299300
onErrorResponse(w, http.StatusBadRequest, "invalid params")
300301
return
301302
}
302-
err = m.handleResourceListRequest(ctx, s, w, msg, p, span)
303+
err = m.handleResourceListRequest(ctx, s, w, msg, params.(*mcp.ListResourcesParams), span)
303304
case "resources/read":
304-
p := &mcp.ReadResourceParams{}
305-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
305+
params = &mcp.ReadResourceParams{}
306+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
306307
if err != nil {
307308
errType = metrics.MCPErrorInvalidParam
308309
onErrorResponse(w, http.StatusBadRequest, "invalid params")
309310
return
310311
}
311-
err = m.handleResourceReadRequest(ctx, s, w, msg, p)
312+
err = m.handleResourceReadRequest(ctx, s, w, msg, params.(*mcp.ReadResourceParams))
312313
case "resources/templates/list":
313-
p := &mcp.ListResourceTemplatesParams{}
314-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
314+
params = &mcp.ListResourceTemplatesParams{}
315+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
315316
if err != nil {
316317
errType = metrics.MCPErrorInvalidParam
317318
onErrorResponse(w, http.StatusBadRequest, "invalid params")
318319
return
319320
}
320-
err = m.handleResourcesTemplatesListRequest(ctx, s, w, msg, p, span)
321+
err = m.handleResourcesTemplatesListRequest(ctx, s, w, msg, params.(*mcp.ListResourceTemplatesParams), span)
321322
case "resources/subscribe":
322-
p := &mcp.SubscribeParams{}
323-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
323+
params = &mcp.SubscribeParams{}
324+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
324325
if err != nil {
325326
errType = metrics.MCPErrorInvalidParam
326327
onErrorResponse(w, http.StatusBadRequest, "invalid params")
327328
return
328329
}
329-
err = m.handleResourcesSubscribeRequest(ctx, s, w, msg, p, span)
330+
err = m.handleResourcesSubscribeRequest(ctx, s, w, msg, params.(*mcp.SubscribeParams), span)
330331
case "resources/unsubscribe":
331-
p := &mcp.UnsubscribeParams{}
332-
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, p)
332+
params = &mcp.UnsubscribeParams{}
333+
span, err = parseParamsAndMaybeStartSpan(ctx, m, msg, params, r.Header)
333334
if err != nil {
334335
errType = metrics.MCPErrorInvalidParam
335336
onErrorResponse(w, http.StatusBadRequest, "invalid params")
336337
return
337338
}
338-
err = m.handleResourcesUnsubscribeRequest(ctx, s, w, msg, p, span)
339+
err = m.handleResourcesUnsubscribeRequest(ctx, s, w, msg, params.(*mcp.UnsubscribeParams), span)
339340
case "notifications/cancelled":
340341
// The responsibility of cancelling the operation on server side is optional, so we just ignore it for now.
341342
// https://modelcontextprotocol.io/specification/2025-06-18/basic/utilities/cancellation#behavior-requirements
@@ -371,8 +372,7 @@ func errorType(err error) metrics.MCPErrorType {
371372

372373
// handleInitializeRequest handles the "initialize" JSON-RPC method.
373374
func (m *MCPProxy) handleInitializeRequest(ctx context.Context, w http.ResponseWriter, req *jsonrpc.Request, p *mcp.InitializeParams, route, subject string, span tracing.MCPSpan) error {
374-
m.metrics.RecordClientCapabilities(ctx, p.Capabilities)
375-
375+
m.metrics.RecordClientCapabilities(ctx, p.Capabilities, p)
376376
s, err := m.newSession(ctx, p, route, subject, span)
377377
if err != nil {
378378
m.l.Error("failed to create new session", slog.String("error", err.Error()))
@@ -789,19 +789,23 @@ func (m *MCPProxy) recordResponse(ctx context.Context, rawMsg jsonrpc.Message) {
789789
case "notifications/resources/list_changed":
790790
case "notifications/resources/updated":
791791
case "notifications/progress":
792-
m.metrics.RecordProgress(ctx)
792+
params := &mcp.ProgressNotificationParams{}
793+
if err := json.Unmarshal(msg.Params, &params); err != nil {
794+
m.l.Error("Failed to unmarshal params", slog.String("method", msg.Method), slog.String("error", err.Error()))
795+
}
796+
m.metrics.RecordProgress(ctx, params)
793797
case "notifications/message":
794798
case "notifications/tools/list_changed":
795799
case "roots/list":
796800
case "sampling/createMessage":
797801
case "elicitation/create":
798802
default:
799803
knownMethod = false
800-
m.metrics.RecordMethodErrorCount(ctx)
804+
m.metrics.RecordMethodErrorCount(ctx, nil)
801805
m.l.Warn("Unsupported MCP request method from server", slog.String("method", msg.Method))
802806
}
803807
if knownMethod {
804-
m.metrics.RecordMethodCount(ctx, msg.Method)
808+
m.metrics.RecordMethodCount(ctx, msg.Method, nil)
805809
}
806810
default:
807811
m.l.Warn("unexpected message type in MCP response", slog.Any("message", msg))
@@ -1223,7 +1227,7 @@ func sendToAllBackendsAndAggregateResponsesImpl[responseType any](ctx context.Co
12231227
}
12241228

12251229
// parseParamsAndMaybeStartSpan parses the params from the JSON-RPC request and starts a tracing span if params is non-nil.
1226-
func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *MCPProxy, req *jsonrpc.Request, p paramType) (tracing.MCPSpan, error) {
1230+
func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *MCPProxy, req *jsonrpc.Request, p paramType, headers http.Header) (tracing.MCPSpan, error) {
12271231
if req.Params == nil {
12281232
return nil, nil
12291233
}
@@ -1233,7 +1237,7 @@ func parseParamsAndMaybeStartSpan[paramType mcp.Params](ctx context.Context, m *
12331237
return nil, err
12341238
}
12351239

1236-
span := m.tracer.StartSpanAndInjectMeta(ctx, req, p)
1240+
span := m.tracer.StartSpanAndInjectMeta(ctx, req, p, headers)
12371241
return span, nil
12381242
}
12391243

0 commit comments

Comments
 (0)