|
| 1 | +// Package upstream provides utilities for propagating HTTP headers |
| 2 | +// from incoming API requests to outbound toolset HTTP calls. |
| 3 | +package upstream |
| 4 | + |
| 5 | +import ( |
| 6 | + "context" |
| 7 | + "fmt" |
| 8 | + "net/http" |
| 9 | + "regexp" |
| 10 | + "strings" |
| 11 | + |
| 12 | + "github.com/dop251/goja" |
| 13 | +) |
| 14 | + |
| 15 | +type contextKey struct{} |
| 16 | + |
| 17 | +// WithHeaders returns a new context carrying the given HTTP headers. |
| 18 | +func WithHeaders(ctx context.Context, h http.Header) context.Context { |
| 19 | + return context.WithValue(ctx, contextKey{}, h) |
| 20 | +} |
| 21 | + |
| 22 | +// HeadersFromContext retrieves upstream HTTP headers from the context. |
| 23 | +// Returns nil if no headers are present. |
| 24 | +func HeadersFromContext(ctx context.Context) http.Header { |
| 25 | + h, _ := ctx.Value(contextKey{}).(http.Header) |
| 26 | + return h |
| 27 | +} |
| 28 | + |
| 29 | +// Handler wraps an http.Handler to store the incoming HTTP request |
| 30 | +// headers in the request context for downstream toolset forwarding. |
| 31 | +func Handler(next http.Handler) http.Handler { |
| 32 | + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 33 | + ctx := WithHeaders(r.Context(), r.Header.Clone()) |
| 34 | + next.ServeHTTP(w, r.WithContext(ctx)) |
| 35 | + }) |
| 36 | +} |
| 37 | + |
| 38 | +// NewHeaderTransport wraps an http.RoundTripper to set custom headers on |
| 39 | +// every outbound request. Header values may contain ${headers.NAME} |
| 40 | +// placeholders that are resolved at request time from upstream headers |
| 41 | +// stored in the request context. |
| 42 | +func NewHeaderTransport(base http.RoundTripper, headers map[string]string) http.RoundTripper { |
| 43 | + return &headerTransport{base: base, headers: headers} |
| 44 | +} |
| 45 | + |
| 46 | +type headerTransport struct { |
| 47 | + base http.RoundTripper |
| 48 | + headers map[string]string |
| 49 | +} |
| 50 | + |
| 51 | +func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
| 52 | + req = req.Clone(req.Context()) |
| 53 | + for key, value := range ResolveHeaders(req.Context(), t.headers) { |
| 54 | + req.Header.Set(key, value) |
| 55 | + } |
| 56 | + return t.base.RoundTrip(req) |
| 57 | +} |
| 58 | + |
| 59 | +// ResolveHeaders resolves ${headers.NAME} placeholders in header values |
| 60 | +// using upstream headers from the context. Header names in the placeholder |
| 61 | +// are case-insensitive, matching HTTP header convention. |
| 62 | +// |
| 63 | +// For example, given the config header: |
| 64 | +// |
| 65 | +// Authorization: ${headers.Authorization} |
| 66 | +// |
| 67 | +// and an upstream request with "Authorization: Bearer token", the resolved |
| 68 | +// value will be "Bearer token". |
| 69 | +func ResolveHeaders(ctx context.Context, headers map[string]string) map[string]string { |
| 70 | + if len(headers) == 0 { |
| 71 | + return headers |
| 72 | + } |
| 73 | + |
| 74 | + upstream := HeadersFromContext(ctx) |
| 75 | + if upstream == nil { |
| 76 | + return headers |
| 77 | + } |
| 78 | + |
| 79 | + vm := goja.New() |
| 80 | + _ = vm.Set("headers", vm.NewDynamicObject(headerAccessor(func(name string) goja.Value { |
| 81 | + return vm.ToValue(upstream.Get(name)) |
| 82 | + }))) |
| 83 | + |
| 84 | + resolved := make(map[string]string, len(headers)) |
| 85 | + for k, v := range headers { |
| 86 | + resolved[k] = expandTemplate(vm, v) |
| 87 | + } |
| 88 | + return resolved |
| 89 | +} |
| 90 | + |
| 91 | +// headerAccessor implements [goja.DynamicObject] for case-insensitive |
| 92 | +// HTTP header lookups. |
| 93 | +type headerAccessor func(string) goja.Value |
| 94 | + |
| 95 | +func (h headerAccessor) Get(k string) goja.Value { return h(k) } |
| 96 | +func (headerAccessor) Set(string, goja.Value) bool { return false } |
| 97 | +func (headerAccessor) Has(string) bool { return true } |
| 98 | +func (headerAccessor) Delete(string) bool { return false } |
| 99 | +func (headerAccessor) Keys() []string { return nil } |
| 100 | + |
| 101 | +// headerPlaceholderRe matches ${headers.NAME} and captures the header |
| 102 | +// name so we can rewrite it to bracket notation for the JS runtime. |
| 103 | +var headerPlaceholderRe = regexp.MustCompile(`\$\{\s*headers\.([^}]+)\}`) |
| 104 | + |
| 105 | +// expandTemplate evaluates a string as a JavaScript template literal, |
| 106 | +// resolving any ${...} expressions via the goja runtime. |
| 107 | +// Before evaluation it rewrites ${headers.NAME} to ${headers["NAME"]} |
| 108 | +// so that header names containing hyphens (e.g. X-Request-Id) are |
| 109 | +// accessed correctly. |
| 110 | +func expandTemplate(vm *goja.Runtime, text string) string { |
| 111 | + if !strings.Contains(text, "${") { |
| 112 | + return text |
| 113 | + } |
| 114 | + |
| 115 | + // Rewrite dotted header access to bracket notation so names with |
| 116 | + // hyphens work: ${headers.X-Req-Id} → ${headers["X-Req-Id"]} |
| 117 | + text = headerPlaceholderRe.ReplaceAllStringFunc(text, func(m string) string { |
| 118 | + parts := headerPlaceholderRe.FindStringSubmatch(m) |
| 119 | + name := strings.TrimSpace(parts[1]) |
| 120 | + return `${headers["` + name + `"]}` |
| 121 | + }) |
| 122 | + |
| 123 | + escaped := strings.ReplaceAll(text, "\\", "\\\\") |
| 124 | + escaped = strings.ReplaceAll(escaped, "`", "\\`") |
| 125 | + script := "`" + escaped + "`" |
| 126 | + |
| 127 | + v, err := vm.RunString(script) |
| 128 | + if err != nil { |
| 129 | + return text |
| 130 | + } |
| 131 | + if v == nil || v.Export() == nil { |
| 132 | + return "" |
| 133 | + } |
| 134 | + return fmt.Sprintf("%v", v.Export()) |
| 135 | +} |
0 commit comments