Skip to content

Commit 8df5bc5

Browse files
authored
Merge pull request #1725 from dgageot/headers
Support ${headers.NAME} syntax to forward upstream API headers to toolsets
2 parents 15ea9b4 + 07343ef commit 8df5bc5

File tree

8 files changed

+298
-37
lines changed

8 files changed

+298
-37
lines changed

pkg/connectrpc/server.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/docker/cagent/pkg/server"
2525
"github.com/docker/cagent/pkg/session"
2626
"github.com/docker/cagent/pkg/tools"
27+
"github.com/docker/cagent/pkg/upstream"
2728
)
2829

2930
// Server implements the Connect-RPC AgentService.
@@ -44,7 +45,8 @@ func (s *Server) Handler() http.Handler {
4445

4546
path, handler := cagentv1connect.NewAgentServiceHandler(s)
4647
mux.Handle(path, handler)
47-
return h2c.NewHandler(mux, &http2.Server{})
48+
49+
return upstream.Handler(h2c.NewHandler(mux, &http2.Server{}))
4850
}
4951

5052
// Serve starts the Connect-RPC server on the given listener.

pkg/server/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/docker/cagent/pkg/api"
1818
"github.com/docker/cagent/pkg/config"
1919
"github.com/docker/cagent/pkg/session"
20+
"github.com/docker/cagent/pkg/upstream"
2021
)
2122

