Skip to content

Commit 314cbdf

Browse files
committed
added WithClient to allow users to specify their own http.Client
1 parent 5b70caf commit 314cbdf

File tree

4 files changed

+69
-10
lines changed

4 files changed

+69
-10
lines changed

client.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"fmt"
88
"io"
9+
"log"
910
"mime/multipart"
1011
"net/http"
1112

@@ -67,7 +68,12 @@ func (c *client) Do(ctx context.Context, request *Request, response interface{})
6768
}
6869
req.Header.Set("Content-Type", writer.FormDataContentType())
6970
req.Header.Set("Accept", "application/json")
70-
res, err := ctxhttp.Do(ctx, c.httpclient, req)
71+
client := c.httpclient
72+
if client == nil {
73+
client = http.DefaultClient
74+
}
75+
log.Println("The client is", client)
76+
res, err := ctxhttp.Do(ctx, client, req)
7177
if err != nil {
7278
return err
7379
}
@@ -86,6 +92,18 @@ func (c *client) Do(ctx context.Context, request *Request, response interface{})
8692
return nil
8793
}
8894

95+
// WithClient specifies the http.Client that requests will use.
96+
func WithClient(ctx context.Context, client *http.Client) context.Context {
97+
log.Printf("%+v", ctx)
98+
c, err := fromContext(ctx)
99+
if err != nil {
100+
// can't set it, fail silently
101+
return ctx
102+
}
103+
c.httpclient = client
104+
return ctx
105+
}
106+
89107
type graphErr struct {
90108
Message string
91109
}
@@ -111,9 +129,9 @@ func NewRequest(q string) *Request {
111129

112130
// Run executes the query and unmarshals the response into response.
113131
func (req *Request) Run(ctx context.Context, response interface{}) error {
114-
client := fromContext(ctx)
115-
if client == nil {
116-
return errors.New("inappropriate context")
132+
client, err := fromContext(ctx)
133+
if err != nil {
134+
return err
117135
}
118136
return client.Do(ctx, req, response)
119137
}

client_test.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@ import (
1313
"github.com/matryer/is"
1414
)
1515

16+
func TestWithClient(t *testing.T) {
17+
is := is.New(t)
18+
var calls int
19+
testClient := &http.Client{
20+
Transport: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
21+
calls++
22+
resp := &http.Response{
23+
Body: ioutil.NopCloser(strings.NewReader(`{"data":{"key":"value"}}`)),
24+
}
25+
return resp, nil
26+
}),
27+
}
28+
29+
ctx := NewContext(context.Background(), "")
30+
ctx = WithClient(ctx, testClient)
31+
32+
req := NewRequest(``)
33+
req.Run(ctx, nil)
34+
35+
is.Equal(calls, 1) // calls
36+
}
37+
1638
func TestDo(t *testing.T) {
1739
is := is.New(t)
1840
var calls int
@@ -171,3 +193,9 @@ func TestFile(t *testing.T) {
171193
is.NoErr(err)
172194

173195
}
196+
197+
type roundTripperFunc func(req *http.Request) (*http.Response, error)
198+
199+
func (fn roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
200+
return fn(req)
201+
}

context.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
package graphql
22

3-
import "context"
3+
import (
4+
"context"
5+
6+
"github.com/pkg/errors"
7+
)
8+
9+
// errInappropriateContext is returned when the context has not been
10+
// configured with graphql.NewContext.
11+
var errInappropriateContext = errors.New("inappropriate context")
412

513
// contextKey provides unique keys for context values.
614
type contextKey string
@@ -10,9 +18,12 @@ var clientContextKey = contextKey("graphql client context")
1018

1119
// fromContext gets the client from the specified
1220
// Context.
13-
func fromContext(ctx context.Context) *client {
14-
c, _ := ctx.Value(clientContextKey).(*client)
15-
return c
21+
func fromContext(ctx context.Context) (*client, error) {
22+
c, ok := ctx.Value(clientContextKey).(*client)
23+
if !ok {
24+
return nil, errInappropriateContext
25+
}
26+
return c, nil
1627
}
1728

1829
// NewContext makes a new context.Context that enables requests.

context_test.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ func TestNewContext(t *testing.T) {
1818
endpoint := "https://server.com/graphql"
1919
ctx = NewContext(ctx, endpoint)
2020

21-
vclient := fromContext(ctx)
21+
vclient, err := fromContext(ctx)
22+
is.NoErr(err)
2223
is.Equal(vclient.endpoint, endpoint)
2324

24-
vclient2 := fromContext(ctx)
25+
vclient2, err := fromContext(ctx)
26+
is.NoErr(err)
2527
is.Equal(vclient, vclient2)
2628

2729
is.Equal(ctx.Value(testContextKey), true) // normal context stuff should work

0 commit comments

Comments
 (0)