Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion p2p/transport/websocket/addrs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) {
}

func TestListeningOnDNSAddr(t *testing.T) {
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil)
ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil, nil, nil, nil)
require.NoError(t, err)
addr := ln.Multiaddr()
first, rest := ma.SplitFirst(addr)
Expand Down
7 changes: 5 additions & 2 deletions p2p/transport/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ var _ manet.Conn = (*Conn)(nil)
// NewConn creates a Conn given a regular gorilla/websocket Conn.
//
// Deprecated: There's no reason to use this method externally. It'll be unexported in a future release.
func NewConn(raw *ws.Conn, secure bool) *Conn {
func NewConn(raw *ws.Conn, secure bool, remoteAddr string) *Conn {
lna := NewAddrWithScheme(raw.LocalAddr().String(), secure)
laddr, err := manet.FromNetAddr(lna)
if err != nil {
log.Errorf("BUG: invalid localaddr on websocket conn", raw.LocalAddr())
return nil
}

rna := NewAddrWithScheme(raw.RemoteAddr().String(), secure)
if remoteAddr == "" {
remoteAddr = raw.RemoteAddr().String()
}
rna := NewAddrWithScheme(remoteAddr, secure)
raddr, err := manet.FromNetAddr(rna)
if err != nil {
log.Errorf("BUG: invalid remoteaddr on websocket conn", raw.RemoteAddr())
Expand Down
43 changes: 41 additions & 2 deletions p2p/transport/websocket/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"net/http"
"net/netip"
"sync"

"go.uber.org/zap"
Expand All @@ -29,6 +30,8 @@ type listener struct {
// so we can't rely on checking if server.TLSConfig is set.
isWss bool

remoteIpExtractor RemoteAddrExtractor

laddr ma.Multiaddr

incoming chan *Conn
Expand All @@ -52,7 +55,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr {

// newListener creates a new listener from a raw net.Listener.
// tlsConf may be nil (for unencrypted websockets).
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr) (*listener, error) {
func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMgr, trustedProxies []netip.Prefix, realAddrExtractor RemoteAddrExtractor) (*listener, error) {
parsed, err := parseWebsocketMultiaddr(a)
if err != nil {
return nil, err
Expand All @@ -62,6 +65,16 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a)
}

var remoteAddrExtractor RemoteAddrExtractor = func(remoteAddr net.Addr, header http.Header) string {
return remoteAddr.String()
}

if realAddrExtractor != nil && len(trustedProxies) > 0 {
remoteAddrExtractor = func(remoteAddr net.Addr, header http.Header) string {
return extractRemoteAddrForProxy(trustedProxies, realAddrExtractor, remoteAddr, header)
}
}

var nl net.Listener

if sharedTcp == nil {
Expand Down Expand Up @@ -106,6 +119,8 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config, sharedTcp *tcpreuse.ConnMg
laddr: parsed.toMultiaddr(),
incoming: make(chan *Conn),
closed: make(chan struct{}),

remoteIpExtractor: remoteAddrExtractor,
}
ln.server = http.Server{Handler: ln, ErrorLog: stdLog}
if parsed.isWSS {
Expand All @@ -130,7 +145,8 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// The upgrader writes a response for us.
return
}
nc := NewConn(c, l.isWss)

nc := NewConn(c, l.isWss, l.remoteIpExtractor(c.RemoteAddr(), r.Header))
if nc == nil {
c.Close()
w.WriteHeader(500)
Expand Down Expand Up @@ -174,6 +190,29 @@ func (l *listener) Multiaddr() ma.Multiaddr {
return l.laddr
}

func isProxyTrusted(trustedProxies []netip.Prefix, remoteAddr net.Addr) bool {
remoteAddrPort, err := netip.ParseAddrPort(remoteAddr.String())
if err != nil {
return false
}

for _, prefix := range trustedProxies {
if prefix.Contains(remoteAddrPort.Addr()) {
return true
}
}

return false
}

// extractRemoteAddrForProxy extract real ip if the given IP address belongs to a trusted proxy
func extractRemoteAddrForProxy(trustedProxies []netip.Prefix, realAddrExtractor RemoteAddrExtractor, remoteAddr net.Addr, header http.Header) string {
if isProxyTrusted(trustedProxies, remoteAddr) {
return realAddrExtractor(remoteAddr, header)
}
return remoteAddr.String()
}

type transportListener struct {
transport.Listener
}
Expand Down
67 changes: 67 additions & 0 deletions p2p/transport/websocket/listener_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package websocket

import (
"net/netip"
"testing"
)

func TestIsProxyTrusted(t *testing.T) {
tests := []struct {
name string
trustedProxies []string
remoteAddr string
want bool
}{
{
name: "Single IP trusted",
trustedProxies: []string{"192.168.1.1/32"},
remoteAddr: "192.168.1.1:1234",
want: true,
},
{
name: "IP not in trusted list",
trustedProxies: []string{"192.168.1.1/32"},
remoteAddr: "192.168.1.2:1234",
want: false,
},
{
name: "CIDR range trusted",
trustedProxies: []string{"192.168.2.0/24", "192.168.1.0/24"},
remoteAddr: "192.168.1.100:1234",
want: true,
},
{
name: "IPv6 address trusted",
trustedProxies: []string{"2001:db8::1/128"},
remoteAddr: "[2001:db8::1]:1234",
want: true,
},
{
name: "IPv6 CIDR range trusted",
trustedProxies: []string{"2001:db8::/32", "2001:db9::/32"},
remoteAddr: "[2001:db8:1:2:3:4:5:6]:1234",
want: true,
},
{
name: "Empty trusted proxies list",
trustedProxies: []string{},
remoteAddr: "192.168.1.1:1234",
want: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var trustedProxies []netip.Prefix
for _, cidr := range tt.trustedProxies {
prefix, _ := netip.ParsePrefix(cidr)
trustedProxies = append(trustedProxies, prefix)
}

got := isProxyTrusted(trustedProxies, &fakeAddr{tt.remoteAddr})
if got != tt.want {
t.Errorf("isProxyTrusted() = %v, want %v", got, tt.want)
}
})
}
}
93 changes: 93 additions & 0 deletions p2p/transport/websocket/real_ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package websocket

import (
"fmt"
"net"
"net/http"
"strconv"
"strings"
)

// DefaultGetRealAddr implements RFC 7239 Forwarded header parsing
func DefaultGetRealAddr(addr net.Addr, h http.Header) string {
remoteAddr := addr.String()
remoteIp := GetRealIPFromHeader(h)
if remoteIp != nil {
remoteTcpAddr, ok := addr.(*net.TCPAddr)
if ok {
remoteAddr = IpPort(remoteIp, strconv.Itoa(remoteTcpAddr.Port))
} else {
_, port, err := net.SplitHostPort(remoteAddr)
if err == nil {
remoteAddr = IpPort(remoteIp, port)
}
}
}
return remoteAddr

}

// GetRealIPFromHeader extracts the client's real IP address from HTTP request header.
// implements RFC 7239 Forwarded header parsing
func GetRealIPFromHeader(h http.Header) net.IP {
// RFC 7239 Forwarded header
if forwarded := h.Get("Forwarded"); forwarded != "" {
if host := parseForwardedHeader(forwarded); host != "" {
if ip := validateIp(host); ip != nil {
return ip
}
}
}

// Fallback to X-Forwarded-For
if xff := h.Get("X-Forwarded-For"); xff != "" {
if host := parseXForwardedFor(xff); host != "" {
if ip := validateIp(host); ip != nil {
return ip
}
}
}

return nil
}

func parseForwardedHeader(value string) string {
parts := strings.Split(value, ",")
if len(parts) == 0 {
return ""
}

pair := strings.TrimSpace(parts[0])
for _, elem := range strings.Split(pair, ";") {
if kv := strings.Split(strings.TrimSpace(elem), "="); len(kv) == 2 {
if strings.ToLower(kv[0]) == "for" {
host := strings.Trim(kv[1], "\"[]")
return host
}
}
}
return ""
}

func parseXForwardedFor(value string) string {
ips := strings.Split(value, ",")
if len(ips) == 0 {
return ""
}
return strings.TrimSpace(ips[0])
}

// validateIp checks if a string is a valid IP address
func validateIp(ip string) net.IP {
if ip == "" {
return nil
}
return net.ParseIP(ip)
}

func IpPort(ip net.IP, port string) string {
if ip.To4() == nil {
return fmt.Sprintf("[%s]:%s", ip.String(), port)
}
return fmt.Sprintf("%s:%s", ip.String(), port)
}
Loading
Loading