Skip to content

Commit 07343ef

Browse files
committed
Support ${headers.NAME} syntax to forward upstream API headers to toolsets
Add a new pkg/upstream package that allows toolset header values to reference incoming API request headers using ${headers.NAME} placeholders. For example, a toolset config like: headers: Authorization: ${headers.Authorization} will resolve the Authorization value at request time from the upstream HTTP request that triggered the agent. The middleware (Echo and ConnectRPC) stores the incoming request headers in the context. At tool-call time, header values containing ${headers.X} are resolved from that context. Static header values without placeholders are unaffected. Assisted-By: cagent
1 parent af20cc4 commit 07343ef

File tree

8 files changed

+298
-37
lines changed

8 files changed

+298
-37
lines changed

pkg/connectrpc/server.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/docker/cagent/pkg/server"
2525
"github.com/docker/cagent/pkg/session"
2626
"github.com/docker/cagent/pkg/tools"
27+
"github.com/docker/cagent/pkg/upstream"
2728
)
2829

2930
// Server implements the Connect-RPC AgentService.
@@ -44,7 +45,8 @@ func (s *Server) Handler() http.Handler {
4445

4546
path, handler := cagentv1connect.NewAgentServiceHandler(s)
4647
mux.Handle(path, handler)
47-
return h2c.NewHandler(mux, &http2.Server{})
48+
49+
return upstream.Handler(h2c.NewHandler(mux, &http2.Server{}))
4850
}
4951

5052
// Serve starts the Connect-RPC server on the given listener.

pkg/server/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/docker/cagent/pkg/api"
1818
"github.com/docker/cagent/pkg/config"
1919
"github.com/docker/cagent/pkg/session"
20+
"github.com/docker/cagent/pkg/upstream"
2021
)
2122