2223
type Server struct {
@@ -27,6 +28,7 @@ type Server struct {
2728
func New(ctx context.Context, sessionStore session.Store, runConfig *config.RuntimeConfig, refreshInterval time.Duration, agentSources config.Sources) (*Server, error) {
2829
e := echo.New()
2930
e.Use(middleware.RequestLogger())
31+
e.Use(echo.WrapMiddleware(upstream.Handler))
3032

3133
s := &Server{
3234
e: e,

pkg/tools/a2a/a2a.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616

1717
"github.com/docker/cagent/pkg/httpclient"
1818
"github.com/docker/cagent/pkg/tools"
19+
"github.com/docker/cagent/pkg/upstream"
1920
)
2021

2122
// Toolset implements tools.ToolSet for A2A remote agents.
@@ -121,7 +122,8 @@ func (t *Toolset) Start(ctx context.Context) error {
121122

122123
// Use a longer timeout for the HTTP client since LLM responses can take a while.
123124
// The default a2a-go HTTP client has only a 5-second timeout which is too short.
124-
httpClient := httpclient.NewHTTPClient(httpclient.WithHeaders(t.headers))
125+
httpClient := httpclient.NewHTTPClient()
126+
httpClient.Transport = upstream.NewHeaderTransport(httpClient.Transport, t.headers)
125127

126128
client, err := a2aclient.NewFromCard(ctx, card, a2aclient.WithJSONRPCTransport(httpClient))
127129
if err != nil {

pkg/tools/builtin/api.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"github.com/docker/cagent/pkg/config/latest"
1515
"github.com/docker/cagent/pkg/js"
1616
"github.com/docker/cagent/pkg/tools"
17-
"github.com/docker/cagent/pkg/useragent"
1817
)
1918

2019
type APITool struct {
@@ -66,15 +65,11 @@ func (t *APITool) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools
6665
return nil, fmt.Errorf("failed to create request: %w", err)
6766
}
6867

69-
req.Header.Set("User-Agent", useragent.Header)
68+
setHeaders(req, t.config.Headers)
7069
if t.config.Method == http.MethodPost {
7170
req.Header.Set("Content-Type", "application/json")
7271
}
7372

74-
for key, value := range t.config.Headers {
75-
req.Header.Set(key, value)
76-
}
77-
7873
resp, err := client.Do(req)
7974
if err != nil {
8075
return nil, fmt.Errorf("request failed: %w", err)

pkg/tools/builtin/openapi.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/getkin/kin-openapi/openapi3"
1616

1717
"github.com/docker/cagent/pkg/tools"
18+
"github.com/docker/cagent/pkg/upstream"
1819
"github.com/docker/cagent/pkg/useragent"
1920
)
2021

@@ -349,9 +350,11 @@ func sanitizeToolName(name string) string {
349350
}
350351

351352
// setHeaders sets the User-Agent and custom headers on an HTTP request.
353+
// Header values may contain ${headers.NAME} placeholders that are resolved
354+
// from upstream headers stored in the request context.
352355
func setHeaders(req *http.Request, headers map[string]string) {
353356
req.Header.Set("User-Agent", useragent.Header)
354-
for k, v := range headers {
357+
for k, v := range upstream.ResolveHeaders(req.Context(), headers) {
355358
req.Header.Set(k, v)
356359
}
357360
}

pkg/tools/mcp/remote.go

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/modelcontextprotocol/go-sdk/mcp"
1212

1313
"github.com/docker/cagent/pkg/tools"
14+
"github.com/docker/cagent/pkg/upstream"
1415
)
1516

1617
type remoteMCPClient struct {
@@ -124,35 +125,11 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque
124125
return session.InitializeResult(), nil
125126
}
126127

127-
// headerTransport is a RoundTripper that adds custom headers to all requests
128-
type headerTransport struct {
129-
base http.RoundTripper
130-
headers map[string]string
131-
}
132-
133-
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
134-
// Clone the request to avoid modifying the original
135-
req = req.Clone(req.Context())
136-
137-
// Add custom headers
138-
for key, value := range t.headers {
139-
req.Header.Set(key, value)
140-
}
141-
142-
return t.base.RoundTrip(req)
143-
}
144-
145-
// createHTTPClient creates an HTTP client with custom headers and OAuth support
128+
// createHTTPClient creates an HTTP client with custom headers and OAuth support.
129+
// Header values may contain ${headers.NAME} placeholders that are resolved
130+
// at request time from upstream headers stored in the request context.
146131
func (c *remoteMCPClient) createHTTPClient() *http.Client {
147-
transport := http.DefaultTransport
148-
149-
// Add custom headers first
150-
if len(c.headers) > 0 {
151-
transport = &headerTransport{
152-
base: transport,
153-
headers: c.headers,
154-
}
155-
}
132+
transport := c.headerTransport()
156133

157134
// Then wrap with OAuth support
158135
transport = &oauthTransport{
@@ -168,6 +145,13 @@ func (c *remoteMCPClient) createHTTPClient() *http.Client {
168145
}
169146
}
170147

148+
func (c *remoteMCPClient) headerTransport() http.RoundTripper {
149+
if len(c.headers) > 0 {
150+
return upstream.NewHeaderTransport(http.DefaultTransport, c.headers)
151+
}
152+
return http.DefaultTransport
153+
}
154+
171155
func (c *remoteMCPClient) Close(context.Context) error {
172156
c.mu.RLock()
173157
session := c.session

pkg/upstream/headers.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Package upstream provides utilities for propagating HTTP headers
2+
// from incoming API requests to outbound toolset HTTP calls.
3+
package upstream
4+
5+
import (
6+
"context"
7+
"fmt"
8+
"net/http"
9+
"regexp"
10+
"strings"
11+
12+
"github.com/dop251/goja"
13+
)
14+
15+
type contextKey struct{}
16+
17+
// WithHeaders returns a new context carrying the given HTTP headers.
18+
func WithHeaders(ctx context.Context, h http.Header) context.Context {
19+
return context.WithValue(ctx, contextKey{}, h)
20+
}
21+
22+
// HeadersFromContext retrieves upstream HTTP headers from the context.
23+
// Returns nil if no headers are present.
24+
func HeadersFromContext(ctx context.Context) http.Header {
25+
h, _ := ctx.Value(contextKey{}).(http.Header)
26+
return h
27+
}
28+
29+
// Handler wraps an http.Handler to store the incoming HTTP request
30+
// headers in the request context for downstream toolset forwarding.
31+
func Handler(next http.Handler) http.Handler {
32+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
33+
ctx := WithHeaders(r.Context(), r.Header.Clone())
34+
next.ServeHTTP(w, r.WithContext(ctx))
35+
})
36+
}
37+
38+
// NewHeaderTransport wraps an http.RoundTripper to set custom headers on
39+
// every outbound request. Header values may contain ${headers.NAME}
40+
// placeholders that are resolved at request time from upstream headers
41+
// stored in the request context.
42+
func NewHeaderTransport(base http.RoundTripper, headers map[string]string) http.RoundTripper {
43+
return &headerTransport{base: base, headers: headers}
44+
}
45+
46+
type headerTransport struct {
47+
base http.RoundTripper
48+
headers map[string]string
49+
}
50+
51+
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
52+
req = req.Clone(req.Context())
53+
for key, value := range ResolveHeaders(req.Context(), t.headers) {
54+
req.Header.Set(key, value)
55+
}
56+
return t.base.RoundTrip(req)
57+
}
58+
59+
// ResolveHeaders resolves ${headers.NAME} placeholders in header values
60+
// using upstream headers from the context. Header names in the placeholder
61+
// are case-insensitive, matching HTTP header convention.
62+
//
63+
// For example, given the config header:
64+
//
65+
// Authorization: ${headers.Authorization}
66+
//
67+
// and an upstream request with "Authorization: Bearer token", the resolved
68+
// value will be "Bearer token".
69+
func ResolveHeaders(ctx context.Context, headers map[string]string) map[string]string {
70+
if len(headers) == 0 {
71+
return headers
72+
}
73+
74+
upstream := HeadersFromContext(ctx)
75+
if upstream == nil {
76+
return headers
77+
}
78+
79+
vm := goja.New()
80+
_ = vm.Set("headers", vm.NewDynamicObject(headerAccessor(func(name string) goja.Value {
81+
return vm.ToValue(upstream.Get(name))
82+
})))
83+
84+
resolved := make(map[string]string, len(headers))
85+
for k, v := range headers {
86+
resolved[k] = expandTemplate(vm, v)
87+
}
88+
return resolved
89+
}
90+
91+
// headerAccessor implements [goja.DynamicObject] for case-insensitive
92+
// HTTP header lookups.
93+
type headerAccessor func(string) goja.Value
94+
95+
func (h headerAccessor) Get(k string) goja.Value { return h(k) }
96+
func (headerAccessor) Set(string, goja.Value) bool { return false }
97+
func (headerAccessor) Has(string) bool { return true }
98+
func (headerAccessor) Delete(string) bool { return false }
99+
func (headerAccessor) Keys() []string { return nil }
100+
101+
// headerPlaceholderRe matches ${headers.NAME} and captures the header
102+
// name so we can rewrite it to bracket notation for the JS runtime.
103+
var headerPlaceholderRe = regexp.MustCompile(`\$\{\s*headers\.([^}]+)\}`)
104+
105+
// expandTemplate evaluates a string as a JavaScript template literal,
106+
// resolving any ${...} expressions via the goja runtime.
107+
// Before evaluation it rewrites ${headers.NAME} to ${headers["NAME"]}
108+
// so that header names containing hyphens (e.g. X-Request-Id) are
109+
// accessed correctly.
110+
func expandTemplate(vm *goja.Runtime, text string) string {
111+
if !strings.Contains(text, "${") {
112+
return text
113+
}
114+
115+
// Rewrite dotted header access to bracket notation so names with
116+
// hyphens work: ${headers.X-Req-Id} → ${headers["X-Req-Id"]}
117+
text = headerPlaceholderRe.ReplaceAllStringFunc(text, func(m string) string {
118+
parts := headerPlaceholderRe.FindStringSubmatch(m)
119+
name := strings.TrimSpace(parts[1])
120+
return `${headers["` + name + `"]}`
121+
})
122+
123+
escaped := strings.ReplaceAll(text, "\\", "\\\\")
124+
escaped = strings.ReplaceAll(escaped, "`", "\\`")
125+
script := "`" + escaped + "`"
126+
127+
v, err := vm.RunString(script)
128+
if err != nil {
129+
return text
130+
}
131+
if v == nil || v.Export() == nil {
132+
return ""
133+
}
134+
return fmt.Sprintf("%v", v.Export())
135+
}

0 commit comments

Comments
 (0)