Skip to content

Commit 85fd463

Browse files
committed
implement our own, manageable h2 client conn pool
1 parent d4034c6 commit 85fd463

File tree

1 file changed

+333
-0
lines changed

1 file changed

+333
-0
lines changed

dialer/h2_client_conn_pool.go

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
// Copyright 2015 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file (see https://cs.opensource.google/go/x/net/+/refs/tags/v0.43.0:LICENSE).
4+
5+
// Transport code's client connection pooling.
6+
7+
package dialer
8+
9+
import (
10+
"context"
11+
"crypto/tls"
12+
"errors"
13+
"fmt"
14+
"net"
15+
"net/http"
16+
"sync"
17+
18+
"golang.org/x/net/http/httpguts"
19+
"golang.org/x/net/http2"
20+
)
21+
22+
type clientConnPool struct {
23+
t *http2.Transport
24+
25+
mu sync.Mutex // TODO: maybe switch to RWMutex
26+
// TODO: add support for sharing conns based on cert names
27+
// (e.g. share conn for googleapis.com and appspot.com)
28+
conns map[string][]*http2.ClientConn // key is host:port
29+
dialing map[string]*dialCall // currently in-flight dials
30+
keys map[*http2.ClientConn][]string
31+
addConnCalls map[string]*addConnCall // in-flight addConnIfNeeded calls
32+
}
33+
34+
func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*http2.ClientConn, error) {
35+
return p.getClientConn(req, addr, dialOnMiss)
36+
}
37+
38+
const (
39+
dialOnMiss = true
40+
noDialOnMiss = false
41+
)
42+
43+
// isConnectionCloseRequest reports whether req should use its own
44+
// connection for a single request and then close the connection.
45+
func isConnectionCloseRequest(req *http.Request) bool {
46+
return req.Close || httpguts.HeaderValuesContainsToken(req.Header["Connection"], "close")
47+
}
48+
49+
func strSliceContains(ss []string, s string) bool {
50+
for _, v := range ss {
51+
if v == s {
52+
return true
53+
}
54+
}
55+
return false
56+
}
57+
58+
func (p *clientConnPool) newTLSConfig(host string) *tls.Config {
59+
cfg := new(tls.Config)
60+
if p.t.TLSClientConfig != nil {
61+
*cfg = *p.t.TLSClientConfig.Clone()
62+
}
63+
if !strSliceContains(cfg.NextProtos, http2.NextProtoTLS) {
64+
cfg.NextProtos = append([]string{http2.NextProtoTLS}, cfg.NextProtos...)
65+
}
66+
if cfg.ServerName == "" {
67+
cfg.ServerName = host
68+
}
69+
return cfg
70+
}
71+
72+
func (p *clientConnPool) dialClientConn(ctx context.Context, addr string) (*http2.ClientConn, error) {
73+
host, _, err := net.SplitHostPort(addr)
74+
if err != nil {
75+
return nil, err
76+
}
77+
tconn, err := p.dialTLS(ctx, "tcp", addr, p.newTLSConfig(host))
78+
if err != nil {
79+
return nil, err
80+
}
81+
return p.t.NewClientConn(tconn)
82+
}
83+
84+
func (p *clientConnPool) dialTLS(ctx context.Context, network, addr string, tlsCfg *tls.Config) (net.Conn, error) {
85+
if p.t.DialTLSContext != nil {
86+
return p.t.DialTLSContext(ctx, network, addr, tlsCfg)
87+
} else if p.t.DialTLS != nil {
88+
return p.t.DialTLS(network, addr, tlsCfg)
89+
}
90+
91+
tlsCn, err := dialTLSWithContext(ctx, network, addr, tlsCfg)
92+
if err != nil {
93+
return nil, err
94+
}
95+
state := tlsCn.ConnectionState()
96+
if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
97+
return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2.NextProtoTLS)
98+
}
99+
if !state.NegotiatedProtocolIsMutual {
100+
return nil, errors.New("http2: could not negotiate protocol mutually")
101+
}
102+
return tlsCn, nil
103+
}
104+
105+
func dialTLSWithContext(ctx context.Context, network, addr string, cfg *tls.Config) (*tls.Conn, error) {
106+
dialer := &tls.Dialer{
107+
Config: cfg,
108+
}
109+
cn, err := dialer.DialContext(ctx, network, addr)
110+
if err != nil {
111+
return nil, err
112+
}
113+
tlsCn := cn.(*tls.Conn) // DialContext comment promises this will always succeed
114+
return tlsCn, nil
115+
}
116+
117+
func (p *clientConnPool) getClientConn(req *http.Request, addr string, dialOnMiss bool) (*http2.ClientConn, error) {
118+
// TODO(dneil): Dial a new connection when t.DisableKeepAlives is set?
119+
if isConnectionCloseRequest(req) && dialOnMiss {
120+
// It gets its own connection.
121+
cc, err := p.dialClientConn(req.Context(), addr)
122+
if err != nil {
123+
return nil, err
124+
}
125+
return cc, nil
126+
}
127+
for {
128+
p.mu.Lock()
129+
for _, cc := range p.conns[addr] {
130+
if cc.ReserveNewRequest() {
131+
p.mu.Unlock()
132+
return cc, nil
133+
}
134+
}
135+
if !dialOnMiss {
136+
p.mu.Unlock()
137+
return nil, http2.ErrNoCachedConn
138+
}
139+
call := p.getStartDialLocked(req.Context(), addr)
140+
p.mu.Unlock()
141+
<-call.done
142+
if shouldRetryDial(call, req) {
143+
continue
144+
}
145+
cc, err := call.res, call.err
146+
if err != nil {
147+
return nil, err
148+
}
149+
if cc.ReserveNewRequest() {
150+
return cc, nil
151+
}
152+
}
153+
}
154+
155+
// incomparable is a zero-width, non-comparable type. Adding it to a struct
156+
// makes that struct also non-comparable, and generally doesn't add
157+
// any size (as long as it's first).
158+
type incomparable [0]func()
159+
160+
// dialCall is an in-flight Transport dial call to a host.
161+
type dialCall struct {
162+
_ incomparable
163+
p *clientConnPool
164+
// the context associated with the request
165+
// that created this dialCall
166+
ctx context.Context
167+
done chan struct{} // closed when done
168+
res *http2.ClientConn // valid after done is closed
169+
err error // valid after done is closed
170+
}
171+
172+
// requires p.mu is held.
173+
func (p *clientConnPool) getStartDialLocked(ctx context.Context, addr string) *dialCall {
174+
if call, ok := p.dialing[addr]; ok {
175+
// A dial is already in-flight. Don't start another.
176+
return call
177+
}
178+
call := &dialCall{p: p, done: make(chan struct{}), ctx: ctx}
179+
if p.dialing == nil {
180+
p.dialing = make(map[string]*dialCall)
181+
}
182+
p.dialing[addr] = call
183+
go call.dial(call.ctx, addr)
184+
return call
185+
}
186+
187+
// run in its own goroutine.
188+
func (c *dialCall) dial(ctx context.Context, addr string) {
189+
c.res, c.err = c.p.dialClientConn(ctx, addr)
190+
191+
c.p.mu.Lock()
192+
delete(c.p.dialing, addr)
193+
if c.err == nil {
194+
c.p.addConnLocked(addr, c.res)
195+
}
196+
c.p.mu.Unlock()
197+
198+
close(c.done)
199+
}
200+
201+
// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
202+
// already exist. It coalesces concurrent calls with the same key.
203+
// This is used by the http1 Transport code when it creates a new connection. Because
204+
// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
205+
// the protocol), it can get into a situation where it has multiple TLS connections.
206+
// This code decides which ones live or die.
207+
// The return value used is whether c was used.
208+
// c is never closed.
209+
func (p *clientConnPool) addConnIfNeeded(key string, t *http2.Transport, c net.Conn) (used bool, err error) {
210+
p.mu.Lock()
211+
for _, cc := range p.conns[key] {
212+
if cc.CanTakeNewRequest() {
213+
p.mu.Unlock()
214+
return false, nil
215+
}
216+
}
217+
call, dup := p.addConnCalls[key]
218+
if !dup {
219+
if p.addConnCalls == nil {
220+
p.addConnCalls = make(map[string]*addConnCall)
221+
}
222+
call = &addConnCall{
223+
p: p,
224+
done: make(chan struct{}),
225+
}
226+
p.addConnCalls[key] = call
227+
go call.run(t, key, c)
228+
}
229+
p.mu.Unlock()
230+
231+
<-call.done
232+
if call.err != nil {
233+
return false, call.err
234+
}
235+
return !dup, nil
236+
}
237+
238+
type addConnCall struct {
239+
_ incomparable
240+
p *clientConnPool
241+
done chan struct{} // closed when done
242+
err error
243+
}
244+
245+
func (c *addConnCall) run(t *http2.Transport, key string, nc net.Conn) {
246+
cc, err := t.NewClientConn(nc)
247+
248+
p := c.p
249+
p.mu.Lock()
250+
if err != nil {
251+
c.err = err
252+
} else {
253+
p.addConnLocked(key, cc)
254+
}
255+
delete(p.addConnCalls, key)
256+
p.mu.Unlock()
257+
close(c.done)
258+
}
259+
260+
// p.mu must be held
261+
func (p *clientConnPool) addConnLocked(key string, cc *http2.ClientConn) {
262+
for _, v := range p.conns[key] {
263+
if v == cc {
264+
return
265+
}
266+
}
267+
if p.conns == nil {
268+
p.conns = make(map[string][]*http2.ClientConn)
269+
}
270+
if p.keys == nil {
271+
p.keys = make(map[*http2.ClientConn][]string)
272+
}
273+
p.conns[key] = append(p.conns[key], cc)
274+
p.keys[cc] = append(p.keys[cc], key)
275+
}
276+
277+
func (p *clientConnPool) MarkDead(cc *http2.ClientConn) {
278+
p.mu.Lock()
279+
defer p.mu.Unlock()
280+
for _, key := range p.keys[cc] {
281+
vv, ok := p.conns[key]
282+
if !ok {
283+
continue
284+
}
285+
newList := filterOutClientConn(vv, cc)
286+
if len(newList) > 0 {
287+
p.conns[key] = newList
288+
} else {
289+
delete(p.conns, key)
290+
}
291+
}
292+
delete(p.keys, cc)
293+
}
294+
295+
func filterOutClientConn(in []*http2.ClientConn, exclude *http2.ClientConn) []*http2.ClientConn {
296+
out := in[:0]
297+
for _, v := range in {
298+
if v != exclude {
299+
out = append(out, v)
300+
}
301+
}
302+
// If we filtered it out, zero out the last item to prevent
303+
// the GC from seeing it.
304+
if len(in) != len(out) {
305+
in[len(in)-1] = nil
306+
}
307+
return out
308+
}
309+
310+
// shouldRetryDial reports whether the current request should
311+
// retry dialing after the call finished unsuccessfully, for example
312+
// if the dial was canceled because of a context cancellation or
313+
// deadline expiry.
314+
func shouldRetryDial(call *dialCall, req *http.Request) bool {
315+
if call.err == nil {
316+
// No error, no need to retry
317+
return false
318+
}
319+
if call.ctx == req.Context() {
320+
// If the call has the same context as the request, the dial
321+
// should not be retried, since any cancellation will have come
322+
// from this request.
323+
return false
324+
}
325+
if !errors.Is(call.err, context.Canceled) && !errors.Is(call.err, context.DeadlineExceeded) {
326+
// If the call error is not because of a context cancellation or a deadline expiry,
327+
// the dial should not be retried.
328+
return false
329+
}
330+
// Only retry if the error is a context cancellation error or deadline expiry
331+
// and the context associated with the call was canceled or expired.
332+
return call.ctx.Err() != nil
333+
}

0 commit comments

Comments
 (0)