Skip to content

Commit 1bc1569

Browse files
authored
Merge pull request #32 from gravitational/vapopov/add-default-headers-for-client
Add default headers for the client
2 parents 28981ff + 983817c commit 1bc1569

File tree

2 files changed

+55
-6
lines changed

2 files changed

+55
-6
lines changed

client.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ func CookieJar(jar http.CookieJar) ClientParam {
9494
}
9595
}
9696

97+
// WithHeader sets default HTTP header for this client.
98+
func WithHeader(headers http.Header) ClientParam {
99+
return func(c *Client) error {
100+
c.header = headers
101+
return nil
102+
}
103+
}
104+
97105
// SanitizerEnabled will enable the input sanitizer which passes the URL
98106
// path through a strict whitelist.
99107
func SanitizerEnabled(sanitizerEnabled bool) ClientParam {
@@ -116,6 +124,8 @@ type Client struct {
116124
auth fmt.Stringer
117125
// jar is a set of cookies passed with requests
118126
jar http.CookieJar
127+
// header is the http.Header applied on every request.
128+
header http.Header
119129
// newTracer creates new request tracer
120130
newTracer NewTracer
121131
// sanitizerEnabled will enable the input sanitizer which passes the URL
@@ -147,6 +157,9 @@ func NewClient(addr, v string, params ...ClientParam) (*Client, error) {
147157
if c.newTracer == nil {
148158
c.newTracer = NewNopTracer
149159
}
160+
if c.header == nil {
161+
c.header = make(http.Header)
162+
}
150163
return c, nil
151164
}
152165

@@ -181,6 +194,7 @@ func (c *Client) submitForm(ctx context.Context, method string, endpoint string,
181194
if err != nil {
182195
return nil, err
183196
}
197+
req.Header = c.header.Clone()
184198
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
185199
c.addAuth(req)
186200
return c.client.Do(req)
@@ -211,6 +225,7 @@ func (c *Client) submitForm(ctx context.Context, method string, endpoint string,
211225
if err != nil {
212226
return nil, err
213227
}
228+
req.Header = c.header.Clone()
214229
req.Header.Set("Content-Type",
215230
fmt.Sprintf(`multipart/form-data;boundary="%v"`, writer.Boundary()))
216231
c.addAuth(req)
@@ -261,6 +276,7 @@ func (c *Client) submitJSON(ctx context.Context, method string, endpoint string,
261276
if err != nil {
262277
return nil, err
263278
}
279+
req.Header = c.header.Clone()
264280
req.Header.Set("Content-Type", "application/json")
265281
c.addAuth(req)
266282
tracer.Start(req)
@@ -311,6 +327,7 @@ func (c *Client) Delete(ctx context.Context, endpoint string) (*Response, error)
311327
if err != nil {
312328
return nil, err
313329
}
330+
req.Header = c.header.Clone()
314331
c.addAuth(req)
315332
tracer.Start(req)
316333
return c.client.Do(req)
@@ -354,6 +371,7 @@ func (c *Client) Get(ctx context.Context, endpoint string, params url.Values) (*
354371
if err != nil {
355372
return nil, err
356373
}
374+
req.Header = c.header.Clone()
357375
c.addAuth(req)
358376
tracer.Start(req)
359377
return c.client.Do(req)
@@ -382,6 +400,7 @@ func (c *Client) GetFile(ctx context.Context, endpoint string, params url.Values
382400
if err != nil {
383401
return nil, err
384402
}
403+
req.Header = c.header.Clone()
385404
c.addAuth(req)
386405
tracer := c.newTracer()
387406
tracer.Start(req)
@@ -479,7 +498,7 @@ func (c *Client) writeWithPipe(endpoint string, vals url.Values, buffers ...file
479498
r.Close()
480499
return nil, err
481500
}
482-
501+
req.Header = c.header.Clone()
483502
c.addAuth(req)
484503
req.Header.Set("Content-Type", fmt.Sprintf(`multipart/form-data;boundary="%v"`, writer.Boundary()))
485504
return c.client.Do(req)

client_test.go

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,14 @@ func TestPostPutPatchJSON(t *testing.T) {
144144
var data interface{}
145145
var user, pass string
146146
var method string
147+
var userAgent string
147148

149+
expectedUserAgent := "api/1.0.0"
148150
ch := make(chan error, 1)
149151
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
150152
var ok bool
151153
method = r.Method
154+
userAgent = r.UserAgent()
152155

153156
user, pass, ok = r.BasicAuth()
154157
if !ok {
@@ -160,7 +163,9 @@ func TestPostPutPatchJSON(t *testing.T) {
160163
})
161164
t.Cleanup(srv.Close)
162165

163-
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"))
166+
header := make(http.Header)
167+
header.Set("User-Agent", expectedUserAgent)
168+
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"), WithHeader(header))
164169

165170
values := map[string]interface{}{"hello": "there"}
166171
_, err := clt.PostJSON(context.Background(), clt.Endpoint("a", "b"), values)
@@ -191,6 +196,7 @@ func TestPostPutPatchJSON(t *testing.T) {
191196

192197
require.NoError(t, <-ch)
193198

199+
require.Equal(t, expectedUserAgent, userAgent)
194200
require.Equal(t, http.MethodPatch, method)
195201
require.Equal(t, "user", user)
196202
require.Equal(t, "pass", pass)
@@ -200,16 +206,22 @@ func TestPostPutPatchJSON(t *testing.T) {
200206
func TestDelete(t *testing.T) {
201207
var method string
202208
var user, pass string
209+
var userAgent string
203210

211+
expectedUserAgent := "api/1.0.0"
204212
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
205213
user, pass, _ = r.BasicAuth()
206214
method = r.Method
215+
userAgent = r.UserAgent()
207216
})
208217
t.Cleanup(srv.Close)
209218

210-
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"))
219+
header := make(http.Header)
220+
header.Set("User-Agent", expectedUserAgent)
221+
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"), WithHeader(header))
211222
re, err := clt.Delete(context.Background(), clt.Endpoint("a", "b"))
212223
require.NoError(t, err)
224+
require.Equal(t, expectedUserAgent, userAgent)
213225
require.Equal(t, http.MethodDelete, method)
214226
require.Equal(t, http.StatusOK, re.Code())
215227
require.Equal(t, "user", user)
@@ -241,16 +253,22 @@ func TestDeleteP(t *testing.T) {
241253
func TestGet(t *testing.T) {
242254
var method string
243255
var query url.Values
256+
var userAgent string
257+
expectedUserAgent := "api/1.0.0"
244258
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
245259
method = r.Method
246260
query = r.URL.Query()
261+
userAgent = r.UserAgent()
247262
})
248263
t.Cleanup(srv.Close)
249264

250-
clt := newC(srv.URL, "v1")
265+
header := make(http.Header)
266+
header.Set("User-Agent", expectedUserAgent)
267+
clt := newC(srv.URL, "v1", WithHeader(header))
251268
values := url.Values{"q": []string{"1", "2"}}
252269
_, err := clt.Get(context.Background(), clt.Endpoint("a", "b"), values)
253270
require.NoError(t, err)
271+
require.Equal(t, expectedUserAgent, userAgent)
254272
require.Equal(t, http.MethodGet, method)
255273
require.Equal(t, values, query)
256274
}
@@ -274,10 +292,13 @@ func TestGetFile(t *testing.T) {
274292
require.NoError(t, err)
275293

276294
var user, pass string
295+
var userAgent string
296+
expectedUserAgent := "api/1.0.0"
277297
ch := make(chan error, 1)
278298
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
279299
var ok bool
280300
user, pass, ok = r.BasicAuth()
301+
userAgent = r.UserAgent()
281302
if !ok {
282303
ch <- errors.New("basic auth headers invalid")
283304
return
@@ -289,7 +310,9 @@ func TestGetFile(t *testing.T) {
289310
})
290311
t.Cleanup(srv.Close)
291312

292-
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"))
313+
header := make(http.Header)
314+
header.Set("User-Agent", expectedUserAgent)
315+
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"), WithHeader(header))
293316
f, err := clt.GetFile(context.Background(), clt.Endpoint("download"), url.Values{})
294317
require.NoError(t, err)
295318
defer f.Close()
@@ -298,6 +321,7 @@ func TestGetFile(t *testing.T) {
298321

299322
data, err := io.ReadAll(f.Body())
300323
require.NoError(t, err)
324+
require.Equal(t, expectedUserAgent, userAgent)
301325
require.Equal(t, "hello there", string(data))
302326
require.Equal(t, "file.txt", f.FileName())
303327
require.Equal(t, user, "user")
@@ -444,11 +468,14 @@ func testPostMultipartForm(t *testing.T, files []File, expected [][]byte) {
444468
var method string
445469
var data [][]byte
446470
var user, pass string
471+
var userAgent string
447472

473+
expectedUserAgent := "api/1.0.0"
448474
ch := make(chan error, 100)
449475
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
450476
defer close(ch)
451477
user, pass, _ = r.BasicAuth()
478+
userAgent = r.UserAgent()
452479

453480
u = r.URL
454481
if err := r.ParseMultipartForm(64 << 20); err != nil {
@@ -485,7 +512,9 @@ func testPostMultipartForm(t *testing.T, files []File, expected [][]byte) {
485512
})
486513
t.Cleanup(srv.Close)
487514

488-
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"))
515+
header := make(http.Header)
516+
header.Set("User-Agent", expectedUserAgent)
517+
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"), WithHeader(header))
489518
values := url.Values{"a": []string{"b"}}
490519
out, err := clt.PostForm(
491520
context.Background(),
@@ -503,6 +532,7 @@ func testPostMultipartForm(t *testing.T, files []File, expected [][]byte) {
503532
require.Equal(t, "hello back", string(out.Bytes()))
504533
require.Equal(t, "/v1/a/b", u.String())
505534

535+
require.Equal(t, expectedUserAgent, userAgent)
506536
require.Equal(t, http.MethodPost, method)
507537
require.Equal(t, values, params)
508538
require.Equal(t, expected, data)

0 commit comments

Comments
 (0)