diff --git a/p2p/transport/websocket/addrs_test.go b/p2p/transport/websocket/addrs_test.go index 50a8b9e823..167024b102 100644 --- a/p2p/transport/websocket/addrs_test.go +++ b/p2p/transport/websocket/addrs_test.go @@ -69,7 +69,7 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) { } func TestListeningOnDNSAddr(t *testing.T) { - ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil) + ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil, nil, nil) require.NoError(t, err) addr := ln.Multiaddr() first, rest := ma.SplitFirst(addr) diff --git a/p2p/transport/websocket/conn.go b/p2p/transport/websocket/conn.go index 1c2ecd03df..6e73c0bdf7 100644 --- a/p2p/transport/websocket/conn.go +++ b/p2p/transport/websocket/conn.go @@ -39,7 +39,7 @@ var _ manet.Conn = (*Conn)(nil) // NewConn creates a Conn given a regular gorilla/websocket Conn. // // Deprecated: There's no reason to use this method externally. It'll be unexported in a future release. -func NewConn(raw *ws.Conn, secure bool) *Conn { +func NewConn(raw *ws.Conn, secure bool, remoteAddr string) *Conn { lna := NewAddrWithScheme(raw.LocalAddr().String(), secure) laddr, err := manet.FromNetAddr(lna) if err != nil { @@ -47,7 +47,10 @@ func NewConn(raw *ws.Conn, secure bool) *Conn { return nil } - rna := NewAddrWithScheme(raw.RemoteAddr().String(), secure) + if remoteAddr == "" { + remoteAddr = raw.RemoteAddr().String() + } + rna := NewAddrWithScheme(remoteAddr, secure) raddr, err := manet.FromNetAddr(rna) if err != nil { log.Errorf("BUG: invalid remoteaddr on websocket conn", raw.RemoteAddr()) diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index e72e7dabd3..d99291cce1 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/http" + "net/netip" "sync" "go.uber.org/zap" @@ -29,6 +30,8 @@ type listener struct { // so we can't rely on checking if server.TLSConfig is set. isWss bool + remoteIpExtractor RemoteAddrExtractor + laddr ma.Multiaddr incoming chan *Conn @@ -52,7 +55,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). -func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) { +func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr, trustedProxies []netip.Prefix, realAddrExtractor RemoteAddrExtractor) (*listener, error) { parsed, err := parseWebsocketMultiaddr(a) if err != nil { return nil, err @@ -62,6 +65,16 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) } + var remoteAddrExtractor RemoteAddrExtractor = func(remoteAddr net.Addr, header http.Header) string { + return remoteAddr.String() + } + + if realAddrExtractor != nil && len(trustedProxies) > 0 { + remoteAddrExtractor = func(remoteAddr net.Addr, header http.Header) string { + return extractRemoteAddrForProxy(trustedProxies, realAddrExtractor, remoteAddr, header) + } + } + var nl net.Listener if sharedTcp == nil { @@ -106,6 +119,8 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg laddr: parsed.toMultiaddr(), incoming: make(chan *Conn), closed: make(chan struct{}), + + remoteIpExtractor: remoteAddrExtractor, } ln.server = http.Server{Handler: ln, ErrorLog: stdLog} if parsed.isWSS { @@ -130,7 +145,8 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) { // The upgrader writes a response for us. return } - nc := NewConn(c, l.isWss) + + nc := NewConn(c, l.isWss, l.remoteIpExtractor(c.RemoteAddr(), r.Header)) if nc == nil { c.Close() w.WriteHeader(500) @@ -174,6 +190,29 @@ func (l *listener) Multiaddr() ma.Multiaddr { return l.laddr } +func isProxyTrusted(trustedProxies []netip.Prefix, remoteAddr net.Addr) bool { + remoteAddrPort, err := netip.ParseAddrPort(remoteAddr.String()) + if err != nil { + return false + } + + for _, prefix := range trustedProxies { + if prefix.Contains(remoteAddrPort.Addr()) { + return true + } + } + + return false +} + +// extractRemoteAddrForProxy extract real ip if the given IP address belongs to a trusted proxy +func extractRemoteAddrForProxy(trustedProxies []netip.Prefix, realAddrExtractor RemoteAddrExtractor, remoteAddr net.Addr, header http.Header) string { + if isProxyTrusted(trustedProxies, remoteAddr) { + return realAddrExtractor(remoteAddr, header) + } + return remoteAddr.String() +} + type transportListener struct { transport.Listener } diff --git a/p2p/transport/websocket/listener_test.go b/p2p/transport/websocket/listener_test.go new file mode 100644 index 0000000000..97966aaf66 --- /dev/null +++ b/p2p/transport/websocket/listener_test.go @@ -0,0 +1,67 @@ +package websocket + +import ( + "net/netip" + "testing" +) + +func TestIsProxyTrusted(t *testing.T) { + tests := []struct { + name string + trustedProxies []string + remoteAddr string + want bool + }{ + { + name: "Single IP trusted", + trustedProxies: []string{"192.168.1.1/32"}, + remoteAddr: "192.168.1.1:1234", + want: true, + }, + { + name: "IP not in trusted list", + trustedProxies: []string{"192.168.1.1/32"}, + remoteAddr: "192.168.1.2:1234", + want: false, + }, + { + name: "CIDR range trusted", + trustedProxies: []string{"192.168.2.0/24", "192.168.1.0/24"}, + remoteAddr: "192.168.1.100:1234", + want: true, + }, + { + name: "IPv6 address trusted", + trustedProxies: []string{"2001:db8::1/128"}, + remoteAddr: "[2001:db8::1]:1234", + want: true, + }, + { + name: "IPv6 CIDR range trusted", + trustedProxies: []string{"2001:db8::/32", "2001:db9::/32"}, + remoteAddr: "[2001:db8:1:2:3:4:5:6]:1234", + want: true, + }, + { + name: "Empty trusted proxies list", + trustedProxies: []string{}, + remoteAddr: "192.168.1.1:1234", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var trustedProxies []netip.Prefix + for _, cidr := range tt.trustedProxies { + prefix, _ := netip.ParsePrefix(cidr) + trustedProxies = append(trustedProxies, prefix) + } + + got := isProxyTrusted(trustedProxies, &fakeAddr{tt.remoteAddr}) + if got != tt.want { + t.Errorf("isProxyTrusted() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/p2p/transport/websocket/real_ip.go b/p2p/transport/websocket/real_ip.go new file mode 100644 index 0000000000..df96292b02 --- /dev/null +++ b/p2p/transport/websocket/real_ip.go @@ -0,0 +1,93 @@ +package websocket + +import ( + "fmt" + "net" + "net/http" + "strconv" + "strings" +) + +// DefaultGetRealAddr implements RFC 7239 Forwarded header parsing +func DefaultGetRealAddr(addr net.Addr, h http.Header) string { + remoteAddr := addr.String() + remoteIp := GetRealIPFromHeader(h) + if remoteIp != nil { + remoteTcpAddr, ok := addr.(*net.TCPAddr) + if ok { + remoteAddr = IpPort(remoteIp, strconv.Itoa(remoteTcpAddr.Port)) + } else { + _, port, err := net.SplitHostPort(remoteAddr) + if err == nil { + remoteAddr = IpPort(remoteIp, port) + } + } + } + return remoteAddr + +} + +// GetRealIPFromHeader extracts the client's real IP address from HTTP request header. +// implements RFC 7239 Forwarded header parsing +func GetRealIPFromHeader(h http.Header) net.IP { + // RFC 7239 Forwarded header + if forwarded := h.Get("Forwarded"); forwarded != "" { + if host := parseForwardedHeader(forwarded); host != "" { + if ip := validateIp(host); ip != nil { + return ip + } + } + } + + // Fallback to X-Forwarded-For + if xff := h.Get("X-Forwarded-For"); xff != "" { + if host := parseXForwardedFor(xff); host != "" { + if ip := validateIp(host); ip != nil { + return ip + } + } + } + + return nil +} + +func parseForwardedHeader(value string) string { + parts := strings.Split(value, ",") + if len(parts) == 0 { + return "" + } + + pair := strings.TrimSpace(parts[0]) + for _, elem := range strings.Split(pair, ";") { + if kv := strings.Split(strings.TrimSpace(elem), "="); len(kv) == 2 { + if strings.ToLower(kv[0]) == "for" { + host := strings.Trim(kv[1], "\"[]") + return host + } + } + } + return "" +} + +func parseXForwardedFor(value string) string { + ips := strings.Split(value, ",") + if len(ips) == 0 { + return "" + } + return strings.TrimSpace(ips[0]) +} + +// validateIp checks if a string is a valid IP address +func validateIp(ip string) net.IP { + if ip == "" { + return nil + } + return net.ParseIP(ip) +} + +func IpPort(ip net.IP, port string) string { + if ip.To4() == nil { + return fmt.Sprintf("[%s]:%s", ip.String(), port) + } + return fmt.Sprintf("%s:%s", ip.String(), port) +} diff --git a/p2p/transport/websocket/real_ip_test.go b/p2p/transport/websocket/real_ip_test.go new file mode 100644 index 0000000000..44be10dfe6 --- /dev/null +++ b/p2p/transport/websocket/real_ip_test.go @@ -0,0 +1,150 @@ +package websocket + +import ( + "github.com/stretchr/testify/require" + "net" + "net/http" + "testing" +) + +func TestDefaultGetRealAddr(t *testing.T) { + tests := []struct { + name string + remoteAddr string + header http.Header + want string + }{ + { + name: "basic remote addr without header", + remoteAddr: "192.168.1.1:1234", + header: http.Header{}, + want: "192.168.1.1:1234", + }, + { + name: "with Forwarded header", + remoteAddr: "10.0.0.1:1234", + header: http.Header{ + "Forwarded": []string{"for=192.168.1.2"}, + }, + want: "192.168.1.2:1234", + }, + { + name: "with Forwarded header and IPv6", + remoteAddr: "[::1]:1234", + header: http.Header{ + "Forwarded": []string{`for="[2001:db8:cafe::17]"`}, + }, + want: "[2001:db8:cafe::17]:1234", + }, + { + name: "with X-Forwarded-For header", + remoteAddr: "10.0.0.1:1234", + header: http.Header{ + "X-Forwarded-For": []string{"192.168.1.2, 10.0.0.1"}, + }, + want: "192.168.1.2:1234", + }, + { + name: "with multiple Forwarded values", + remoteAddr: "10.0.0.1:1234", + header: http.Header{ + "Forwarded": []string{"for=192.168.1.2;by=proxy1, for=192.168.1.3"}, + }, + want: "192.168.1.2:1234", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := DefaultGetRealAddr(&fakeAddr{tt.remoteAddr}, tt.header) + require.Equal(t, tt.want, got) + }) + } +} + +func TestValidateIp(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + { + name: "Valid IPv4", + input: "192.168.1.1", + expected: true, + }, + { + name: "Valid IPv6", + input: "2001:db8::1", + expected: true, + }, + { + name: "Empty string", + input: "", + expected: false, + }, + { + name: "Invalid IP", + input: "invalid-ip", + expected: false, + }, + { + name: "Partial IP", + input: "192.168", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := validateIp(tt.input) + if (got != nil) != tt.expected { + t.Errorf("validateIp(%q) = %v, want %v", tt.input, got, tt.expected) + } + }) + } +} + +func TestIpPort(t *testing.T) { + tests := []struct { + name string + ip string + port string + expected string + }{ + { + name: "IPv4 address", + ip: "192.168.1.1", + port: "8080", + expected: "192.168.1.1:8080", + }, + { + name: "IPv6 address", + ip: "2001:db8::1", + port: "8080", + expected: "[2001:db8::1]:8080", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + got := IpPort(ip, tt.port) + if got != tt.expected { + t.Errorf("IpPort() = %v, want %v", got, tt.expected) + } + }) + } +} + +type fakeAddr struct { + addr string +} + +func (f fakeAddr) Network() string { + return "tcp" +} + +func (f fakeAddr) String() string { + return f.addr +} diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 0ddf19050e..e22b20b7d6 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "net" "net/http" + "net/netip" "time" "github.com/libp2p/go-libp2p/core/network" @@ -81,6 +82,28 @@ func WithTLSConfig(conf *tls.Config) Option { } } +type RemoteAddrExtractor func(remoteAddr net.Addr, header http.Header) string + +// WithAllowForwardedHeader configures whether to allow the usage of +// Forwarded and X-Forwarded-For header to determine the real client +// IP address when behind a proxy or load balancer. +// configurable a list of trusted proxy IP addresses or CIDR ranges. +// Only when a connection is from a trusted proxy, Forwarded header will be used. +// If remoteAddrExtractor is null, the RFC 7239 extractor is used by default. +func WithAllowForwardedHeader(trustedProxies []netip.Prefix, remoteAddrExtractor RemoteAddrExtractor) Option { + return func(t *WebsocketTransport) error { + if remoteAddrExtractor == nil { + t.remoteAddrExtractor = DefaultGetRealAddr + } else { + t.remoteAddrExtractor = remoteAddrExtractor + } + + t.trustedProxies = trustedProxies + + return nil + } +} + // WebsocketTransport is the actual go-libp2p transport type WebsocketTransport struct { upgrader transport.Upgrader @@ -89,7 +112,9 @@ type WebsocketTransport struct { tlsClientConf *tls.Config tlsConf *tls.Config - sharedTcp *tcpreuse.ConnMgr + sharedTcp *tcpreuse.ConnMgr + trustedProxies []netip.Prefix + remoteAddrExtractor RemoteAddrExtractor } var _ transport.Transport = (*WebsocketTransport)(nil) @@ -236,7 +261,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma return nil, err } - mnc, err := manet.WrapNetConn(NewConn(wscon, isWss)) + mnc, err := manet.WrapNetConn(NewConn(wscon, isWss, "")) if err != nil { wscon.Close() return nil, err @@ -249,7 +274,7 @@ func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { if t.tlsConf != nil { tlsConf = t.tlsConf.Clone() } - l, err := newListener(a, tlsConf, t.sharedTcp) + l, err := newListener(a, tlsConf, t.sharedTcp, t.trustedProxies, t.remoteAddrExtractor) if err != nil { return nil, err } diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index f83576ba2f..abc829f470 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -16,7 +16,9 @@ import ( "math/big" "net" "net/http" + "net/netip" "net/url" + "strconv" "strings" "testing" "time" @@ -623,3 +625,86 @@ func TestSocksProxy(t *testing.T) { }) } } + +func TestWithAllowForwardedHeader(t *testing.T) { + tests := []struct { + name string + trustedProxies []netip.Prefix + remoteAddr string + headers http.Header + expectedAddr string + }{ + { + name: "trusted proxy with forwarded header", + trustedProxies: []netip.Prefix{ + netip.MustParsePrefix("127.0.0.1/32"), + }, + remoteAddr: "127.0.0.1:1234", + headers: http.Header{ + "X-Forwarded-For": []string{"192.168.1.2"}, + }, + expectedAddr: "192.168.1.2:1234", + }, + { + name: "untrusted proxy ignores forwarded header", + trustedProxies: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + }, + remoteAddr: "192.168.1.1:1234", + headers: http.Header{ + "X-Forwarded-For": []string{"192.168.1.2"}, + }, + expectedAddr: "192.168.1.1:1234", + }, + { + name: "no trusted proxies configured", + trustedProxies: []netip.Prefix{}, + remoteAddr: "192.168.1.1:1234", + headers: http.Header{ + "X-Forwarded-For": []string{"192.168.1.2"}, + }, + expectedAddr: "192.168.1.1:1234", + }, + { + name: "no apply configured", + trustedProxies: nil, + remoteAddr: "192.168.1.1:1234", + headers: http.Header{ + "X-Forwarded-For": []string{"192.168.1.2"}, + }, + expectedAddr: "192.168.1.1:1234", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport := &WebsocketTransport{} + + // Configure transport with test proxy settings + if tt.trustedProxies != nil { + err := WithAllowForwardedHeader(tt.trustedProxies, nil)(transport) + require.NoError(t, err) + } + + // Create listener + l, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil, tt.trustedProxies, transport.remoteAddrExtractor) + require.NoError(t, err) + _ = l.nl.Close() + + // Test remote IP extraction + host, portStr, _ := net.SplitHostPort(tt.remoteAddr) + port, _ := strconv.ParseInt(portStr, 10, 16) + remoteAddr := &net.TCPAddr{ + IP: net.ParseIP(host), + Port: int(port), + } + + extractedAddr := extractRemoteAddrForProxy(tt.trustedProxies, + DefaultGetRealAddr, + remoteAddr, + tt.headers) + + require.Equal(t, tt.expectedAddr, extractedAddr) + }) + } +}