Skip to content

Commit d130170

Browse files
committed
protocol: allow per-listener and per-conn custom read buffer size
1 parent 3622fe1 commit d130170

File tree

2 files changed

+149
-11
lines changed

2 files changed

+149
-11
lines changed

protocol.go

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,33 @@ var (
4343
// Only one of Policy or ConnPolicy should be provided. If both are provided then
4444
// a panic would occur during accept.
4545
type Listener struct {
46+
// Listener is the underlying listener.
4647
Listener net.Listener
4748
// Deprecated: use ConnPolicyFunc instead. This will be removed in future release.
48-
Policy PolicyFunc
49-
ConnPolicy ConnPolicyFunc
50-
ValidateHeader Validator
49+
Policy PolicyFunc
50+
// ConnPolicy is the policy function for accepted connections.
51+
ConnPolicy ConnPolicyFunc
52+
// ValidateHeader is the validator function for the proxy header.
53+
ValidateHeader Validator
54+
// ReadHeaderTimeout is the timeout for reading the proxy header.
5155
ReadHeaderTimeout time.Duration
56+
// ReadBufferSize is the read buffer size for accepted connections. When > 0,
57+
// each accepted connection uses this size for proxy header detection; 0 means default.
58+
ReadBufferSize int
5259
}
5360

5461
// Conn is used to wrap and underlying connection which
5562
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
5663
// return the address of the client instead of the proxy address. Each connection
5764
// will have its own readHeaderTimeout and readDeadline set by the Accept() call.
5865
type Conn struct {
59-
readDeadline atomic.Value // time.Time
60-
once sync.Once
61-
readErr error
62-
conn net.Conn
63-
bufReader *bufio.Reader
66+
readDeadline atomic.Value // time.Time
67+
once sync.Once
68+
readErr error
69+
conn net.Conn
70+
bufReader *bufio.Reader
71+
// bufferSize is set when the client overrides via WithBufferSize; nil means use default.
72+
bufferSize *int
6473
header *Header
6574
ProxyHeaderPolicy Policy
6675
Validate Validator
@@ -89,6 +98,22 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
8998
}
9099
}
91100

