Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion p2p/transport/websocket/addrs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) {
}

func TestListeningOnDNSAddr(t *testing.T) {
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil)
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil, false, nil)
require.NoError(t, err)
addr := ln.Multiaddr()
first, rest := ma.SplitFirst(addr)
Expand Down
7 changes: 5 additions & 2 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ var _ manet.Conn = (*Conn)(nil)
// NewConn creates a Conn given a regular gorilla/websocket Conn.
//
// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release.
func NewConn(raw *ws.Conn, secure bool) *Conn {
func NewConn(raw *ws.Conn, secure bool, remoteAddr string) *Conn {
lna := NewAddrWithScheme(raw.LocalAddr().String(), secure)
laddr, err := manet.FromNetAddr(lna)
if err != nil {
log.Errorf("BUG: invalid localaddr on websocket conn", raw.LocalAddr())
return nil
}

rna := NewAddrWithScheme(raw.RemoteAddr().String(), secure)
if remoteAddr == "" {
remoteAddr = raw.RemoteAddr().String()
}
rna := NewAddrWithScheme(remoteAddr, secure)
raddr, err := manet.FromNetAddr(rna)
if err != nil {
log.Errorf("BUG: invalid remoteaddr on websocket conn", raw.RemoteAddr())
Expand Down
43 changes: 41 additions & 2 deletions p2p/transport/websocket/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type listener struct {
// so we can't rely on checking if server.TLSConfig is set.
isWss bool

allowForwardedHeader bool
trustedProxies []*net.IPNet

laddr ma.Multiaddr

incoming chan *Conn
Expand All @@ -52,7 +55,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {

// newListener creates a new listener from a raw net.Listener.
// tlsConf may be nil (for unencrypted websockets).
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) {
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr, allowForwardedHeader bool, trustedProxies []*net.IPNet) (*listener, error) {
parsed, err := parseWebsocketMultiaddr(a)
if err != nil {
return nil, err
Expand Down Expand Up @@ -106,6 +109,9 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
laddr: parsed.toMultiaddr(),
incoming: make(chan *Conn),
closed: make(chan struct{}),

allowForwardedHeader: allowForwardedHeader,
trustedProxies: trustedProxies,
}
ln.server = http.Server{Handler: ln, ErrorLog: stdLog}
if parsed.isWSS {
Expand All @@ -124,13 +130,46 @@ func (l *listener) serve() {
}
}

// isProxyTrusted checks if the given IP address belongs to a trusted proxy
func (l *listener) isProxyTrusted(remoteAddr string) bool {
if len(l.trustedProxies) == 0 {
// allow any
return true
}

// Extract IP from remoteAddr (format: "IP:port")
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return false
}

ip := net.ParseIP(host)
if ip == nil {
return false
}

// Check if IP is in any of the trusted CIDR ranges
for _, cidr := range l.trustedProxies {
if cidr.Contains(ip) {
return true
}
}
return false
}

func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
// The upgrader writes a response for us.
return
}
nc := NewConn(c, l.isWss)

var overrideRemoteAddr string
if l.allowForwardedHeader && l.isProxyTrusted(c.RemoteAddr().String()) {
overrideRemoteAddr = GetRealIP(c.RemoteAddr(), r.Header)
}

nc := NewConn(c, l.isWss, overrideRemoteAddr)
if nc == nil {
c.Close()
w.WriteHeader(500)
Expand Down
67 changes: 67 additions & 0 deletions p2p/transport/websocket/listener_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package websocket

import (
"github.com/stretchr/testify/require"
"testing"
)

func TestIsProxyTrusted(t *testing.T) {
tests := []struct {
name string
trustedProxies []string
remoteAddr string
want bool
}{
{
name: "Single IP trusted",
trustedProxies: []string{"192.168.1.1"},
remoteAddr: "192.168.1.1:1234",
want: true,
},
{
name: "IP not in trusted list",
trustedProxies: []string{"192.168.1.1"},
remoteAddr: "192.168.1.2:1234",
want: false,
},
{
name: "CIDR range trusted",
trustedProxies: []string{"192.168.1.0/24"},
remoteAddr: "192.168.1.100:1234",
want: true,
},
{
name: "IPv6 address trusted",
trustedProxies: []string{"2001:db8::1"},
remoteAddr: "[2001:db8::1]:1234",
want: true,
},
{
name: "IPv6 CIDR range trusted",
trustedProxies: []string{"2001:db8::/32"},
remoteAddr: "[2001:db8:1:2:3:4:5:6]:1234",
want: true,
},
{
name: "Empty trusted proxies list",
trustedProxies: []string{},
remoteAddr: "192.168.1.1:1234",
want: true, // Everything is trusted when list is empty
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
transport := &WebsocketTransport{}
err := WithAllowForwardedHeader(tt.trustedProxies)(transport)
require.NoError(t, err)
l := &listener{
trustedProxies: transport.trustedProxies,
}
got := l.isProxyTrusted(tt.remoteAddr)
if got != tt.want {
t.Errorf("isProxyTrusted() = %v, want %v", got, tt.want)
}
})
}
}
79 changes: 79 additions & 0 deletions p2p/transport/websocket/real_ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package websocket

import (
"fmt"
"net"
"net/http"
"strconv"
"strings"
)

func GetRealIP(addr net.Addr, h http.Header) string {
remoteAddr := addr.String()
remoteIp := GetRealIPFromHeader(h)
if remoteIp != nil {
remoteTcpAddr, ok := addr.(*net.TCPAddr)
if ok {
remoteAddr = IpPort(remoteIp, strconv.Itoa(remoteTcpAddr.Port))
} else {
_, port, err := net.SplitHostPort(remoteAddr)
if err == nil {
remoteAddr = IpPort(remoteIp, port)
}
}
}
return remoteAddr
}

// GetRealIPFromHeader extracts the client's real IP address from HTTP request header.
// It checks various proxy header to find the actual IP.
func GetRealIPFromHeader(h http.Header) net.IP {
// Check X-Real-IP header (used by Nginx and others)
ipStr := h.Get("X-Real-IP")
if ip := validateIp(ipStr); ip != nil {
return ip
}

// Check X-Forwarded-For header (used by most proxies)
// Format: client, proxy1, proxy2, ...
ipStr = h.Get("X-Forwarded-For")
if ipStr != "" {
// Extract the first IP from the comma-separated list
ips := strings.Split(ipStr, ",")
for _, ipItem := range ips {
ipItem = strings.TrimSpace(ipItem)
if ip := validateIp(ipItem); ip != nil {
return ip
}
}
}

// Check CF-Connecting-IP header (used by Cloudflare)
ipStr = h.Get("CF-Connecting-IP")
if ip := validateIp(ipStr); ip != nil {
return ip
}

// Check True-Client-IP header (used by Akamai, Cloudflare, etc.)
ipStr = h.Get("True-Client-IP")
if ip := validateIp(ipStr); ip != nil {
return ip
}

return nil
}

// validateIp checks if a string is a valid IP address
func validateIp(ip string) net.IP {
if ip == "" {
return nil
}
return net.ParseIP(ip)
}

func IpPort(ip net.IP, port string) string {
if ip.To4() == nil {
return fmt.Sprintf("[%s]:%s", ip.String(), port)
}
return fmt.Sprintf("%s:%s", ip.String(), port)
}
Loading