Skip to content

Commit 290d6bd

Browse files
renaynayfjl
andauthored
rpc: add SetHeader method to Client (#21392)
Resolves #20163 Co-authored-by: Felix Lange <[email protected]>
1 parent 9c2ac6f commit 290d6bd

File tree

3 files changed

+79
-10
lines changed

3 files changed

+79
-10
lines changed

rpc/client.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ type Client struct {
8585

8686
// writeConn is used for writing to the connection on the caller's goroutine. It should
8787
// only be accessed outside of dispatch, with the write lock held. The write lock is
88-
// taken by sending on requestOp and released by sending on sendDone.
88+
// taken by sending on reqInit and released by sending on reqSent.
8989
writeConn jsonWriter
9090

9191
// for dispatch
@@ -260,6 +260,19 @@ func (c *Client) Close() {
260260
}
261261
}
262262

263+
// SetHeader adds a custom HTTP header to the client's requests.
264+
// This method only works for clients using HTTP, it doesn't have
265+
// any effect for clients using another transport.
266+
func (c *Client) SetHeader(key, value string) {
267+
if !c.isHTTP {
268+
return
269+
}
270+
conn := c.writeConn.(*httpConn)
271+
conn.mu.Lock()
272+
conn.headers.Set(key, value)
273+
conn.mu.Unlock()
274+
}
275+
263276
// Call performs a JSON-RPC call with the given arguments and unmarshals into
264277
// result if no error occurred.
265278
//

rpc/client_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"os"
2727
"reflect"
2828
"runtime"
29+
"strings"
2930
"sync"
3031
"testing"
3132
"time"
@@ -429,6 +430,42 @@ func TestClientNotificationStorm(t *testing.T) {
429430
doTest(23000, true)
430431
}
431432

433+
func TestClientSetHeader(t *testing.T) {
434+
var gotHeader bool
435+
srv := newTestServer()
436+
httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
437+
if r.Header.Get("test") == "ok" {
438+
gotHeader = true
439+
}
440+
srv.ServeHTTP(w, r)
441+
}))
442+
defer httpsrv.Close()
443+
defer srv.Stop()
444+
445+
client, err := Dial(httpsrv.URL)
446+
if err != nil {
447+
t.Fatal(err)
448+
}
449+
defer client.Close()
450+
451+
client.SetHeader("test", "ok")
452+
if _, err := client.SupportedModules(); err != nil {
453+
t.Fatal(err)
454+
}
455+
if !gotHeader {
456+
t.Fatal("client did not set custom header")
457+
}
458+
459+
// Check that Content-Type can be replaced.
460+
client.SetHeader("content-type", "application/x-garbage")
461+
_, err = client.SupportedModules()
462+
if err == nil {
463+
t.Fatal("no error for invalid content-type header")
464+
} else if !strings.Contains(err.Error(), "Unsupported Media Type") {
465+
t.Fatalf("error is not related to content-type: %q", err)
466+
}
467+
}
468+
432469
func TestClientHTTP(t *testing.T) {
433470
server := newTestServer()
434471
defer server.Stop()

rpc/http.go

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"io/ioutil"
2727
"mime"
2828
"net/http"
29+
"net/url"
2930
"sync"
3031
"time"
3132
)
@@ -40,9 +41,11 @@ var acceptedContentTypes = []string{contentType, "application/json-rpc", "applic
4041

4142
type httpConn struct {
4243
client *http.Client
43-
req *http.Request
44+
url string
4445
closeOnce sync.Once
4546
closeCh chan interface{}
47+
mu sync.Mutex // protects headers
48+
headers http.Header
4649
}
4750

4851
// httpConn is treated specially by Client.
@@ -51,7 +54,7 @@ func (hc *httpConn) writeJSON(context.Context, interface{}) error {
5154
}
5255

5356
func (hc *httpConn) remoteAddr() string {
54-
return hc.req.URL.String()
57+
return hc.url
5558
}
5659

5760
func (hc *httpConn) readBatch() ([]*jsonrpcMessage, bool, error) {
@@ -102,16 +105,24 @@ var DefaultHTTPTimeouts = HTTPTimeouts{
102105
// DialHTTPWithClient creates a new RPC client that connects to an RPC server over HTTP
103106
// using the provided HTTP Client.
104107
func DialHTTPWithClient(endpoint string, client *http.Client) (*Client, error) {
105-
req, err := http.NewRequest(http.MethodPost, endpoint, nil)
108+
// Sanity check URL so we don't end up with a client that will fail every request.
109+
_, err := url.Parse(endpoint)
106110
if err != nil {
107111
return nil, err
108112
}
109-
req.Header.Set("Content-Type", contentType)
110-
req.Header.Set("Accept", contentType)
111113

112114
initctx := context.Background()
115+
headers := make(http.Header, 2)
116+
headers.Set("accept", contentType)
117+
headers.Set("content-type", contentType)
113118
return newClient(initctx, func(context.Context) (ServerCodec, error) {
114-
return &httpConn{client: client, req: req, closeCh: make(chan interface{})}, nil
119+
hc := &httpConn{
120+
client: client,
121+
headers: headers,
122+
url: endpoint,
123+
closeCh: make(chan interface{}),
124+
}
125+
return hc, nil
115126
})
116127
}
117128

@@ -131,7 +142,7 @@ func (c *Client) sendHTTP(ctx context.Context, op *requestOp, msg interface{}) e
131142
if respBody != nil {
132143
buf := new(bytes.Buffer)
133144
if _, err2 := buf.ReadFrom(respBody); err2 == nil {
134-
return fmt.Errorf("%v %v", err, buf.String())
145+
return fmt.Errorf("%v: %v", err, buf.String())
135146
}
136147
}
137148
return err
@@ -166,10 +177,18 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
166177
if err != nil {
167178
return nil, err
168179
}
169-
req := hc.req.WithContext(ctx)
170-
req.Body = ioutil.NopCloser(bytes.NewReader(body))
180+
req, err := http.NewRequestWithContext(ctx, "POST", hc.url, ioutil.NopCloser(bytes.NewReader(body)))
181+
if err != nil {
182+
return nil, err
183+
}
171184
req.ContentLength = int64(len(body))
172185

186+
// set headers
187+
hc.mu.Lock()
188+
req.Header = hc.headers.Clone()
189+
hc.mu.Unlock()
190+
191+
// do request
173192
resp, err := hc.client.Do(req)
174193
if err != nil {
175194
return nil, err

0 commit comments

Comments
 (0)