@@ -150,8 +150,13 @@ func (p *Listener) Addr() net.Addr {
150150 return p .Listener .Addr ()
151151}
152152
153- // NewConn is used to wrap a net.Conn that may be speaking
154- // the proxy protocol into a proxyproto.Conn.
153+ // NewConn is used to wrap a net.Conn that may be speaking the PROXY protocol
154+ // into a proxyproto.Conn.
155+ //
156+ // NOTE: NewConn may interfere with previously set ReadDeadline on the provided net.Conn,
157+ // because it sets a temporary deadline when detecting and reading the PROXY protocol header.
158+ // If you need to enforce a specific ReadDeadline on the connection, be sure to call Conn.SetReadDeadline
159+ // again after NewConn returns, to restore your desired deadline.
155160func NewConn (conn net.Conn , opts ... func (* Conn )) * Conn {
156161 // For v1 the header length is at most 108 bytes.
157162 // For v2 the header length is at most 52 bytes plus the length of the TLVs.
@@ -176,18 +181,20 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
176181// the initial scan. If there is an error parsing the header,
177182// it is returned and the socket is closed.
178183func (p * Conn ) Read (b []byte ) (int , error ) {
179- p .once .Do (func () {
180- p .readErr = p .readHeader ()
181- })
182- if p .readErr != nil {
183- return 0 , p .readErr
184+ // Ensure header processing runs at most once and surface any errors.
185+ if err := p .ensureHeaderProcessed (); err != nil {
186+ return 0 , err
184187 }
185188
186189 return p .reader .Read (b )
187190}
188191
189192// Write wraps original conn.Write.
190193func (p * Conn ) Write (b []byte ) (int , error ) {
194+ // Ensure header processing has completed before writing.
195+ if err := p .ensureHeaderProcessed (); err != nil {
196+ return 0 , err
197+ }
191198 return p .conn .Write (b )
192199}
193200
@@ -199,7 +206,8 @@ func (p *Conn) Close() error {
199206// ProxyHeader returns the proxy protocol header, if any. If an error occurs
200207// while reading the proxy header, nil is returned.
201208func (p * Conn ) ProxyHeader () * Header {
202- p .once .Do (func () { p .readErr = p .readHeader () })
209+ // Ensure header processing runs at most once.
210+ _ = p .ensureHeaderProcessed ()
203211 return p .header
204212}
205213
@@ -210,7 +218,8 @@ func (p *Conn) ProxyHeader() *Header {
210218// from the proxy header even if the proxy header itself is
211219// syntactically correct.
212220func (p * Conn ) LocalAddr () net.Addr {
213- p .once .Do (func () { p .readErr = p .readHeader () })
221+ // Ensure header processing runs at most once.
222+ _ = p .ensureHeaderProcessed ()
214223 if p .header == nil || p .header .Command .IsLocal () || p .readErr != nil {
215224 return p .conn .LocalAddr ()
216225 }
@@ -225,7 +234,8 @@ func (p *Conn) LocalAddr() net.Addr {
225234// from the proxy header even if the proxy header itself is
226235// syntactically correct.
227236func (p * Conn ) RemoteAddr () net.Addr {
228- p .once .Do (func () { p .readErr = p .readHeader () })
237+ // Ensure header processing runs at most once.
238+ _ = p .ensureHeaderProcessed ()
229239 if p .header == nil || p .header .Command .IsLocal () || p .readErr != nil {
230240 return p .conn .RemoteAddr ()
231241 }
@@ -291,11 +301,25 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
291301// readHeader reads the proxy protocol header from the connection.
292302func (p * Conn ) readHeader () error {
293303 // If the connection's readHeaderTimeout is more than 0,
294- // push our deadline back to now plus the timeout. This should only
295- // run on the connection, as we don't want to override the previous
296- // read deadline the user may have used.
304+ // apply a temporary deadline without extending a user-configured
305+ // deadline. If the user has no deadline, we use now + timeout.
297306 if p .readHeaderTimeout > 0 {
298- if err := p .conn .SetReadDeadline (time .Now ().Add (p .readHeaderTimeout )); err != nil {
307+ var (
308+ storedDeadline time.Time
309+ hasDeadline bool
310+ )
311+ if t := p .readDeadline .Load (); t != nil {
312+ storedDeadline = t .(time.Time )
313+ hasDeadline = ! storedDeadline .IsZero ()
314+ }
315+
316+ headerDeadline := time .Now ().Add (p .readHeaderTimeout )
317+ if hasDeadline && storedDeadline .Before (headerDeadline ) {
318+ // Clamp to the user's earlier deadline to avoid extending it.
319+ headerDeadline = storedDeadline
320+ }
321+
322+ if err := p .conn .SetReadDeadline (headerDeadline ); err != nil {
299323 return err
300324 }
301325 }
@@ -304,7 +328,7 @@ func (p *Conn) readHeader() error {
304328
305329 // If the connection's readHeaderTimeout is more than 0, undo the change to the
306330 // deadline that we made above. Because we retain the readDeadline as part of our
307- // SetReadDeadline override, we know the user's desired deadline so we use that .
331+ // SetReadDeadline override, we can restore the user's deadline (if any) .
308332 // Therefore, we check whether the error is a net.Timeout and if it is, we decide
309333 // the proxy proto does not exist and set the error accordingly.
310334 if p .readHeaderTimeout > 0 {
@@ -352,8 +376,23 @@ func (p *Conn) readHeader() error {
352376 return err
353377}
354378
379+ // ensureHeaderProcessed runs header processing once.
380+ func (p * Conn ) ensureHeaderProcessed () error {
381+ p .once .Do (func () {
382+ p .readErr = p .readHeader ()
383+ })
384+ if p .readErr != nil {
385+ return p .readErr
386+ }
387+ return nil
388+ }
389+
355390// ReadFrom implements the io.ReaderFrom ReadFrom method.
356391func (p * Conn ) ReadFrom (r io.Reader ) (int64 , error ) {
392+ // Ensure header processing has completed before reading/writing.
393+ if err := p .ensureHeaderProcessed (); err != nil {
394+ return 0 , err
395+ }
357396 if rf , ok := p .conn .(io.ReaderFrom ); ok {
358397 return rf .ReadFrom (r )
359398 }
@@ -362,9 +401,9 @@ func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
362401
363402// WriteTo implements io.WriterTo.
364403func (p * Conn ) WriteTo (w io.Writer ) (int64 , error ) {
365- p . once . Do ( func () { p . readErr = p . readHeader () })
366- if p . readErr != nil {
367- return 0 , p . readErr
404+ // Ensure header processing has completed before reading/writing.
405+ if err := p . ensureHeaderProcessed (); err != nil {
406+ return 0 , err
368407 }
369408
370409 b := make ([]byte , p .bufReader .Buffered ())
0 commit comments