@@ -3,6 +3,7 @@ package main
33import (
44 "context"
55 "flag"
6+ "fmt"
67 "log"
78 "net/http"
89 "net/http/pprof"
@@ -31,6 +32,7 @@ func main() {
3132 vaultTokenPath := flag .String ("vault-token-path" , "parca-load/token" , "The path in Vault to find the parca-load token" )
3233 vaultRole := flag .String ("vault-role" , "parca-load" , "The role name of parca-load in Vault" )
3334 clientTimeout := flag .Duration ("client-timeout" , 10 * time .Second , "Timeout for requests to the Parca instance" )
35+ customHeadersStr := flag .String ("headers" , "" , "Comma-separated custom headers in the format 'key=value,key2=value2' to attach to requests" )
3436
3537 queryInterval := flag .Duration ("query-interval" , 5 * time .Second , "The time interval between queries to the Parca instance" )
3638 queryRangeStr := flag .String ("query-range" , "15m,12h,168h" , "Comma-separated time durations for query" )
@@ -83,12 +85,20 @@ func main() {
8385 log .Fatalf ("parse time range string error: %v" , err )
8486 }
8587
88+ customHeaders , err := parseHeaders (* customHeadersStr )
89+ if err != nil {
90+ log .Fatalf ("parse custom headers error: %v" , err )
91+ }
92+
8693 clientOptions := []connect.ClientOption {
8794 connect .WithGRPCWeb (),
8895 }
8996 if * token != "" {
9097 clientOptions = append (clientOptions , connect .WithInterceptors (& bearerTokenInterceptor {token : * token }))
9198 }
99+ if len (customHeaders ) > 0 {
100+ clientOptions = append (clientOptions , connect .WithInterceptors (& customHeadersInterceptor {headers : customHeaders }))
101+ }
92102
93103 client := queryv1alpha1connect .NewQueryServiceClient (
94104 & http.Client {Timeout : * clientTimeout },
@@ -171,6 +181,27 @@ func (i *bearerTokenInterceptor) WrapStreamingHandler(handler connect.StreamingH
171181 return handler
172182}
173183
184+ type customHeadersInterceptor struct {
185+ headers map [string ]string
186+ }
187+
188+ func (i * customHeadersInterceptor ) WrapUnary (next connect.UnaryFunc ) connect.UnaryFunc {
189+ return func (ctx context.Context , req connect.AnyRequest ) (connect.AnyResponse , error ) {
190+ for key , value := range i .headers {
191+ req .Header ().Set (key , value )
192+ }
193+ return next (ctx , req )
194+ }
195+ }
196+
197+ func (i * customHeadersInterceptor ) WrapStreamingClient (client connect.StreamingClientFunc ) connect.StreamingClientFunc {
198+ return client
199+ }
200+
201+ func (i * customHeadersInterceptor ) WrapStreamingHandler (handler connect.StreamingHandlerFunc ) connect.StreamingHandlerFunc {
202+ return handler
203+ }
204+
174205func parseTimeRanges (input string ) ([]time.Duration , error ) {
175206 parts := strings .Split (input , "," )
176207 durations := make ([]time.Duration , len (parts ))
@@ -185,3 +216,27 @@ func parseTimeRanges(input string) ([]time.Duration, error) {
185216
186217 return durations , nil
187218}
219+
220+ func parseHeaders (input string ) (map [string ]string , error ) {
221+ if input == "" {
222+ return nil , nil
223+ }
224+
225+ headers := make (map [string ]string )
226+ pairs := strings .Split (input , "," )
227+
228+ for _ , pair := range pairs {
229+ parts := strings .SplitN (strings .TrimSpace (pair ), "=" , 2 )
230+ if len (parts ) != 2 {
231+ return nil , fmt .Errorf ("invalid header format: %s (expected key=value)" , pair )
232+ }
233+ key := strings .TrimSpace (parts [0 ])
234+ value := strings .TrimSpace (parts [1 ])
235+ if key == "" {
236+ return nil , fmt .Errorf ("empty header key in: %s" , pair )
237+ }
238+ headers [key ] = value
239+ }
240+
241+ return headers , nil
242+ }
0 commit comments