2223
type Server struct {
@@ -27,6 +28,7 @@ type Server struct {
2728
func New(ctx context.Context, sessionStore session.Store, runConfig *config.RuntimeConfig, refreshInterval time.Duration, agentSources config.Sources) (*Server, error) {
2829
e := echo.New()
2930
e.Use(middleware.RequestLogger())
31+
e.Use(echo.WrapMiddleware(upstream.Handler))
3032

3133
s := &Server{
3234
e: e,

pkg/tools/a2a/a2a.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616

1717
"github.com/docker/cagent/pkg/httpclient"
1818
"github.com/docker/cagent/pkg/tools"
19+
"github.com/docker/cagent/pkg/upstream"
1920
)
2021

2122
// Toolset implements tools.ToolSet for A2A remote agents.
@@ -121,7 +122,8 @@ func (t *Toolset) Start(ctx context.Context) error {
121122

122123
// Use a longer timeout for the HTTP client since LLM responses can take a while.
123124
// The default a2a-go HTTP client has only a 5-second timeout which is too short.
124-
httpClient := httpclient.NewHTTPClient(httpclient.WithHeaders(t.headers))
125+
httpClient := httpclient.NewHTTPClient()
126+
httpClient.Transport = upstream.NewHeaderTransport(httpClient.Transport, t.headers)
125127

126128
client, err := a2aclient.NewFromCard(ctx, card, a2aclient.WithJSONRPCTransport(httpClient))
127129
if err != nil {

pkg/tools/builtin/api.go

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ import (
1414
"github.com/docker/cagent/pkg/config/latest"
1515
"github.com/docker/cagent/pkg/js"
1616
"github.com/docker/cagent/pkg/tools"
17-
"github.com/docker/cagent/pkg/useragent"
1817
)
1918

2019
type APITool struct {
@@ -66,15 +65,11 @@ func (t *APITool) callTool(ctx context.Context, toolCall tools.ToolCall) (*tools
6665
return nil, fmt.Errorf("failed to create request: %w", err)
6766
}
6867

69-
req.Header.Set("User-Agent", useragent.Header)
68+
setHeaders(req, t.config.Headers)
7069
if t.config.Method == http.MethodPost {
7170
req.Header.Set("Content-Type", "application/json")
7271
}
7372

74-
for key, value := range t.config.Headers {
75-
req.Header.Set(key, value)
76-
}
77-
7873
resp, err := client.Do(req)
7974
if err != nil {
8075
return nil, fmt.Errorf("request failed: %w", err)

pkg/tools/builtin/openapi.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/getkin/kin-openapi/openapi3"
1616

1717
"github.com/docker/cagent/pkg/tools"
18+
"github.com/docker/cagent/pkg/upstream"
1819
"github.com/docker/cagent/pkg/useragent"
1920
)
2021

@@ -349,9 +350,11 @@ func sanitizeToolName(name string) string {
349350
}
350351

351352
// setHeaders sets the User-Agent and custom headers on an HTTP request.
353+
// Header values may contain ${headers.NAME} placeholders that are resolved
354+
// from upstream headers stored in the request context.
352355
func setHeaders(req *http.Request, headers map[string]string) {
353356
req.Header.Set("User-Agent", useragent.Header)
354-
for k, v := range headers {
357+
for k, v := range upstream.ResolveHeaders(req.Context(), headers) {
355358
req.Header.Set(k, v)
356359
}
357360
}

pkg/tools/mcp/remote.go

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/modelcontextprotocol/go-sdk/mcp"
1212

1313
"github.com/docker/cagent/pkg/tools"
14+
"github.com/docker/cagent/pkg/upstream"
1415
)
1516

1617
type remoteMCPClient struct {
@@ -124,35 +125,11 @@ func (c *remoteMCPClient) Initialize(ctx context.Context, _ *mcp.InitializeReque
124125
return session.InitializeResult(), nil
125126
}
126127

127-
// headerTransport is a RoundTripper that adds custom headers to all requests
128-
type headerTransport struct {
129-
base http.RoundTripper
130-
headers map[string]string
131-
}
132-
133-
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
134-
// Clone the request to avoid modifying the original
135-
req = req.Clone(req.Context())
136-
137-
// Add custom headers
138-
for key, value := range t.headers {
139-
req.Header.Set(key, value)
140-
}
141-
142-
return t.base.RoundTrip(req)
143-
}
144-
145-
// createHTTPClient creates an HTTP client with custom headers and OAuth support
128+
// createHTTPClient creates an HTTP client with custom headers and OAuth support.
129+
// Header values may contain ${headers.NAME} placeholders that are resolved
130+
// at request time from upstream headers stored in the request context.
146131
func (c *remoteMCPClient) createHTTPClient() *http.Client {
147-
transport := http.DefaultTransport
148-
149-
// Add custom headers first
150-
if len(c.headers) > 0 {
151-
transport = &headerTransport{
152-
base: transport,
153-
headers: c.headers,
154-
}
155-
}
132+
transport := c.headerTransport()
156133

157134
// Then wrap with OAuth support
158135
transport = &oauthTransport{
@@ -168,6 +145,13 @@ func (c *remoteMCPClient) createHTTPClient() *http.Client {
168145
}
169146
}
170147

148+
func (c *remoteMCPClient) headerTransport() http.RoundTripper {
149+
if len(c.headers) > 0 {
150+
return upstream.NewHeaderTransport(http.DefaultTransport, c.headers)
151+
}
152+
return http.DefaultTransport
153+
}
154+
171155
func (c *remoteMCPClient) Close(context.Context) error {
172156
c.mu.RLock()
173157
session := c.session

pkg/upstream/headers.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Package upstream provides utilities for propagating HTTP headers
2+
// from incoming API requests to outbound toolset HTTP calls.
3+
package upstream
4+
5+
import (
6+
"context"
7+
"fmt"
8+
"net/http"
9+
"regexp"
10+
"strings"
11+
12+
"github.com/dop251/goja"
13+
)
14+
15+
type contextKey struct{}
16+
17+
// WithHeaders returns a new context carrying the given HTTP headers.
18+
func WithHeaders(ctx context.Context, h http.Header) context.Context {
19+
return context.WithValue(ctx, contextKey{}, h)
20+
}
21+
22+
// HeadersFromContext retrieves upstream HTTP headers from the context.
23+
// Returns nil if no headers are present.
24+
func HeadersFromContext(ctx context.Context) http.Header {
25+
h, _ := ctx.Value(contextKey{}).(http.Header)
26+
return h
27+
}
28+
29+
// Handler wraps an http.Handler to store the incoming HTTP request
30+
// headers in the request context for downstream toolset forwarding.
31+
func Handler(next http.Handler) http.Handler {
32+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
33+
ctx := WithHeaders(r.Context(), r.Header.Clone())
34+
next.ServeHTTP(w, r.WithContext(ctx))
35+
})
36+
}
37+
38+
// NewHeaderTransport wraps an http.RoundTripper to set custom headers on
39+
// every outbound request. Header values may contain ${headers.NAME}
40+
// placeholders that are resolved at request time from upstream headers
41+
// stored in the request context.
42+
func NewHeaderTransport(base http.RoundTripper, headers map[string]string) http.RoundTripper {
43+
return &headerTransport{base: base, headers: headers}
44+
}
45+
46+
type headerTransport struct {
47+
base http.RoundTripper
48+
headers map[string]string
49+
}
50+
51+
func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
52+
req = req.Clone(req.Context())
53+
for key, value := range ResolveHeaders(req.Context(), t.headers) {
54+
req.Header.Set(key, value)
55+
}
56+
return t.base.RoundTrip(req)
57+
}
58+
59+
// ResolveHeaders resolves ${headers.NAME} placeholders in header values
60+
// using upstream headers from the context. Header names in the placeholder
61+
// are case-insensitive, matching HTTP header convention.
62+
//
63+
// For example, given the config header:
64+
//
65+
// Authorization: ${headers.Authorization}
66+
//
67+
// and an upstream request with "Authorization: Bearer token", the resolved
68+
// value will be "Bearer token".
69+
func ResolveHeaders(ctx context.Context, headers map[string]string) map[string]string {
70+
if len(headers) == 0 {
71+
return headers
72+
}
73+
74+
upstream := HeadersFromContext(ctx)
75+
if upstream == nil {
76+
return headers
77+
}
78+
79+
vm := goja.New()
80+
_ = vm.Set("headers", vm.NewDynamicObject(headerAccessor(func(name string) goja.Value {
81+
return vm.ToValue(upstream.Get(name))
82+
})))
83+
84+
resolved := make(map[string]string, len(headers))
85+
for k, v := range headers {
86+
resolved[k] = expandTemplate(vm, v)
87+
}
88+
return resolved
89+
}
90+
91+
// headerAccessor implements [goja.DynamicObject] for case-insensitive
92+
// HTTP header lookups.
93+
type headerAccessor func(string) goja.Value
94+
95+
func (h headerAccessor) Get(k string) goja.Value { return h(k) }
96+
func (headerAccessor) Set(string, goja.Value) bool { return false }
97+
func (headerAccessor) Has(string) bool { return true }
98+
func (headerAccessor) Delete(string) bool { return false }
99+
func (headerAccessor) Keys() []string { return nil }
100+
101+
// headerPlaceholderRe matches ${headers.NAME} and captures the header
102+
// name so we can rewrite it to bracket notation for the JS runtime.
103+
var headerPlaceholderRe = regexp.MustCompile(`\$\{\s*headers\.([^}]+)\}`)
104+
105+
// expandTemplate evaluates a string as a JavaScript template literal,
106+
// resolving any ${...} expressions via the goja runtime.
107+
// Before evaluation it rewrites ${headers.NAME} to ${headers["NAME"]}
108+
// so that header names containing hyphens (e.g. X-Request-Id) are
109+
// accessed correctly.
110+
func expandTemplate(vm *goja.Runtime, text string) string {
111+
if !strings.Contains(text, "${") {
112+
return text
113+
}
114+
115+
// Rewrite dotted header access to bracket notation so names with
116+
// hyphens work: ${headers.X-Req-Id} → ${headers["X-Req-Id"]}
117+
text = headerPlaceholderRe.ReplaceAllStringFunc(text, func(m string) string {
118+
parts := headerPlaceholderRe.FindStringSubmatch(m)
119+
name := strings.TrimSpace(parts[1])
120+
return `${headers["` + name + `"]}`
121+
})
122+
123+
escaped := strings.ReplaceAll(text, "\\", "\\\\")
124+
escaped = strings.ReplaceAll(escaped, "`", "\\`")
125+
script := "`" + escaped + "`"
126+
127+
v, err := vm.RunString(script)
128+
if err != nil {
129+
return text
130+
}
131+
if v == nil || v.Export() == nil {
132+
return ""
133+
}
134+
return fmt.Sprintf("%v", v.Export())
135+
}

0 commit comments

Comments
 (0)