diff --git a/proxy/cloudflare_simple_proxy_protocol.go b/proxy/cloudflare_simple_proxy_protocol.go new file mode 100644 index 00000000..4ee910f0 --- /dev/null +++ b/proxy/cloudflare_simple_proxy_protocol.go @@ -0,0 +1,86 @@ +package proxy + +import ( + "encoding/binary" + "net" +) + +const CLOUDFLARE_SIMPLE_PROXY_PROTOCOL_MAGIC = 0x56EC + +// CloudflareSimpleProxyProtocol implements Cloudflares Simple Proxy Protocol +// https://developers.cloudflare.com/spectrum/reference/simple-proxy-protocol-header/ +type CloudflareSimpleProxyProtocol struct{} + +func (cspp *CloudflareSimpleProxyProtocol) HeaderSize() int { + return 38 +} + +func (cspp *CloudflareSimpleProxyProtocol) Parse(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) { + realClientIP := net.IP(packet[2:18]) + realProxyIP := net.IP(packet[18:34]) + realClientPort := int(binary.BigEndian.Uint16(packet[34:36])) + realProxyPort := int(binary.BigEndian.Uint16(packet[36:38])) + + switch v := clientAddress.(type) { + case *net.TCPAddr: + v.IP = realClientIP + v.Port = realClientPort + case *net.UDPAddr: + v.IP = realClientIP + v.Port = realClientPort + } + + switch v := proxyAddress.(type) { + case *net.TCPAddr: + v.IP = realProxyIP + v.Port = realProxyPort + case *net.UDPAddr: + v.IP = realProxyIP + v.Port = realProxyPort + } + + return packet[cspp.HeaderSize():], nil +} + +func (cspp *CloudflareSimpleProxyProtocol) Encode(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) { + var clientIP net.IP + var clientPort int + var proxyIP net.IP + var proxyPort int + + switch v := clientAddress.(type) { + case *net.TCPAddr: + clientIP = v.IP + clientPort = v.Port + case *net.UDPAddr: + clientIP = v.IP + clientPort = v.Port + } + + switch v := proxyAddress.(type) { + case *net.TCPAddr: + proxyIP = v.IP + proxyPort = v.Port + case *net.UDPAddr: + proxyIP = v.IP + proxyPort = v.Port + } + + if clientIP.To4() != nil { + clientIP = clientIP.To16() + } + + if proxyIP.To4() != nil { + proxyIP = proxyIP.To16() + } + + newData := make([]byte, cspp.HeaderSize()+len(packet)) + binary.BigEndian.PutUint16(newData[0:2], CLOUDFLARE_SIMPLE_PROXY_PROTOCOL_MAGIC) + copy(newData[2:18], clientIP) + copy(newData[18:34], proxyIP) + binary.BigEndian.PutUint16(newData[34:36], uint16(clientPort)) + binary.BigEndian.PutUint16(newData[36:38], uint16(proxyPort)) + copy(newData[38:], packet) + + return newData, nil +} diff --git a/proxy/dummy_proxy_protocol.go b/proxy/dummy_proxy_protocol.go new file mode 100644 index 00000000..ec9a1fdc --- /dev/null +++ b/proxy/dummy_proxy_protocol.go @@ -0,0 +1,24 @@ +package proxy + +import ( + "net" +) + +// DummyProxyProtocol has no proxy header. Returns the data as-is +type DummyProxyProtocol struct{} + +func (dpp *DummyProxyProtocol) HeaderSize() int { + return 0 +} + +func (dpp *DummyProxyProtocol) Parse(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) { + return packet, nil +} + +func (dpp *DummyProxyProtocol) Encode(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) { + return packet, nil +} + +func NewDummyProxyProtocol() *DummyProxyProtocol { + return &DummyProxyProtocol{} +} diff --git a/proxy/haproxy_proxy_protocol_v2.go b/proxy/haproxy_proxy_protocol_v2.go new file mode 100644 index 00000000..a262d6cc --- /dev/null +++ b/proxy/haproxy_proxy_protocol_v2.go @@ -0,0 +1,175 @@ +package proxy + +import ( + "encoding/binary" + "fmt" + "net" +) + +const ( + HAPROXY_V2_SIGNATURE = "\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A" + + // * Version and Command + HAPROXY_V2_VERSION = 0x2 + HAPROXY_V2_CMD_LOCAL = 0x0 + HAPROXY_V2_CMD_PROXY = 0x1 + + // * Address families + HAPROXY_V2_AF_UNSPEC = 0x0 + HAPROXY_V2_AF_INET = 0x1 + HAPROXY_V2_AF_INET6 = 0x2 + HAPROXY_V2_AF_UNIX = 0x3 + + // * Transport protocols + HAPROXY_V2_TRANSPORT_UNSPEC = 0x0 + HAPROXY_V2_TRANSPORT_STREAM = 0x1 + HAPROXY_V2_TRANSPORT_DGRAM = 0x2 + + // * Combined protocol bytes + HAPROXY_V2_PROTO_UNSPEC = 0x00 + HAPROXY_V2_PROTO_TCP4 = 0x11 + HAPROXY_V2_PROTO_UDP4 = 0x12 + HAPROXY_V2_PROTO_TCP6 = 0x21 + HAPROXY_V2_PROTO_UDP6 = 0x22 + HAPROXY_V2_PROTO_UNIX_STREAM = 0x31 + HAPROXY_V2_PROTO_UNIX_DGRAM = 0x32 +) + +// HAProxyProxyProtocolV2 implements HAProxys PROXY protocol version 2 (binary format) +// https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt#:~:text=2.2.%20Binary%20header%20format%20(version%202) +type HAProxyProxyProtocolV2 struct{} + +func (hpp *HAProxyProxyProtocolV2) HeaderSize() int { + // * V2 is variable length (16 bytes minimum + address length) + // * Return -1 to indicate variable length + return -1 +} + +func (hpp *HAProxyProxyProtocolV2) Parse(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) { + // TODO - Add validation checks here + + versionCommand := packet[12] + version := (versionCommand >> 4) & 0x0F + command := versionCommand & 0x0F + + if version != HAPROXY_V2_VERSION { + return nil, fmt.Errorf("unsupported PROXY protocol version: 0x%x", version) + } + + family := packet[13] + addressLengthgth := binary.BigEndian.Uint16(packet[14:16]) + totalHeaderLen := 16 + int(addressLengthgth) + + if command == HAPROXY_V2_CMD_LOCAL { + return packet[totalHeaderLen:], nil + } + + addressData := packet[16:totalHeaderLen] + + switch family { + case HAPROXY_V2_PROTO_TCP4, HAPROXY_V2_PROTO_UDP4: + realClientIP := net.IPv4(addressData[0], addressData[1], addressData[2], addressData[3]) + realClientPort := int(binary.BigEndian.Uint16(addressData[8:10])) + + switch v := clientAddress.(type) { + case *net.TCPAddr: + v.IP = realClientIP + v.Port = realClientPort + case *net.UDPAddr: + v.IP = realClientIP + v.Port = realClientPort + } + case HAPROXY_V2_PROTO_TCP6, HAPROXY_V2_PROTO_UDP6: + realClientIP := net.IP(addressData[0:16]) + realClientPort := int(binary.BigEndian.Uint16(addressData[32:34])) + + switch v := clientAddress.(type) { + case *net.TCPAddr: + v.IP = realClientIP + v.Port = realClientPort + case *net.UDPAddr: + v.IP = realClientIP + v.Port = realClientPort + } + } + + return packet[totalHeaderLen:], nil +} + +func (hpp *HAProxyProxyProtocolV2) Encode(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) { + var clientIP net.IP + var clientPort int + var proxyIP net.IP + var proxyPort int + var protocol byte + var addressLength uint16 + + switch v := clientAddress.(type) { + case *net.TCPAddr: + clientIP = v.IP + clientPort = v.Port + case *net.UDPAddr: + clientIP = v.IP + clientPort = v.Port + } + + // TODO - I'm almost positive this is wrong, the PROXY protocol docs say this is the "destination" but it's unclear if that's the proxy server or ustream server + switch v := proxyAddress.(type) { + case *net.TCPAddr: + proxyIP = v.IP + proxyPort = v.Port + case *net.UDPAddr: + proxyIP = v.IP + proxyPort = v.Port + } + + var addressData []byte + if clientIP.To4() != nil { + clientIP = clientIP.To4() + proxyIP = proxyIP.To4() + + switch clientAddress.(type) { + case *net.TCPAddr: + protocol = HAPROXY_V2_PROTO_TCP4 + case *net.UDPAddr: + protocol = HAPROXY_V2_PROTO_UDP4 + } + + addressLength = 12 + addressData = make([]byte, 12) + copy(addressData[0:4], clientIP) + copy(addressData[4:8], proxyIP) + binary.BigEndian.PutUint16(addressData[8:10], uint16(clientPort)) + binary.BigEndian.PutUint16(addressData[10:12], uint16(proxyPort)) + } else { + clientIP = clientIP.To16() + proxyIP = proxyIP.To16() + + switch clientAddress.(type) { + case *net.TCPAddr: + protocol = HAPROXY_V2_PROTO_TCP6 + case *net.UDPAddr: + protocol = HAPROXY_V2_PROTO_UDP6 + } + + addressLength = 36 + addressData = make([]byte, 36) + copy(addressData[0:16], clientIP) + copy(addressData[16:32], proxyIP) + binary.BigEndian.PutUint16(addressData[32:34], uint16(clientPort)) + binary.BigEndian.PutUint16(addressData[34:36], uint16(proxyPort)) + } + + header := make([]byte, 16) + copy(header[0:12], []byte(HAPROXY_V2_SIGNATURE)) + header[12] = (HAPROXY_V2_VERSION << 4) | HAPROXY_V2_CMD_PROXY + header[13] = protocol + binary.BigEndian.PutUint16(header[14:16], addressLength) + + newData := make([]byte, 16+int(addressLength)+len(packet)) + copy(newData[0:16], header) + copy(newData[16:16+addressLength], addressData) + copy(newData[16+addressLength:], packet) + + return newData, nil +} diff --git a/proxy/proxy_protocol.go b/proxy/proxy_protocol.go new file mode 100644 index 00000000..e99f5325 --- /dev/null +++ b/proxy/proxy_protocol.go @@ -0,0 +1,16 @@ +package proxy + +import "net" + +type ProxyProtocol interface { + // HeaderSize returns the size of the proxy protocol header + // attached to the start of all packets + HeaderSize() int + + // Parse extracts proxy header from the packet. The client address and proxy address + // are updated in this function. Returns the real packet payload after the proxy header + Parse(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) + + // Encode wraps a payload with the proxy header for the given addresses + Encode(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) +} diff --git a/proxy/prudp_simple_proxy_protocol.go b/proxy/prudp_simple_proxy_protocol.go new file mode 100644 index 00000000..7d6791fb --- /dev/null +++ b/proxy/prudp_simple_proxy_protocol.go @@ -0,0 +1,54 @@ +package proxy + +import ( + "encoding/binary" + "net" +) + +const PRUDP_SIMPLE_PROXY_PROTOCOL_VERSION = 0 + +// PRUDPSimpleProxyProtocol implements a custom proxy header for PRUDP. Should only be used when +// one of the other protocols cannot be used +type PRUDPSimpleProxyProtocol struct{} + +func (pspp *PRUDPSimpleProxyProtocol) HeaderSize() int { + return 7 +} + +func (pspp *PRUDPSimpleProxyProtocol) Parse(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) { + realClientIP := net.IP(packet[1:5]) + realClientPort := int(binary.BigEndian.Uint16(packet[5:7])) + + switch v := clientAddress.(type) { + case *net.TCPAddr: + v.IP = realClientIP + v.Port = realClientPort + case *net.UDPAddr: + v.IP = realClientIP + v.Port = realClientPort + } + + return packet[pspp.HeaderSize():], nil +} + +func (pspp *PRUDPSimpleProxyProtocol) Encode(clientAddress net.Addr, proxyAddress net.Addr, packet []byte) ([]byte, error) { + var ipv4 net.IP + var port int + + switch v := clientAddress.(type) { + case *net.TCPAddr: + ipv4 = v.IP.To4() + port = v.Port + case *net.UDPAddr: + ipv4 = v.IP.To4() + port = v.Port + } + + newData := make([]byte, pspp.HeaderSize()+len(packet)) + newData[0] = PRUDP_SIMPLE_PROXY_PROTOCOL_VERSION + copy(newData[1:5], ipv4) + binary.BigEndian.PutUint16(newData[5:7], uint16(port)) + copy(newData[7:], packet) + + return newData, nil +} diff --git a/proxy/prudp_simple_proxy_server.go b/proxy/prudp_simple_proxy_server.go new file mode 100644 index 00000000..1fcba10d --- /dev/null +++ b/proxy/prudp_simple_proxy_server.go @@ -0,0 +1,134 @@ +package proxy + +import ( + "encoding/binary" + "fmt" + "net" + "sync" + + "github.com/PretendoNetwork/plogger-go" +) + +var logger = plogger.NewLogger() + +type ProxyServer struct { + mappings map[int]*ProxyMapping + wg sync.WaitGroup +} + +type ProxyMapping struct { + listenPort int + targetAddr *net.UDPAddr + conn *net.UDPConn +} + +func (ps *ProxyServer) handlePort(mapping *ProxyMapping) { + defer ps.wg.Done() + + buffer := make([]byte, 64000) + + for { + read, addr, err := mapping.conn.ReadFromUDP(buffer) + if err != nil { + logger.Errorf("Read error on port %d: %s\n", mapping.listenPort, err) + return + } + + data := make([]byte, read) + copy(data, buffer[:read]) + + if addr.String() == mapping.targetAddr.String() { + ps.handleServerPacket(mapping, data) + } else { + ps.handleClientPacket(mapping, data, addr) + } + } +} + +// handleClientPacket prepends client IP:port and forwards to upstream server +func (ps *ProxyServer) handleClientPacket(mapping *ProxyMapping, data []byte, clientAddr *net.UDPAddr) { + ip4 := clientAddr.IP.To4() + if ip4 == nil { + logger.Warningf("Warning: non-IPv4 address: %s\n", clientAddr.String()) + return + } + + packet := make([]byte, 7+len(data)) + + packet[0] = PRUDP_SIMPLE_PROXY_PROTOCOL_VERSION + copy(packet[1:5], ip4) + binary.BigEndian.PutUint16(packet[5:7], uint16(clientAddr.Port)) + copy(packet[7:], data) + + _, err := mapping.conn.WriteToUDP(packet, mapping.targetAddr) + if err != nil { + logger.Errorf("Error forwarding to server: %s\n", err) + } +} + +// handleServerPacket extracts destination address and forwards payload to client +func (ps *ProxyServer) handleServerPacket(mapping *ProxyMapping, data []byte) { + if len(data) < 6 { + logger.Warningf("Warning: packet too short (%d bytes)\n", len(data)) + return + } + + // * Always version 0 for now, skip the version byte + clientIP := net.IPv4(data[1], data[2], data[3], data[4]) + clientPort := int(binary.BigEndian.Uint16(data[5:7])) + clientAddr := &net.UDPAddr{ + IP: clientIP, + Port: clientPort, + } + + payload := data[7:] + + _, err := mapping.conn.WriteToUDP(payload, clientAddr) + if err != nil { + logger.Errorf("Error forwarding to client: %s\n", err) + } +} + +// AddMapping adds a port mapping (listenPort -> targetAddr) +func (ps *ProxyServer) AddMapping(listenPort int, targetAddr string) error { + target, err := net.ResolveUDPAddr("udp", targetAddr) + if err != nil { + return fmt.Errorf("failed to resolve target address %s: %w", targetAddr, err) + } + + listen := &net.UDPAddr{ + IP: net.IPv4zero, + Port: listenPort, + } + + conn, err := net.ListenUDP("udp", listen) + if err != nil { + return fmt.Errorf("failed to listen on port %d: %w", listenPort, err) + } + + mapping := &ProxyMapping{ + listenPort: listenPort, + targetAddr: target, + conn: conn, + } + + ps.mappings[listenPort] = mapping + + return nil +} + +// Start begins listening on all mapped ports +func (ps *ProxyServer) Start() { + for _, mapping := range ps.mappings { + ps.wg.Add(1) + go ps.handlePort(mapping) + } + ps.wg.Wait() +} + +// NewProxyServer creates a new UDP proxy server +func NewProxyServer() *ProxyServer { + return &ProxyServer{ + mappings: make(map[int]*ProxyMapping), + } +} diff --git a/prudp_server.go b/prudp_server.go index c2c71a08..940860fd 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -9,6 +9,7 @@ import ( "time" "github.com/PretendoNetwork/nex-go/v2/constants" + "github.com/PretendoNetwork/nex-go/v2/proxy" "github.com/lxzan/gws" ) @@ -28,6 +29,7 @@ type PRUDPServer struct { PRUDPV0Settings *PRUDPV0Settings PRUDPV1Settings *PRUDPV1Settings UseVerboseRMC bool + ProxyProtocol proxy.ProxyProtocol } // BindPRUDPEndPoint binds a provided PRUDPEndPoint to the server @@ -132,7 +134,12 @@ func (ps *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, return nil } - readStream := NewByteStreamIn(packetData, ps.LibraryVersions, ps.ByteStreamSettings) + // * The constant calls to NewSocketConnection is just moved from + // * ps.processPacket, but should we really be doing that? Seems wasteful + // TODO - Proxied mode doesn't work on WebSocket connections + socket := NewSocketConnection(ps, address, webSocketConnection) + realPacketPayload, _ := ps.ProxyProtocol.Parse(socket.Address, socket.ProxyAddress, packetData) + readStream := NewByteStreamIn(realPacketPayload, ps.LibraryVersions, ps.ByteStreamSettings) var packets []PRUDPPacketInterface @@ -149,36 +156,36 @@ func (ps *PRUDPServer) handleSocketMessage(packetData []byte, address net.Addr, } for _, packet := range packets { - go ps.processPacket(packet, address, webSocketConnection) + go ps.processPacket(packet, socket) } return nil } -func (ps *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Addr, webSocketConnection *gws.Conn) { +func (ps *PRUDPServer) processPacket(packet PRUDPPacketInterface, socket *SocketConnection) { if !ps.Endpoints.Has(packet.DestinationVirtualPortStreamID()) { - logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", address.String(), packet.DestinationVirtualPortStreamID()) + logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", socket.Address.String(), packet.DestinationVirtualPortStreamID()) return } endpoint, ok := ps.Endpoints.Get(packet.DestinationVirtualPortStreamID()) if !ok { - logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", address.String(), packet.DestinationVirtualPortStreamID()) + logger.Warningf("Client %s trying to connect to unbound PRUDPEndPoint %d", socket.Address.String(), packet.DestinationVirtualPortStreamID()) return } if packet.DestinationVirtualPortStreamType() != packet.SourceVirtualPortStreamType() { - logger.Warningf("Client %s trying to use non matching destination and source stream types %d and %d", address.String(), packet.DestinationVirtualPortStreamType(), packet.SourceVirtualPortStreamType()) + logger.Warningf("Client %s trying to use non matching destination and source stream types %d and %d", socket.Address.String(), packet.DestinationVirtualPortStreamType(), packet.SourceVirtualPortStreamType()) return } if packet.DestinationVirtualPortStreamType() > constants.StreamTypeRelay { - logger.Warningf("Client %s trying to use invalid to destination stream type %d", address.String(), packet.DestinationVirtualPortStreamType()) + logger.Warningf("Client %s trying to use invalid to destination stream type %d", socket.Address.String(), packet.DestinationVirtualPortStreamType()) return } if packet.SourceVirtualPortStreamType() > constants.StreamTypeRelay { - logger.Warningf("Client %s trying to use invalid to source stream type %d", address.String(), packet.DestinationVirtualPortStreamType()) + logger.Warningf("Client %s trying to use invalid to source stream type %d", socket.Address.String(), packet.DestinationVirtualPortStreamType()) return } @@ -196,11 +203,10 @@ func (ps *PRUDPServer) processPacket(packet PRUDPPacketInterface, address net.Ad } if invalidSourcePort { - logger.Warningf("Client %s trying to use invalid to source port number %d. Port number too large", address.String(), sourcePortNumber) + logger.Warningf("Client %s trying to use invalid to source port number %d. Port number too large", socket.Address.String(), sourcePortNumber) return } - socket := NewSocketConnection(ps, address, webSocketConnection) endpoint.processPacket(packet, socket) } @@ -304,13 +310,14 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { // SendRaw will send the given socket the provided packet func (ps *PRUDPServer) SendRaw(socket *SocketConnection, data []byte) { // TODO - Should this return the error too? - var err error - if address, ok := socket.Address.(*net.UDPAddr); ok && ps.udpSocket != nil { - _, err = ps.udpSocket.WriteToUDP(data, address) + sendData, _ := ps.ProxyProtocol.Encode(socket.Address, socket.ProxyAddress, data) + + if address, ok := socket.ProxyAddress.(*net.UDPAddr); ok && ps.udpSocket != nil { + _, err = ps.udpSocket.WriteToUDP(sendData, address) } else if socket.WebSocketConnection != nil { - err = socket.WebSocketConnection.WriteMessage(gws.OpcodeBinary, data) + err = socket.WebSocketConnection.WriteMessage(gws.OpcodeBinary, sendData) } if err != nil { @@ -344,5 +351,6 @@ func NewPRUDPServer() *PRUDPServer { ByteStreamSettings: NewByteStreamSettings(), PRUDPV0Settings: NewPRUDPV0Settings(), PRUDPV1Settings: NewPRUDPV1Settings(), + ProxyProtocol: proxy.NewDummyProxyProtocol(), } } diff --git a/socket_connection.go b/socket_connection.go index af6ce250..47aded3b 100644 --- a/socket_connection.go +++ b/socket_connection.go @@ -10,7 +10,8 @@ import ( // A single socket may have many PRUDP connections open on it. type SocketConnection struct { Server *PRUDPServer // * PRUDP server the socket is connected to - Address net.Addr // * Sockets address + ProxyAddress net.Addr // * Address of the proxy server. When not proxied, same as SocketConnection.Address + Address net.Addr // * Address of the real client WebSocketConnection *gws.Conn // * Only used in PRUDPLite } @@ -18,7 +19,29 @@ type SocketConnection struct { func NewSocketConnection(server *PRUDPServer, address net.Addr, webSocketConnection *gws.Conn) *SocketConnection { return &SocketConnection{ Server: server, + ProxyAddress: cloneAddr(address), // * Need to make a copy of the net.Addr so it can be worked with independently Address: address, WebSocketConnection: webSocketConnection, } } + +// TODO - This is sort of a hack, replace this with our own type that implements net.Addr and adds this functionality natively? +func cloneAddr(addr net.Addr) net.Addr { + switch v := addr.(type) { + case *net.TCPAddr: + return &net.TCPAddr{ + IP: append([]byte(nil), v.IP...), + Port: v.Port, + Zone: v.Zone, + } + case *net.UDPAddr: + return &net.UDPAddr{ + IP: append([]byte(nil), v.IP...), + Port: v.Port, + Zone: v.Zone, + } + } + + // TODO - Maybe not safe? + return nil +}