Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ func CookieJar(jar http.CookieJar) ClientParam {
}
}

// WithHeader sets default HTTP header for this client.
func WithHeader(headers http.Header) ClientParam {
return func(c *Client) error {
c.header = headers
return nil
}
}

// SanitizerEnabled will enable the input sanitizer which passes the URL
// path through a strict whitelist.
func SanitizerEnabled(sanitizerEnabled bool) ClientParam {
Expand All @@ -116,6 +124,8 @@ type Client struct {
auth fmt.Stringer
// jar is a set of cookies passed with requests
jar http.CookieJar
// header is the http.Header applied on every request.
header http.Header
// newTracer creates new request tracer
newTracer NewTracer
// sanitizerEnabled will enable the input sanitizer which passes the URL
Expand Down Expand Up @@ -147,6 +157,9 @@ func NewClient(addr, v string, params ...ClientParam) (*Client, error) {
if c.newTracer == nil {
c.newTracer = NewNopTracer
}
if c.header == nil {
c.header = make(http.Header)
}
return c, nil
}

Expand Down Expand Up @@ -181,6 +194,7 @@ func (c *Client) submitForm(ctx context.Context, method string, endpoint string,
if err != nil {
return nil, err
}
req.Header = c.header.Clone()
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
c.addAuth(req)
return c.client.Do(req)
Expand Down Expand Up @@ -211,6 +225,7 @@ func (c *Client) submitForm(ctx context.Context, method string, endpoint string,
if err != nil {
return nil, err
}
req.Header = c.header.Clone()
req.Header.Set("Content-Type",
fmt.Sprintf(`multipart/form-data;boundary="%v"`, writer.Boundary()))
c.addAuth(req)
Expand Down Expand Up @@ -261,6 +276,7 @@ func (c *Client) submitJSON(ctx context.Context, method string, endpoint string,
if err != nil {
return nil, err
}
req.Header = c.header.Clone()
req.Header.Set("Content-Type", "application/json")
c.addAuth(req)
tracer.Start(req)
Expand Down Expand Up @@ -311,6 +327,7 @@ func (c *Client) Delete(ctx context.Context, endpoint string) (*Response, error)
if err != nil {
return nil, err
}
req.Header = c.header.Clone()
c.addAuth(req)
tracer.Start(req)
return c.client.Do(req)
Expand Down Expand Up @@ -354,6 +371,7 @@ func (c *Client) Get(ctx context.Context, endpoint string, params url.Values) (*
if err != nil {
return nil, err
}
req.Header = c.header.Clone()
c.addAuth(req)
tracer.Start(req)
return c.client.Do(req)
Expand Down Expand Up @@ -382,6 +400,7 @@ func (c *Client) GetFile(ctx context.Context, endpoint string, params url.Values
if err != nil {
return nil, err
}
req.Header = c.header.Clone()
c.addAuth(req)
tracer := c.newTracer()
tracer.Start(req)
Expand Down Expand Up @@ -479,7 +498,7 @@ func (c *Client) writeWithPipe(endpoint string, vals url.Values, buffers ...file
r.Close()
return nil, err
}

req.Header = c.header.Clone()
c.addAuth(req)
req.Header.Set("Content-Type", fmt.Sprintf(`multipart/form-data;boundary="%v"`, writer.Boundary()))
return c.client.Do(req)
Expand Down
40 changes: 35 additions & 5 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,14 @@ func TestPostPutPatchJSON(t *testing.T) {
var data interface{}
var user, pass string
var method string
var userAgent string

expectedUserAgent := "api/1.0.0"
ch := make(chan error, 1)
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
var ok bool
method = r.Method
userAgent = r.UserAgent()

user, pass, ok = r.BasicAuth()
if !ok {
Expand All @@ -160,7 +163,9 @@ func TestPostPutPatchJSON(t *testing.T) {
})
t.Cleanup(srv.Close)

clt := newC(srv.URL, "v1", BasicAuth("user", "pass"))
header := make(http.Header)
header.Set("User-Agent", expectedUserAgent)
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"), WithHeader(header))

values := map[string]interface{}{"hello": "there"}
_, err := clt.PostJSON(context.Background(), clt.Endpoint("a", "b"), values)
Expand Down Expand Up @@ -191,6 +196,7 @@ func TestPostPutPatchJSON(t *testing.T) {

require.NoError(t, <-ch)

require.Equal(t, expectedUserAgent, userAgent)
require.Equal(t, http.MethodPatch, method)
require.Equal(t, "user", user)
require.Equal(t, "pass", pass)
Expand All @@ -200,16 +206,22 @@ func TestPostPutPatchJSON(t *testing.T) {
func TestDelete(t *testing.T) {
var method string
var user, pass string
var userAgent string

expectedUserAgent := "api/1.0.0"
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
user, pass, _ = r.BasicAuth()
method = r.Method
userAgent = r.UserAgent()
})
t.Cleanup(srv.Close)

clt := newC(srv.URL, "v1", BasicAuth("user", "pass"))
header := make(http.Header)
header.Set("User-Agent", expectedUserAgent)
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"), WithHeader(header))
re, err := clt.Delete(context.Background(), clt.Endpoint("a", "b"))
require.NoError(t, err)
require.Equal(t, expectedUserAgent, userAgent)
require.Equal(t, http.MethodDelete, method)
require.Equal(t, http.StatusOK, re.Code())
require.Equal(t, "user", user)
Expand Down Expand Up @@ -241,16 +253,22 @@ func TestDeleteP(t *testing.T) {
func TestGet(t *testing.T) {
var method string
var query url.Values
var userAgent string
expectedUserAgent := "api/1.0.0"
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
method = r.Method
query = r.URL.Query()
userAgent = r.UserAgent()
})
t.Cleanup(srv.Close)

clt := newC(srv.URL, "v1")
header := make(http.Header)
header.Set("User-Agent", expectedUserAgent)
clt := newC(srv.URL, "v1", WithHeader(header))
values := url.Values{"q": []string{"1", "2"}}
_, err := clt.Get(context.Background(), clt.Endpoint("a", "b"), values)
require.NoError(t, err)
require.Equal(t, expectedUserAgent, userAgent)
require.Equal(t, http.MethodGet, method)
require.Equal(t, values, query)
}
Expand All @@ -274,10 +292,13 @@ func TestGetFile(t *testing.T) {
require.NoError(t, err)

var user, pass string
var userAgent string
expectedUserAgent := "api/1.0.0"
ch := make(chan error, 1)
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
var ok bool
user, pass, ok = r.BasicAuth()
userAgent = r.UserAgent()
if !ok {
ch <- errors.New("basic auth headers invalid")
return
Expand All @@ -289,7 +310,9 @@ func TestGetFile(t *testing.T) {
})
t.Cleanup(srv.Close)

clt := newC(srv.URL, "v1", BasicAuth("user", "pass"))
header := make(http.Header)
header.Set("User-Agent", expectedUserAgent)
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"), WithHeader(header))
f, err := clt.GetFile(context.Background(), clt.Endpoint("download"), url.Values{})
require.NoError(t, err)
defer f.Close()
Expand All @@ -298,6 +321,7 @@ func TestGetFile(t *testing.T) {

data, err := io.ReadAll(f.Body())
require.NoError(t, err)
require.Equal(t, expectedUserAgent, userAgent)
require.Equal(t, "hello there", string(data))
require.Equal(t, "file.txt", f.FileName())
require.Equal(t, user, "user")
Expand Down Expand Up @@ -444,11 +468,14 @@ func testPostMultipartForm(t *testing.T, files []File, expected [][]byte) {
var method string
var data [][]byte
var user, pass string
var userAgent string

expectedUserAgent := "api/1.0.0"
ch := make(chan error, 100)
srv := serveHandler(func(w http.ResponseWriter, r *http.Request) {
defer close(ch)
user, pass, _ = r.BasicAuth()
userAgent = r.UserAgent()

u = r.URL
if err := r.ParseMultipartForm(64 << 20); err != nil {
Expand Down Expand Up @@ -485,7 +512,9 @@ func testPostMultipartForm(t *testing.T, files []File, expected [][]byte) {
})
t.Cleanup(srv.Close)

clt := newC(srv.URL, "v1", BasicAuth("user", "pass"))
header := make(http.Header)
header.Set("User-Agent", expectedUserAgent)
clt := newC(srv.URL, "v1", BasicAuth("user", "pass"), WithHeader(header))
values := url.Values{"a": []string{"b"}}
out, err := clt.PostForm(
context.Background(),
Expand All @@ -503,6 +532,7 @@ func testPostMultipartForm(t *testing.T, files []File, expected [][]byte) {
require.Equal(t, "hello back", string(out.Bytes()))
require.Equal(t, "/v1/a/b", u.String())

require.Equal(t, expectedUserAgent, userAgent)
require.Equal(t, http.MethodPost, method)
require.Equal(t, values, params)
require.Equal(t, expected, data)
Expand Down
Loading