Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions helper/http2/http2.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,12 @@ func (srv *Server) serveConn(baseCtx context.Context, conn net.Conn) error {
}()

ctx := baseCtx
// We don't check if srv.h1.ConnContext is nil so http.Server works the same
// with or without this middleware.
// For more info, see https://github.com/pires/go-proxyproto/pull/140/changes#r2725568706.
if connCtx := srv.h1.ConnContext(ctx, conn); connCtx != nil {
ctx = connCtx
// Mirror net/http.Server ConnContext behavior.
if cc := srv.h1.ConnContext; cc != nil {
ctx = cc(ctx, conn)
if ctx == nil {
panic("ConnContext returned nil")
}
}

opts := http2.ServeConnOpts{Context: ctx, BaseConfig: srv.h1}
Expand Down
63 changes: 63 additions & 0 deletions helper/http2/http2_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package http2

import (
"context"
"net"
"net/http"
"testing"
"time"

"github.com/pires/go-proxyproto"
)

// TestServeConn_ConnContextReturnsNil lives in package http2 (not http2_test) so
// it can call the unexported serveConn method directly and recover the panic in
// the same goroutine, which is not possible through the public Serve API because
// Serve spawns a new goroutine per connection.
func TestServeConn_ConnContextReturnsNil(t *testing.T) {
srv := NewServer(&http.Server{
ReadHeaderTimeout: 5 * time.Second,
Handler: http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}),
ConnContext: func(_ context.Context, _ net.Conn) context.Context {
return nil
},
}, nil)

// Create a pipe and write a PROXY header with h2 ALPN to trigger the H2 path.
clientConn, serverConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
defer func() { _ = serverConn.Close() }()

header := proxyproto.Header{
Version: 2,
Command: proxyproto.LOCAL,
TransportProtocol: proxyproto.UNSPEC,
}
if err := header.SetTLVs([]proxyproto.TLV{{
Type: proxyproto.PP2_TYPE_ALPN,
Value: []byte("h2"),
}}); err != nil {
t.Fatalf("failed to set TLVs: %v", err)
}

// Write the header in a goroutine because net.Pipe is synchronous.
go func() {
_, _ = header.WriteTo(clientConn)
_ = clientConn.Close()
}()

pConn := proxyproto.NewConn(serverConn)

defer func() {
r := recover()
if r == nil {
t.Fatal("expected panic from ConnContext returning nil")
}
msg, ok := r.(string)
if !ok || msg != "ConnContext returned nil" {
t.Fatalf("expected panic message 'ConnContext returned nil', got: %v", r)
}
}()

_ = srv.serveConn(context.Background(), pConn)
}
143 changes: 104 additions & 39 deletions helper/http2/http2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,34 +158,98 @@ func TestServer_h2_tls(t *testing.T) {
}
}