101+
// WithBufferSize sets the size of the read buffer used for proxy header detection.
102+
// Values <= 0 are ignored and the default (256 bytes) is used. Values < 16 are
103+
// effectively 16 due to bufio's minimum. The default is tuned for typical proxy
104+
// protocol header lengths.
105+
func WithBufferSize(length int) func(*Conn) {
106+
return func(c *Conn) {
107+
if length <= 0 {
108+
return
109+
}
110+
p := new(int)
111+
*p = length
112+
c.bufferSize = p
113+
c.bufReader = bufio.NewReaderSize(c.conn, length)
114+
}
115+
}
116+
92117
// Accept waits for and returns the next valid connection to the listener.
93118
func (p *Listener) Accept() (net.Conn, error) {
94119
for {
@@ -130,11 +155,14 @@ func (p *Listener) Accept() (net.Conn, error) {
130155
}
131156
}
132157

133-
newConn := NewConn(
134-
conn,
158+
opts := []func(*Conn){
135159
WithPolicy(proxyHeaderPolicy),
136160
ValidateHeader(p.ValidateHeader),
137-
)
161+
}
162+
if p.ReadBufferSize > 0 {
163+
opts = append(opts, WithBufferSize(p.ReadBufferSize))
164+
}
165+
newConn := NewConn(conn, opts...)
138166

139167
// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
140168
if p.ReadHeaderTimeout == 0 {

protocol_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,116 @@ func TestNewConnSetReadHeaderTimeoutIgnoresNegative(t *testing.T) {
264264
}
265265
}
266266

267+
func TestWithBufferSizePositive(t *testing.T) {
268+
conn, peer := net.Pipe()
269+
t.Cleanup(func() {
270+
_ = conn.Close()
271+
_ = peer.Close()
272+
})
273+
274+
proxyConn := NewConn(conn, WithBufferSize(4096))
275+
if proxyConn.bufferSize == nil {
276+
t.Fatalf("expected bufferSize to be set")
277+
}
278+
if *proxyConn.bufferSize != 4096 {
279+
t.Fatalf("expected bufferSize 4096, got %d", *proxyConn.bufferSize)
280+
}
281+
282+
go func() { _, _ = peer.Write([]byte("x")) }()
283+
buf := make([]byte, 1)
284+
if _, err := proxyConn.Read(buf); err != nil {
285+
t.Fatalf("read failed: %v", err)
286+
}
287+
if string(buf) != "x" {
288+
t.Fatalf("unexpected read: %q", buf)
289+
}
290+
}
291+
292+
func TestWithBufferSizeZeroOrNegative(t *testing.T) {
293+
for _, length := range []int{0, -1} {
294+
t.Run(fmt.Sprint(length), func(t *testing.T) {
295+
conn, peer := net.Pipe()
296+
t.Cleanup(func() {
297+
_ = conn.Close()
298+
_ = peer.Close()
299+
})
300+
301+
proxyConn := NewConn(conn, WithBufferSize(length))
302+
if proxyConn.bufferSize != nil {
303+
t.Fatalf("expected bufferSize to be nil for length %d", length)
304+
}
305+
306+
go func() { _, _ = peer.Write([]byte("y")) }()
307+
buf := make([]byte, 1)
308+
if _, err := proxyConn.Read(buf); err != nil {
309+
t.Fatalf("read failed: %v", err)
310+
}
311+
if string(buf) != "y" {
312+
t.Fatalf("unexpected read: %q", buf)
313+
}
314+
})
315+
}
316+
}
317+
318+
func TestListenerReadBufferSizeApplied(t *testing.T) {
319+
l, err := net.Listen("tcp", testLocalhostRandomPort)
320+
if err != nil {
321+
t.Fatalf("err: %v", err)
322+
}
323+
t.Cleanup(func() { _ = l.Close() })
324+
325+
pl := &Listener{Listener: l, ReadBufferSize: 4096}
326+
327+
go func() {
328+
c, _ := net.Dial("tcp", pl.Addr().String())
329+
if c != nil {
330+
_ = c.Close()
331+
}
332+
}()
333+
334+
conn, err := pl.Accept()
335+
if err != nil {
336+
t.Fatalf("Accept: %v", err)
337+
}
338+
t.Cleanup(func() { _ = conn.Close() })
339+
340+
proxyConn := conn.(*Conn)
341+
if proxyConn.bufferSize == nil {
342+
t.Fatalf("expected bufferSize to be set when Listener.ReadBufferSize > 0")
343+
}
344+
if *proxyConn.bufferSize != 4096 {
345+
t.Fatalf("expected bufferSize 4096, got %d", *proxyConn.bufferSize)
346+
}
347+
}
348+
349+
func TestListenerReadBufferSizeZeroUsesDefault(t *testing.T) {
350+
l, err := net.Listen("tcp", testLocalhostRandomPort)
351+
if err != nil {
352+
t.Fatalf("err: %v", err)
353+
}
354+
t.Cleanup(func() { _ = l.Close() })
355+
356+
pl := &Listener{Listener: l, ReadBufferSize: 0}
357+
358+
go func() {
359+
c, _ := net.Dial("tcp", pl.Addr().String())
360+
if c != nil {
361+
_ = c.Close()
362+
}
363+
}()
364+
365+
conn, err := pl.Accept()
366+
if err != nil {
367+
t.Fatalf("Accept: %v", err)
368+
}
369+
t.Cleanup(func() { _ = conn.Close() })
370+
371+
proxyConn := conn.(*Conn)
372+
if proxyConn.bufferSize != nil {
373+
t.Fatalf("expected bufferSize to be nil when Listener.ReadBufferSize is 0")
374+
}
375+
}
376+
267377
func TestReadHeaderTimeoutRespectsEarlierDeadline(t *testing.T) {
268378
const (
269379
headerTimeout = 200 * time.Millisecond

0 commit comments

Comments
 (0)