Skip to content

Commit e800908

Browse files
authored
Add ability to specify custom headers (#453)
* Add ability to specify custom headers * .github: Update golangci-lint
1 parent d294166 commit e800908

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

.github/workflows/go-lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ concurrency:
1717

1818
env:
1919
# renovate: datasource=go depName=github.com/golangci/golangci-lint
20-
GOLANGCI_LINT_VERSION: v1.64.8
20+
GOLANGCI_LINT_VERSION: v2.6.2
2121

2222
jobs:
2323
skip-check:
@@ -58,6 +58,6 @@ jobs:
5858
check-latest: true
5959

6060
- name: golangci-lint
61-
uses: golangci/golangci-lint-action@3cfe3a4abbb849e10058ce4af15d205b6da42804 # v3.7.1
61+
uses: golangci/golangci-lint-action@1e7e51e771db61008b38414a730f564565cf7c20 # v9.2.0
6262
with:
6363
version: ${{ env.GOLANGCI_LINT_VERSION }}

main.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package main
33
import (
44
"context"
55
"flag"
6+
"fmt"
67
"log"
78
"net/http"
89
"net/http/pprof"
@@ -31,6 +32,7 @@ func main() {
3132
vaultTokenPath := flag.String("vault-token-path", "parca-load/token", "The path in Vault to find the parca-load token")
3233
vaultRole := flag.String("vault-role", "parca-load", "The role name of parca-load in Vault")
3334
clientTimeout := flag.Duration("client-timeout", 10*time.Second, "Timeout for requests to the Parca instance")
35+
customHeadersStr := flag.String("headers", "", "Comma-separated custom headers in the format 'key=value,key2=value2' to attach to requests")
3436

3537
queryInterval := flag.Duration("query-interval", 5*time.Second, "The time interval between queries to the Parca instance")
3638
queryRangeStr := flag.String("query-range", "15m,12h,168h", "Comma-separated time durations for query")
@@ -83,12 +85,20 @@ func main() {
8385
log.Fatalf("parse time range string error: %v", err)
8486
}
8587

88+
customHeaders, err := parseHeaders(*customHeadersStr)
89+
if err != nil {
90+
log.Fatalf("parse custom headers error: %v", err)
91+
}
92+
8693
clientOptions := []connect.ClientOption{
8794
connect.WithGRPCWeb(),
8895
}
8996
if *token != "" {
9097
clientOptions = append(clientOptions, connect.WithInterceptors(&bearerTokenInterceptor{token: *token}))
9198
}
99+
if len(customHeaders) > 0 {
100+
clientOptions = append(clientOptions, connect.WithInterceptors(&customHeadersInterceptor{headers: customHeaders}))
101+
}
92102

93103
client := queryv1alpha1connect.NewQueryServiceClient(
94104
&http.Client{Timeout: *clientTimeout},
@@ -171,6 +181,27 @@ func (i *bearerTokenInterceptor) WrapStreamingHandler(handler connect.StreamingH
171181
return handler
172182
}
173183

184+
type customHeadersInterceptor struct {
185+
headers map[string]string
186+
}
187+
188+
func (i *customHeadersInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
189+
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
190+
for key, value := range i.headers {
191+
req.Header().Set(key, value)
192+
}
193+
return next(ctx, req)
194+
}
195+
}
196+
197+
func (i *customHeadersInterceptor) WrapStreamingClient(client connect.StreamingClientFunc) connect.StreamingClientFunc {
198+
return client
199+
}
200+
201+
func (i *customHeadersInterceptor) WrapStreamingHandler(handler connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
202+
return handler
203+
}
204+
174205
func parseTimeRanges(input string) ([]time.Duration, error) {
175206
parts := strings.Split(input, ",")
176207
durations := make([]time.Duration, len(parts))
@@ -185,3 +216,27 @@ func parseTimeRanges(input string) ([]time.Duration, error) {
185216

186217
return durations, nil
187218
}
219+
220+
func parseHeaders(input string) (map[string]string, error) {
221+
if input == "" {
222+
return nil, nil
223+
}
224+
225+
headers := make(map[string]string)
226+
pairs := strings.Split(input, ",")
227+
228+
for _, pair := range pairs {
229+
parts := strings.SplitN(strings.TrimSpace(pair), "=", 2)
230+
if len(parts) != 2 {
231+
return nil, fmt.Errorf("invalid header format: %s (expected key=value)", pair)
232+
}
233+
key := strings.TrimSpace(parts[0])
234+
value := strings.TrimSpace(parts[1])
235+
if key == "" {
236+
return nil, fmt.Errorf("empty header key in: %s", pair)
237+
}
238+
headers[key] = value
239+
}
240+
241+
return headers, nil
242+
}

0 commit comments

Comments
 (0)