func newTestServer(t *testing.T) (addr string, server *http.Server) {
func TestServer_h1_nil_ConnContext(t *testing.T) {
addr, server := newTestServerWithoutConnContext(t)
t.Cleanup(func() {
if err := server.Close(); err != nil {
t.Errorf("failed to close server: %v", err)
}
})

resp, err := http.Get("http://" + addr)
if err != nil {
t.Fatalf("failed to perform HTTP request: %v", err)
}
if err := resp.Body.Close(); err != nil {
t.Fatalf("failed to close response body: %v", err)
}
}

func TestServer_h2_nil_ConnContext(t *testing.T) {
addr, server := newTestServerWithoutConnContext(t)
t.Cleanup(func() {
if err := server.Close(); err != nil {
t.Errorf("failed to close server: %v", err)
}
})

conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer func() {
if err := conn.Close(); err != nil {
t.Errorf("failed to close connection: %v", err)
}
}()

proxyHeader := proxyproto.Header{
Version: 2,
Command: proxyproto.LOCAL,
TransportProtocol: proxyproto.UNSPEC,
}
tlvs := []proxyproto.TLV{{
Type: proxyproto.PP2_TYPE_ALPN,
Value: []byte("h2"),
}}
if err := proxyHeader.SetTLVs(tlvs); err != nil {
t.Fatalf("failed to set TLVs: %v", err)
}
if _, err := proxyHeader.WriteTo(conn); err != nil {
t.Fatalf("failed to write PROXY header: %v", err)
}

h2Conn, err := new(http2.Transport).NewClientConn(conn)
if err != nil {
t.Fatalf("failed to create HTTP connection: %v", err)
}

req, err := http.NewRequest(http.MethodGet, "http://"+addr, nil)
if err != nil {
t.Fatalf("failed to create HTTP request: %v", err)
}

resp, err := h2Conn.RoundTrip(req)
if err != nil {
t.Fatalf("failed to perform HTTP request: %v", err)
}
if err := resp.Body.Close(); err != nil {
t.Fatalf("failed to close response body: %v", err)
}
}

// startTestServer listens on a random port, wraps the listener with wrapListener
// (or a proxyproto.Listener if nil), and starts an h2proxy.Server in the background.
// It registers cleanup to wait for the server to finish.
func startTestServer(t *testing.T, server *http.Server, wrapListener func(net.Listener) net.Listener) string {
t.Helper()

ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}

server = &http.Server{
ReadHeaderTimeout: 5 * time.Second,
Handler: http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if v := r.Context().Value(connContextKey); v == nil {
t.Errorf("http.Request.Context missing connContextKey")
}
if v := r.Context().Value(baseContextKey); v == nil {
t.Errorf("http.Request.Context missing baseContextKey")
}
}),
BaseContext: func(_ net.Listener) context.Context {
return context.WithValue(context.Background(), baseContextKey, struct{}{})
},
ConnContext: func(ctx context.Context, _ net.Conn) context.Context {
return context.WithValue(ctx, connContextKey, struct{}{})
},
var serveLn net.Listener
if wrapListener != nil {
serveLn = wrapListener(ln)
} else {
serveLn = &proxyproto.Listener{Listener: ln}
}

h2Server := h2proxy.NewServer(server, nil)
done := make(chan error, 1)
go func() {
done <- h2Server.Serve(&proxyproto.Listener{Listener: ln})
done <- h2Server.Serve(serveLn)
}()

t.Cleanup(func() {
Expand All @@ -195,16 +259,33 @@ func newTestServer(t *testing.T) (addr string, server *http.Server) {
}
})

return ln.Addr().String(), server
return ln.Addr().String()
}

func newTestServer(t *testing.T) (addr string, server *http.Server) {
server = newContextAssertingServer(t)
return startTestServer(t, server, nil), server
}

func newTLSTestServer(t *testing.T) (addr string, server *http.Server) {
ln, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
server = newContextAssertingServer(t)
return startTestServer(t, server, func(ln net.Listener) net.Listener {
return tls.NewListener(ln, testTLSConfig(t))
}), server
}

func newTestServerWithoutConnContext(t *testing.T) (addr string, server *http.Server) {
server = &http.Server{
ReadHeaderTimeout: 5 * time.Second,
Handler: http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}),
}
return startTestServer(t, server, nil), server
}

// newContextAssertingServer returns an http.Server that asserts connContextKey
// and baseContextKey are present in every request's context.
func newContextAssertingServer(t *testing.T) *http.Server {
return &http.Server{
ReadHeaderTimeout: 5 * time.Second,
Handler: http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
if v := r.Context().Value(connContextKey); v == nil {
Expand All @@ -221,22 +302,6 @@ func newTLSTestServer(t *testing.T) (addr string, server *http.Server) {
return context.WithValue(ctx, connContextKey, struct{}{})
},
}

tlsLn := tls.NewListener(ln, testTLSConfig(t))
h2Server := h2proxy.NewServer(server, nil)
done := make(chan error, 1)
go func() {
done <- h2Server.Serve(tlsLn)
}()

t.Cleanup(func() {
err := <-done
if err != nil && !errors.Is(err, http.ErrServerClosed) {
t.Fatalf("failed to serve: %v", err)
}
})

return ln.Addr().String(), server
}

func testTLSConfig(t *testing.T) *tls.Config {
Expand Down
Loading