Skip to content

Commit 9481030

Browse files
committed
protocol: readHeader() respects previously called conn.SetReadDeadline(t)
1 parent 6994bcd commit 9481030

File tree

2 files changed

+168
-18
lines changed

2 files changed

+168
-18
lines changed

protocol.go

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,13 @@ func (p *Listener) Addr() net.Addr {
150150
return p.Listener.Addr()
151151
}
152152

153-
// NewConn is used to wrap a net.Conn that may be speaking
154-
// the proxy protocol into a proxyproto.Conn.
153+
// NewConn is used to wrap a net.Conn that may be speaking the PROXY protocol
154+
// into a proxyproto.Conn.
155+
//
156+
// NOTE: NewConn may interfere with previously set ReadDeadline on the provided net.Conn,
157+
// because it sets a temporary deadline when detecting and reading the PROXY protocol header.
158+
// If you need to enforce a specific ReadDeadline on the connection, be sure to call Conn.SetReadDeadline
159+
// again after NewConn returns, to restore your desired deadline.
155160
func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
156161
// For v1 the header length is at most 108 bytes.
157162
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
@@ -176,18 +181,20 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
176181
// the initial scan. If there is an error parsing the header,
177182
// it is returned and the socket is closed.
178183
func (p *Conn) Read(b []byte) (int, error) {
179-
p.once.Do(func() {
180-
p.readErr = p.readHeader()
181-
})
182-
if p.readErr != nil {
183-
return 0, p.readErr
184+
// Ensure header processing runs at most once and surface any errors.
185+
if err := p.ensureHeaderProcessed(); err != nil {
186+
return 0, err
184187
}
185188

186189
return p.reader.Read(b)
187190
}
188191

189192
// Write wraps original conn.Write.
190193
func (p *Conn) Write(b []byte) (int, error) {
194+
// Ensure header processing has completed before writing.
195+
if err := p.ensureHeaderProcessed(); err != nil {
196+
return 0, err
197+
}
191198
return p.conn.Write(b)
192199
}
193200

@@ -199,7 +206,8 @@ func (p *Conn) Close() error {
199206
// ProxyHeader returns the proxy protocol header, if any. If an error occurs
200207
// while reading the proxy header, nil is returned.
201208
func (p *Conn) ProxyHeader() *Header {
202-
p.once.Do(func() { p.readErr = p.readHeader() })
209+
// Ensure header processing runs at most once.
210+
_ = p.ensureHeaderProcessed()
203211
return p.header
204212
}
205213

@@ -210,7 +218,8 @@ func (p *Conn) ProxyHeader() *Header {
210218
// from the proxy header even if the proxy header itself is
211219
// syntactically correct.
212220
func (p *Conn) LocalAddr() net.Addr {
213-
p.once.Do(func() { p.readErr = p.readHeader() })
221+
// Ensure header processing runs at most once.
222+
_ = p.ensureHeaderProcessed()
214223
if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil {
215224
return p.conn.LocalAddr()
216225
}
@@ -225,7 +234,8 @@ func (p *Conn) LocalAddr() net.Addr {
225234
// from the proxy header even if the proxy header itself is
226235
// syntactically correct.
227236
func (p *Conn) RemoteAddr() net.Addr {
228-
p.once.Do(func() { p.readErr = p.readHeader() })
237+
// Ensure header processing runs at most once.
238+
_ = p.ensureHeaderProcessed()
229239
if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil {
230240
return p.conn.RemoteAddr()
231241
}
@@ -291,11 +301,25 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
291301
// readHeader reads the proxy protocol header from the connection.
292302
func (p *Conn) readHeader() error {
293303
// If the connection's readHeaderTimeout is more than 0,
294-
// push our deadline back to now plus the timeout. This should only
295-
// run on the connection, as we don't want to override the previous
296-
// read deadline the user may have used.
304+
// apply a temporary deadline without extending a user-configured
305+
// deadline. If the user has no deadline, we use now + timeout.
297306
if p.readHeaderTimeout > 0 {
298-
if err := p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout)); err != nil {
307+
var (
308+
storedDeadline time.Time
309+
hasDeadline bool
310+
)
311+
if t := p.readDeadline.Load(); t != nil {
312+
storedDeadline = t.(time.Time)
313+
hasDeadline = !storedDeadline.IsZero()
314+
}
315+
316+
headerDeadline := time.Now().Add(p.readHeaderTimeout)
317+
if hasDeadline && storedDeadline.Before(headerDeadline) {
318+
// Clamp to the user's earlier deadline to avoid extending it.
319+
headerDeadline = storedDeadline
320+
}
321+
322+
if err := p.conn.SetReadDeadline(headerDeadline); err != nil {
299323
return err
300324
}
301325
}
@@ -304,7 +328,7 @@ func (p *Conn) readHeader() error {
304328

305329
// If the connection's readHeaderTimeout is more than 0, undo the change to the
306330
// deadline that we made above. Because we retain the readDeadline as part of our
307-
// SetReadDeadline override, we know the user's desired deadline so we use that.
331+
// SetReadDeadline override, we can restore the user's deadline (if any).
308332
// Therefore, we check whether the error is a net.Timeout and if it is, we decide
309333
// the proxy proto does not exist and set the error accordingly.
310334
if p.readHeaderTimeout > 0 {
@@ -352,8 +376,23 @@ func (p *Conn) readHeader() error {
352376
return err
353377
}
354378

379+
// ensureHeaderProcessed runs header processing once.
380+
func (p *Conn) ensureHeaderProcessed() error {
381+
p.once.Do(func() {
382+
p.readErr = p.readHeader()
383+
})
384+
if p.readErr != nil {
385+
return p.readErr
386+
}
387+
return nil
388+
}
389+
355390
// ReadFrom implements the io.ReaderFrom ReadFrom method.
356391
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
392+
// Ensure header processing has completed before reading/writing.
393+
if err := p.ensureHeaderProcessed(); err != nil {
394+
return 0, err
395+
}
357396
if rf, ok := p.conn.(io.ReaderFrom); ok {
358397
return rf.ReadFrom(r)
359398
}
@@ -362,9 +401,9 @@ func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
362401

363402
// WriteTo implements io.WriterTo.
364403
func (p *Conn) WriteTo(w io.Writer) (int64, error) {
365-
p.once.Do(func() { p.readErr = p.readHeader() })
366-
if p.readErr != nil {
367-
return 0, p.readErr
404+
// Ensure header processing has completed before reading/writing.
405+
if err := p.ensureHeaderProcessed(); err != nil {
406+
return 0, err
368407
}
369408

370409
b := make([]byte, p.bufReader.Buffered())

protocol_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,117 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
223223
}
224224
}
225225

226+
func TestReadHeaderTimeoutRespectsEarlierDeadline(t *testing.T) {
227+
const (
228+
headerTimeout = 200 * time.Millisecond
229+
userTimeout = 60 * time.Millisecond
230+
tolerance = 100 * time.Millisecond
231+
)
232+
233+
l, err := net.Listen("tcp", testLocalhostRandomPort)
234+
if err != nil {
235+
t.Fatalf("err: %v", err)
236+
}
237+
238+
pl := &Listener{
239+
Listener: l,
240+
ReadHeaderTimeout: headerTimeout,
241+
Policy: func(_ net.Addr) (Policy, error) {
242+
// Use REQUIRE so a timeout is surfaced as ErrNoProxyProtocol.
243+
return REQUIRE, nil
244+
},
245+
}
246+
247+
type dialResult struct {
248+
conn net.Conn
249+
err error
250+
}
251+
252+
dialResultCh := make(chan dialResult, 1)
253+
go func() {
254+
conn, err := net.Dial("tcp", pl.Addr().String())
255+
dialResultCh <- dialResult{conn: conn, err: err}
256+
}()
257+
258+
conn, err := pl.Accept()
259+
if err != nil {
260+
t.Fatalf("err: %v", err)
261+
}
262+
t.Cleanup(func() {
263+
if closeErr := conn.Close(); closeErr != nil {
264+
t.Errorf("failed to close connection: %v", closeErr)
265+
}
266+
})
267+
268+
result := <-dialResultCh
269+
if result.err != nil {
270+
t.Fatalf("client error: %v", result.err)
271+
}
272+
t.Cleanup(func() {
273+
if closeErr := result.conn.Close(); closeErr != nil {
274+
t.Errorf("failed to close client connection: %v", closeErr)
275+
}
276+
})
277+
278+
// Set a shorter user deadline than the readHeaderTimeout and do not send data.
279+
if err := conn.SetReadDeadline(time.Now().Add(userTimeout)); err != nil {
280+
t.Fatalf("err: %v", err)
281+
}
282+
283+
start := time.Now()
284+
recv := make([]byte, 1)
285+
_, err = conn.Read(recv)
286+
elapsed := time.Since(start)
287+
288+
// The read should honor the earlier user deadline instead of waiting
289+
// for the longer readHeaderTimeout.
290+
if !errors.Is(err, ErrNoProxyProtocol) {
291+
t.Fatalf("expected ErrNoProxyProtocol, got: %v", err)
292+
}
293+
if elapsed > userTimeout+tolerance {
294+
t.Fatalf("read exceeded user deadline: elapsed=%v timeout=%v", elapsed, userTimeout)
295+
}
296+
}
297+
298+
func TestDeadlineSettersAfterHeaderProcessed(t *testing.T) {
299+
conn, peer := net.Pipe()
300+
t.Cleanup(func() {
301+
if closeErr := conn.Close(); closeErr != nil {
302+
t.Errorf("failed to close connection: %v", closeErr)
303+
}
304+
})
305+
t.Cleanup(func() {
306+
if closeErr := peer.Close(); closeErr != nil {
307+
t.Errorf("failed to close peer connection: %v", closeErr)
308+
}
309+
})
310+
311+
proxyConn := NewConn(conn)
312+
313+
// Ensure header processing completes by sending a non-PROXY byte
314+
// and reading it through the proxy connection.
315+
go func() {
316+
if _, err := peer.Write([]byte("x")); err != nil {
317+
t.Errorf("failed to write peer data: %v", err)
318+
}
319+
}()
320+
buf := make([]byte, 1)
321+
if _, err := proxyConn.Read(buf); err != nil {
322+
t.Fatalf("read failed: %v", err)
323+
}
324+
325+
deadline := time.Now().Add(time.Second)
326+
if err := proxyConn.SetDeadline(deadline); err != nil {
327+
t.Fatalf("unexpected SetDeadline error: %v", err)
328+
}
329+
if err := proxyConn.SetReadDeadline(deadline); err != nil {
330+
t.Fatalf("unexpected SetReadDeadline error: %v", err)
331+
}
332+
if err := proxyConn.SetWriteDeadline(deadline); err != nil {
333+
t.Fatalf("unexpected SetWriteDeadline error: %v", err)
334+
}
335+
}
336+
226337
func TestReadHeaderTimeoutIsReset(t *testing.T) {
227338
const timeout = time.Millisecond * 250
228339

0 commit comments

Comments
 (0)