diff --git a/README.md b/README.md index 2b77659..02d4610 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,58 @@ # vprox -vprox is a high-performance network proxy acting as a split tunnel VPN. The server accepts peering requests from clients, which then establish WireGuard tunnels that direct all traffic on the client's network interface through the server, with IP masquerading. +vprox is a high-performance network proxy acting as a split tunnel VPN, powered by WireGuard. The server accepts peering requests from clients, which then establish WireGuard tunnels that direct all traffic on the client's network interface through the server, with IP masquerading. Both the client and server commands need root access. The server can have multiple public IP addresses attached, and on cloud providers, it automatically uses the instance metadata endpoint to discover its public IP addresses and start one proxy for each. This property allows the server to be high-availability. In the event of a restart or network partition, the tunnels remain open. If the server's IP address is attached to a new host, clients will automatically re-establish connections. This means that IP addresses can be moved to different hosts in event of an outage. +## Architecture + +In single-tunnel mode, vprox creates one WireGuard interface per connection: + +``` +Client Server +┌──────────┐ UDP :50227 ┌──────────┐ +│ vprox0 │◄──────────────────►│ vprox0 │──► Internet +│ (wg) │ │ (wg) │ (SNAT) +└──────────┘ └──────────┘ +``` + +In multi-tunnel mode (`--tunnels N`), vprox creates N parallel WireGuard interfaces, each on a different UDP port. On the client, a dummy interface with policy routing presents a single `vprox0` device to applications while distributing traffic across all tunnels: + +``` +Client Server +┌──────────┐ policy routing ┌──────────┐ +│ vprox0 │ (dummy, user- │ vprox0 │◄─── UDP :50227 ───► vprox0t0 (wg) +│ │ facing) │ vprox0t1│◄─── UDP :50228 ───► vprox0t1 (wg) +│ │ │ vprox0t2│◄─── UDP :50229 ───► vprox0t2 (wg) ──► Internet +│ │ ip rule: to wg │ vprox0t3│◄─── UDP :50230 ───► vprox0t3 (wg) (SNAT) +│ │ subnet → table └──────────┘ +│ │ with multipath +│ vprox0t0 │◄──────────────────► +│ vprox0t1 │◄──────────────────► +│ vprox0t2 │◄──────────────────► +│ vprox0t3 │◄──────────────────► +└──────────┘ +``` + +Each tunnel uses a different UDP port, so the NIC's RSS (Receive Side Scaling) hashes them to different hardware RX queues. The kernel's multipath routing distributes flows across tunnels using L4 hashing. Applications bind to the single `vprox0` interface and are unaware of the underlying tunnels. + ## Usage +### Prerequisites + On the Linux VPN server and client, install system requirements (`iptables` and `wireguard`). ```bash # On Ubuntu sudo apt install iptables wireguard -# On Fedora +# On Fedora / Amazon Linux sudo dnf install iptables wireguard-tools ``` -Also, you need to set some kernel settings with Sysctl. Enable IPv4 forwarding, and make sure that [`rp_filter`](https://sysctl-explorer.net/net/ipv4/rp_filter/) is set to 2, or masqueraded packets may be filtered out. You can edit your OS configuration file to set this persistently, or set it once below. +Set the required kernel parameters. Enable IPv4 forwarding, and make sure that [`rp_filter`](https://sysctl-explorer.net/net/ipv4/rp_filter/) is set to 2, or masqueraded packets may be filtered out. ```bash # Applies until next reboot @@ -26,7 +60,9 @@ sudo sysctl -w net.ipv4.ip_forward=1 sudo sysctl -w net.ipv4.conf.all.rp_filter=2 ``` -To set up `vprox`, you'll need the private IPv4 address of the server connected to an Internet gateway (use the `ip addr` command), as well as a block of IPs to allocate to the WireGuard subnet between server and client. This has no particular meaning and can be arbitrarily chosen to not overlap with other subnets. +### Basic setup + +You'll need the private IPv4 address of the server connected to an Internet gateway (use the `ip addr` command), as well as a block of IPs to allocate to the WireGuard subnet between server and client. This can be arbitrarily chosen to not overlap with other subnets. ```bash # [Machine A: public IP 1.2.3.4, private IP 172.31.64.125] @@ -42,6 +78,61 @@ Note that Machine B must be able to send UDP packets to port 50227 on Machine A, All outbound network traffic seen by `vprox0` will automatically be forwarded through the WireGuard tunnel. The VPN server masquerades the source IP address. +### Multi-tunnel mode (high throughput) + +A single WireGuard tunnel is encapsulated in one UDP flow (fixed 4-tuple). On cloud providers like AWS, NIC hardware hashes flows to RX queues by this 4-tuple, so a single tunnel is limited to the throughput of one hardware queue — typically ~2-2.5 Gbps on AWS ENA. + +Multi-tunnel mode creates N parallel WireGuard tunnels on different UDP ports, spreading traffic across multiple NIC queues: + +```bash +# Server: 4 parallel tunnels per IP +VPROX_PASSWORD=my-password vprox server --ip 172.31.64.125 --wg-block 240.1.0.0/16 --tunnels 4 + +# Client: 4 parallel tunnels (must be <= server's --tunnels value) +VPROX_PASSWORD=my-password vprox connect 1.2.3.4 --interface vprox0 --tunnels 4 +``` + +Both server and client must use `--tunnels`. The `dummy` kernel module must be available on the client (`sudo modprobe dummy`). + +**Required sysctl on both server and client** for multipath flow distribution: + +```bash +sudo sysctl -w net.ipv4.fib_multipath_hash_policy=1 +``` + +Applications bind to the single `vprox0` interface as before — the multi-tunnel routing is transparent. + +**Choosing the number of tunnels:** Start with `--tunnels 4`. The optimal value depends on the number of CPU cores and NIC queues. On a 4-core server, 4 tunnels will typically saturate the CPU. Adding more tunnels than CPU cores provides diminishing returns since WireGuard encryption becomes the bottleneck. + +### Performance tuning + +For maximum throughput, apply these additional sysctl settings on both server and client: + +```bash +# UDP/Socket buffer sizes (WireGuard uses UDP) +sudo sysctl -w net.core.rmem_max=26214400 +sudo sysctl -w net.core.wmem_max=26214400 +sudo sysctl -w net.core.rmem_default=1048576 +sudo sysctl -w net.core.wmem_default=1048576 + +# Network device backlog (for high packet rates) +sudo sysctl -w net.core.netdev_max_backlog=50000 + +# TCP tuning (for traffic inside the tunnel) +sudo sysctl -w net.ipv4.tcp_rmem="4096 1048576 26214400" +sudo sysctl -w net.ipv4.tcp_wmem="4096 1048576 26214400" +sudo sysctl -w net.ipv4.tcp_congestion_control=bbr + +# Multipath flow hashing (required for multi-tunnel) +sudo sysctl -w net.ipv4.fib_multipath_hash_policy=1 + +# Connection tracking limits (for NAT with many peers) +sudo sysctl -w net.netfilter.nf_conntrack_max=1048576 +``` + +To make these settings persistent across reboots, add them to `/etc/sysctl.d/99-vprox.conf` without the `sudo sysctl -w` prefix, then apply with `sudo sysctl --system`. + + ### Building To build `vprox`, run the following command with Go 1.22+ installed: @@ -56,7 +147,7 @@ This produces a static binary in `./vprox`. On cloud providers like AWS, you can attach [secondary private IP addresses](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/MultipleIP.html) to an interface and associate each of them with a global IPv4 unicast address. -A `vprox` server listening on multiple IP addresses needs to provide `--ip` option once for every IP, and each IP requires its its own WireGuard VPN subnet with a non-overlapping address range. You can pass `--wg-block-per-ip /22` to split the `--wg-block` into smaller blocks for each IP. +A `vprox` server listening on multiple IP addresses needs to provide `--ip` option once for every IP, and each IP requires its own WireGuard VPN subnet with a non-overlapping address range. You can pass `--wg-block-per-ip /22` to split the `--wg-block` into smaller blocks for each IP. On AWS in particular, the `--cloud aws` option allows you to automatically discover the private IP addresses of the server by periodically querying the instance metadata endpoint. @@ -66,8 +157,46 @@ On AWS in particular, the `--cloud aws` option allows you to automatically disco - Supports forwarding IPv4 packets - Works if the server has multiple IPs, specified with `--wg-block-per-ip` - Automatic discovery of IPs using instance metadata endpoints (AWS) -- Only one vprox server may be running on a host +- Multi-tunnel mode for throughput beyond the single NIC queue limit (`--tunnels N`) +- WireGuard interfaces tuned with GSO/GRO offload, multi-queue, and optimized MTU/MSS +- Connection tracking bypass (NOTRACK) for reduced CPU overhead on WireGuard UDP flows +- TCP MSS clamping to prevent fragmentation inside the tunnel - Control traffic is encrypted with TLS (Warning: does not verify server certificate) +- Only one vprox server may be running on a host + +## How it works + +### Control plane + +The server listens on port 443 (HTTPS) for control traffic. Clients send a `/connect` request with their WireGuard public key. The server allocates a peer IP from the WireGuard subnet, adds the client as a peer on all tunnel interfaces, and returns the assigned address along with a list of tunnel endpoints (listen ports). + +### Data plane + +WireGuard handles the data plane. Each tunnel interface encrypts/decrypts traffic independently. The server applies iptables rules for: + +- **SNAT (masquerade)**: Outbound traffic from WireGuard peers is source-NAT'd to the server's bind address. +- **Firewall marks**: Traffic from WireGuard interfaces is marked for routing policy. +- **MSS clamping**: TCP SYN packets are clamped to fit within the WireGuard MTU (1380 bytes). +- **NOTRACK**: WireGuard UDP flows bypass connection tracking to reduce per-packet CPU overhead. + +### Multi-tunnel routing + +In multi-tunnel mode, both server and client use Linux policy routing to distribute traffic: + +- A custom routing table (51820) contains multipath routes across all tunnel interfaces. +- An `ip rule` directs matching traffic to this table. +- On the client, the rule matches traffic sourced from the WireGuard IP (set by the dummy `vprox0` device). +- On the server, the rule matches traffic destined for the WireGuard subnet (forwarded download traffic). +- The kernel's L4 multipath hash (`fib_multipath_hash_policy=1`) distributes different flows to different tunnels. + +### Interface tuning + +WireGuard interfaces are created with performance-optimized settings: + +- **MTU 1420**: Prevents fragmentation on standard 1500 MTU networks (WireGuard adds ~60 bytes overhead). +- **GSO/GRO 65536**: Enables Generic Segmentation/Receive Offload, allowing the kernel to batch packets into 64 KB super-packets before encryption (Linux 5.19+). +- **4 TX/RX queues**: Enables parallel packet processing across multiple CPU cores. +- **TxQLen 1000**: Reduces packet drops during traffic bursts. ## Authors diff --git a/cmd/connect.go b/cmd/connect.go index 0cf7f12..b938303 100644 --- a/cmd/connect.go +++ b/cmd/connect.go @@ -68,12 +68,15 @@ var ConnectCmd = &cobra.Command{ } var connectCmdArgs struct { - ifname string + ifname string + tunnels int } func init() { ConnectCmd.Flags().StringVar(&connectCmdArgs.ifname, "interface", "vprox0", "Interface name to proxy traffic through the VPN") + ConnectCmd.Flags().IntVar(&connectCmdArgs.tunnels, "tunnels", + 1, "Number of parallel WireGuard tunnels (higher values improve throughput by spreading traffic across NIC queues)") } func runConnect(cmd *cobra.Command, args []string) error { @@ -98,11 +101,12 @@ func runConnect(cmd *cobra.Command, args []string) error { } client := &lib.Client{ - Key: key, - Ifname: connectCmdArgs.ifname, - ServerIp: serverIp, - Password: password, - WgClient: wgClient, + Key: key, + Ifname: connectCmdArgs.ifname, + ServerIp: serverIp, + Password: password, + NumTunnels: connectCmdArgs.tunnels, + WgClient: wgClient, Http: &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ diff --git a/cmd/server.go b/cmd/server.go index 47c366b..7dee489 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -29,6 +29,7 @@ var serverCmdArgs struct { wgBlockPerIp string cloud string takeover bool + tunnels int } func init() { @@ -42,6 +43,8 @@ func init() { "", "Cloud provider for IP metadata (watches for changes)") ServerCmd.Flags().BoolVar(&serverCmdArgs.takeover, "takeover", false, "Take over existing WireGuard state from a previous server instance (for non-disruptive upgrades)") + ServerCmd.Flags().IntVar(&serverCmdArgs.tunnels, "tunnels", + 1, "Number of parallel WireGuard tunnels per IP (higher values improve throughput by spreading traffic across NIC queues)") } func runServer(cmd *cobra.Command, args []string) error { @@ -98,9 +101,13 @@ func runServer(cmd *cobra.Command, args []string) error { return err } + if serverCmdArgs.tunnels < 1 || serverCmdArgs.tunnels > lib.MaxTunnelsPerServer { + return fmt.Errorf("--tunnels must be between 1 and %d", lib.MaxTunnelsPerServer) + } + ctx, done := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password, serverCmdArgs.takeover) + sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password, serverCmdArgs.tunnels, serverCmdArgs.takeover) if err != nil { done() return err diff --git a/lib/client.go b/lib/client.go index 65f58e4..12006ef 100644 --- a/lib/client.go +++ b/lib/client.go @@ -20,6 +20,8 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) + + // Used to determine if we can recover from an error during connection setup. type ConnectionError struct { Message string @@ -40,12 +42,30 @@ func IsRecoverableError(err error) bool { return true } -// Client manages a peering connection with with a local WireGuard interface. +// Client manages a peering connection with a local WireGuard interface, or a +// set of parallel WireGuard interfaces when multi-tunnel is enabled. +// +// Single-tunnel (NumTunnels <= 1): +// +// Applications use "vprox0" which is a plain WireGuard interface. +// +// Multi-tunnel (NumTunnels > 1): +// +// WireGuard tunnels: vprox0t0, vprox0t1, vprox0t2, ... +// Dummy device: vprox0 (holds the IP address, user-facing) +// +// Applications bind to "vprox0" (the dummy interface). An ip rule redirects +// traffic sourced from the vprox IP into a custom routing table that has +// equal-cost multipath routes across the WireGuard tunnels. The kernel +// distributes flows across them via L4 hashing. Each tunnel uses a different +// UDP port so the NIC hashes the outer packets to different hardware RX queues. type Client struct { // Key is the private key of the client. Key wgtypes.Key - // Ifname is the name of the client WireGuard interface. + // Ifname is the name of the interface exposed to applications (e.g. "vprox0"). + // In multi-tunnel mode this is a dummy device; individual WireGuard tunnels + // are named t0, t1, etc. Ifname string // ServerIp is the public IPv4 address of the server. @@ -54,142 +74,415 @@ type Client struct { // Password authenticates the client connection. Password string + // NumTunnels is the number of parallel WireGuard tunnels to create. + // When <= 1, the client creates a single plain WireGuard interface. + // When > 1, a dummy device + policy routing is created over N WireGuard tunnels. + NumTunnels int + // WgClient is a shared client for interacting with the WireGuard kernel module. WgClient *wgctrl.Client // Http is used to make connect requests to the server. Http *http.Client - // wgCidr is the current subnet assigned to the WireGuard interface, if any. + // wgCidr is the current subnet assigned to the interface, if any. wgCidr netip.Prefix + + // activeTunnels tracks how many tunnel interfaces were actually created + // during the last successful Connect(). + activeTunnels int +} + +// --------------------------------------------------------------------------- +// Naming helpers +// --------------------------------------------------------------------------- + +// numTunnels returns the effective tunnel count, defaulting to 1. +func (c *Client) numTunnels() int { + if c.NumTunnels <= 1 { + return 1 + } + return c.NumTunnels +} + +// isMultiTunnel returns true when we should create a bond device. +func (c *Client) isMultiTunnel() bool { + return c.numTunnels() > 1 +} + +// tunnelIfname returns the WireGuard interface name for the t-th tunnel. +// - Single-tunnel mode: returns Ifname directly (e.g. "vprox0"). +// - Multi-tunnel mode: returns "t" (e.g. "vprox0t0", "vprox0t1"). +func (c *Client) tunnelIfname(t int) string { + if !c.isMultiTunnel() { + return c.Ifname + } + return fmt.Sprintf("%st%d", c.Ifname, t) +} + +// tunnelLink builds a linkWireguard with tuned LinkAttrs for the t-th tunnel. +func (c *Client) tunnelLink(t int) *linkWireguard { + return &linkWireguard{LinkAttrs: netlink.LinkAttrs{ + Name: c.tunnelIfname(t), + MTU: WireguardMTU, + TxQLen: WireguardTxQLen, + NumTxQueues: WireguardNumQueues, + NumRxQueues: WireguardNumQueues, + GSOMaxSize: WireguardGSOMaxSize, + GROMaxSize: WireguardGSOMaxSize, + }} } -// CreateInterface creates a new interface for wireguard. DeleteInterface() needs -// to be called to clean this up. +// --------------------------------------------------------------------------- +// Interface creation / deletion +// --------------------------------------------------------------------------- + +// CreateInterface creates the network interface(s) that applications will use. +// - Single-tunnel: one plain WireGuard interface named Ifname. +// - Multi-tunnel: N WireGuard interfaces + a dummy device named Ifname +// with policy routing to distribute traffic across the tunnels. +// +// DeleteInterface() must be called to clean up. func (c *Client) CreateInterface() error { - link := c.link() + nt := c.numTunnels() + + // Create the WireGuard tunnel interfaces. + for t := 0; t < nt; t++ { + if err := c.createTunnelInterface(t); err != nil { + for rb := 0; rb < t; rb++ { + c.deleteTunnelInterface(rb) + } + return err + } + } + + // In multi-tunnel mode, create a dummy device for the user-facing interface. + if c.isMultiTunnel() { + if err := c.createDummyInterface(); err != nil { + for t := 0; t < nt; t++ { + c.deleteTunnelInterface(t) + } + return err + } + log.Printf("created dummy %s with %d WireGuard tunnels (%s .. %s)", + c.Ifname, nt, c.tunnelIfname(0), c.tunnelIfname(nt-1)) + } + + return nil +} + +// createTunnelInterface creates a single WireGuard tunnel interface. +func (c *Client) createTunnelInterface(t int) error { + link := c.tunnelLink(t) err := netlink.LinkAdd(link) if err != nil { - return fmt.Errorf("error creating vprox interface: %v", err) + return fmt.Errorf("error creating WireGuard interface %s: %v", link.Name, err) + } + + // Set MTU explicitly (some kernels ignore LinkAttrs.MTU on creation). + if err := netlink.LinkSetMTU(link, WireguardMTU); err != nil { + netlink.LinkDel(link) + return fmt.Errorf("error setting MTU on %s: %v", link.Name, err) + } + + // Set TxQLen for improved burst handling (non-fatal). + if err := netlink.LinkSetTxQLen(link, WireguardTxQLen); err != nil { + log.Printf("warning: failed to set TxQLen on %s: %v", link.Name, err) } return nil } -// Connect attempts to reconnect to the peer. A network interface needs to -// have already been created with CreateInterface() before calling Connect() +// createDummyInterface creates a dummy network interface named Ifname. This is +// the user-facing device that applications bind to. A policy routing rule will +// redirect its traffic into a custom table with multipath routes across the +// WireGuard tunnels. +func (c *Client) createDummyInterface() error { + // Remove any stale interface with this name. + if existing, _ := netlink.LinkByName(c.Ifname); existing != nil { + _ = netlink.LinkDel(existing) + } + + dummy := &netlink.Dummy{ + LinkAttrs: netlink.LinkAttrs{ + Name: c.Ifname, + MTU: WireguardMTU, + TxQLen: WireguardTxQLen, + }, + } + + if err := netlink.LinkAdd(dummy); err != nil { + return fmt.Errorf("failed to create dummy interface %s: %v", c.Ifname, err) + } + + return nil +} + +// DeleteInterface removes all interfaces and policy routing rules. +func (c *Client) DeleteInterface() { + if c.isMultiTunnel() { + // Clean up policy routing. + c.cleanupPolicyRouting() + + // Delete the dummy interface. + log.Printf("About to delete dummy interface %v", c.Ifname) + if dummy, err := netlink.LinkByName(c.Ifname); err == nil { + if err := netlink.LinkDel(dummy); err != nil { + log.Printf("error deleting dummy %v: %v", c.Ifname, err) + } else { + log.Printf("successfully deleted dummy %v", c.Ifname) + } + } + } + + nt := c.numTunnels() + for t := nt - 1; t >= 0; t-- { + c.deleteTunnelInterface(t) + } +} + +// deleteTunnelInterface removes a single WireGuard tunnel interface. +func (c *Client) deleteTunnelInterface(t int) { + ifname := c.tunnelIfname(t) + log.Printf("About to delete vprox interface %v", ifname) + link := &linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}} + if err := netlink.LinkDel(link); err != nil { + log.Printf("error deleting vprox interface %v: %v", ifname, err) + } else { + log.Printf("successfully deleted vprox interface %v", ifname) + } +} + +// --------------------------------------------------------------------------- +// Connect / Disconnect +// --------------------------------------------------------------------------- + +// Connect attempts to connect (or reconnect) to the server. All tunnel +// interfaces must already exist via CreateInterface(). func (c *Client) Connect() error { resp, err := c.sendConnectionRequest() if err != nil { return err } - link := c.link() - err = netlink.LinkSetUp(link) - if err != nil { - return fmt.Errorf("error setting up vprox interface: %v", err) + // Determine how many tunnels to actually use — minimum of what the client + // wants and what the server offers. + nt := c.numTunnels() + serverTunnels := len(resp.Tunnels) + if serverTunnels > 0 && serverTunnels < nt { + nt = serverTunnels } + if serverTunnels == 0 { + nt = 1 + } + c.activeTunnels = nt - err = c.updateInterface(resp) - if err != nil { + // Configure WireGuard on each tunnel interface. + for t := 0; t < nt; t++ { + ifname := c.tunnelIfname(t) + + link := c.tunnelLink(t) + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("error setting up %s: %v", ifname, err) + } + + // Pick the listen port: tunnel 0 always uses ServerListenPort + // (backwards compatible with old servers); tunnels 1+ use Tunnels[t]. + port := resp.ServerListenPort + if t > 0 && t < len(resp.Tunnels) { + port = resp.Tunnels[t].ListenPort + } + + if err := c.configureWireguardTunnel(t, resp, port); err != nil { + return fmt.Errorf("error configuring wireguard on %s: %v", ifname, err) + } + } + + // Bring up the user-facing interface and assign the address. + if err := c.bringUpUserInterface(); err != nil { return err } - err = c.configureWireguard(resp) - if err != nil { - return fmt.Errorf("error configuring wireguard interface: %v", err) + if err := c.updateAddress(resp); err != nil { + return err + } + + // In multi-tunnel mode, assign addresses to each tunnel and set up + // policy routing to distribute traffic across them. + if c.isMultiTunnel() && nt > 1 { + if err := c.setupPolicyRouting(nt); err != nil { + return fmt.Errorf("error setting up policy routing: %v", err) + } + log.Printf("policy routing configured across %d tunnels", nt) } return nil } -// updateInterface updates the wireguard interface based on the provided connectionResponse -func (c *Client) updateInterface(resp connectResponse) error { +// bringUpUserInterface brings up the interface that applications will use. +// In single-tunnel mode this is the WireGuard interface itself; in multi-tunnel +// mode this is the bond device. +func (c *Client) bringUpUserInterface() error { + link, err := netlink.LinkByName(c.Ifname) + if err != nil { + return fmt.Errorf("failed to find interface %s: %v", c.Ifname, err) + } + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("error setting up %s: %v", c.Ifname, err) + } + return nil +} + +// updateAddress assigns (or updates) the IP address on the user-facing +// interface (Ifname). +func (c *Client) updateAddress(resp connectResponse) error { cidr, err := netip.ParsePrefix(resp.AssignedAddr) if err != nil { return fmt.Errorf("failed to parse assigned address %v: %v", resp.AssignedAddr, err) } - if cidr != c.wgCidr { - link := c.link() + link, err := netlink.LinkByName(c.Ifname) + if err != nil { + return fmt.Errorf("failed to find interface %s: %v", c.Ifname, err) + } + if cidr != c.wgCidr { if c.wgCidr.IsValid() { oldIpnet := prefixToIPNet(c.wgCidr) - err = netlink.AddrDel(link, &netlink.Addr{IPNet: &oldIpnet}) - - if err != nil { - log.Printf("warning: failed to remove old address from vprox interface when reconnecting: %v", err) + if err := netlink.AddrDel(link, &netlink.Addr{IPNet: &oldIpnet}); err != nil { + log.Printf("warning: failed to remove old address from %s: %v", c.Ifname, err) } } - ipnet := prefixToIPNet(cidr) - err = netlink.AddrAdd(link, &netlink.Addr{IPNet: &ipnet}) - if err != nil { - return fmt.Errorf("failed to add new address to vprox interface: %v", err) + if err := netlink.AddrAdd(link, &netlink.Addr{IPNet: &ipnet}); err != nil { + return fmt.Errorf("failed to add address to %s: %v", c.Ifname, err) } c.wgCidr = cidr } return nil } -// sendConnectionRequest attempts to send a connection request to the peer -func (c *Client) sendConnectionRequest() (connectResponse, error) { - connectUrl, err := url.Parse(fmt.Sprintf("https://%s/connect", c.ServerIp)) - if err != nil { - return connectResponse{}, fmt.Errorf("failed to parse connect URL: %v", err) +// --------------------------------------------------------------------------- +// Policy routing (multi-tunnel) +// --------------------------------------------------------------------------- + +// setupPolicyRouting creates: +// 1. An ip rule that matches traffic from the vprox source IP and directs it +// to a custom routing table. +// 2. Equal-cost multipath routes in that table across all WireGuard tunnels. +// +// This allows applications binding to the dummy vprox0 interface (or using the +// vprox IP as source) to have their traffic distributed across tunnels by the +// kernel's L4 flow hash. +func (c *Client) setupPolicyRouting(nt int) error { + if !c.wgCidr.IsValid() { + return fmt.Errorf("no valid CIDR assigned yet") } - reqJson := &connectRequest{ - PeerPublicKey: c.Key.PublicKey().String(), + // Assign the same IP address to each WireGuard tunnel interface so that + // the kernel can use any of them to reach the server WireGuard IP. + for t := 0; t < nt; t++ { + ifname := c.tunnelIfname(t) + tunnelLink, err := netlink.LinkByName(ifname) + if err != nil { + return fmt.Errorf("failed to find tunnel %s: %v", ifname, err) + } + ipnet := prefixToIPNet(c.wgCidr) + if err := netlink.AddrReplace(tunnelLink, &netlink.Addr{IPNet: &ipnet}); err != nil { + return fmt.Errorf("failed to assign address to %s: %v", ifname, err) + } } - buf, err := json.Marshal(reqJson) - if err != nil { - return connectResponse{}, fmt.Errorf("failed to marshal connect request: %v", err) + + // Build multipath nexthops — one per tunnel interface. + gwAddr := c.wgCidr.Masked().Addr().Next() + gwIP := addrToIp(gwAddr) + + var nexthops []*netlink.NexthopInfo + for t := 0; t < nt; t++ { + ifname := c.tunnelIfname(t) + tunnelLink, err := netlink.LinkByName(ifname) + if err != nil { + return fmt.Errorf("failed to find tunnel %s: %v", ifname, err) + } + nexthops = append(nexthops, &netlink.NexthopInfo{ + LinkIndex: tunnelLink.Attrs().Index, + Gw: gwIP, + Hops: 0, + }) } - req := &http.Request{ - Method: http.MethodPost, - URL: connectUrl, - Header: http.Header{ - "Authorization": []string{"Bearer " + c.Password}, - }, - Body: io.NopCloser(bytes.NewBuffer(buf)), + // Add default multipath route in the custom table. + _, defaultDst, _ := net.ParseCIDR("0.0.0.0/0") + mpRoute := &netlink.Route{ + Table: PolicyRoutingTable, + Dst: defaultDst, + MultiPath: nexthops, + } + if err := netlink.RouteReplace(mpRoute); err != nil { + return fmt.Errorf("failed to add multipath route to table %d: %v", PolicyRoutingTable, err) } - resp, err := c.Http.Do(req) - if err != nil { - return connectResponse{}, fmt.Errorf("failed to connect to server: %v", err) + // Add an ip rule: from lookup table PolicyRoutingTable. + srcIP := c.wgCidr.Addr() + srcNet := &net.IPNet{ + IP: addrToIp(srcIP), + Mask: net.CIDRMask(32, 32), } - defer resp.Body.Close() + rule := netlink.NewRule() + rule.Src = srcNet + rule.Table = PolicyRoutingTable + rule.Priority = PolicyRoutingPriority - if resp.StatusCode != http.StatusOK { - recoverable := resp.StatusCode != http.StatusUnauthorized - return connectResponse{}, &ConnectionError{ - Message: fmt.Sprintf("server returned status %v", resp.Status), - Recoverable: recoverable, - } + // Remove any stale rule first (idempotent). + _ = netlink.RuleDel(rule) + + if err := netlink.RuleAdd(rule); err != nil { + return fmt.Errorf("failed to add ip rule for %v: %v", srcIP, err) } - buf, err = io.ReadAll(resp.Body) - if err != nil { - return connectResponse{}, fmt.Errorf("failed to read response body: %v", err) + return nil +} + +// cleanupPolicyRouting removes the ip rule and flushes the custom routing table. +func (c *Client) cleanupPolicyRouting() { + if c.wgCidr.IsValid() { + srcIP := c.wgCidr.Addr() + srcNet := &net.IPNet{ + IP: addrToIp(srcIP), + Mask: net.CIDRMask(32, 32), + } + rule := netlink.NewRule() + rule.Src = srcNet + rule.Table = PolicyRoutingTable + rule.Priority = PolicyRoutingPriority + if err := netlink.RuleDel(rule); err != nil { + log.Printf("warning: failed to delete ip rule: %v", err) + } } - var respJson connectResponse - json.Unmarshal(buf, &respJson) - return respJson, nil + // Flush routes in our custom table. + routes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{ + Table: PolicyRoutingTable, + }, netlink.RT_FILTER_TABLE) + if err == nil { + for i := range routes { + _ = netlink.RouteDel(&routes[i]) + } + } } -// configureWireguard configures the WireGuard peer. -func (c *Client) configureWireguard(connectionResponse connectResponse) error { - serverPublicKey, err := wgtypes.ParseKey(connectionResponse.ServerPublicKey) +// configureWireguardTunnel configures a single WireGuard tunnel interface with +// the server as a peer on the given port. +func (c *Client) configureWireguardTunnel(t int, resp connectResponse, serverPort int) error { + serverPublicKey, err := wgtypes.ParseKey(resp.ServerPublicKey) if err != nil { return fmt.Errorf("failed to parse server public key: %v", err) } keepalive := 25 * time.Second - return c.WgClient.ConfigureDevice(c.Ifname, wgtypes.Config{ + ifname := c.tunnelIfname(t) + return c.WgClient.ConfigureDevice(ifname, wgtypes.Config{ PrivateKey: &c.Key, ReplacePeers: true, Peers: []wgtypes.PeerConfig{ @@ -197,7 +490,7 @@ func (c *Client) configureWireguard(connectionResponse connectResponse) error { PublicKey: serverPublicKey, Endpoint: &net.UDPAddr{ IP: addrToIp(c.ServerIp), - Port: connectionResponse.ServerListenPort, + Port: serverPort, }, PersistentKeepaliveInterval: &keepalive, ReplaceAllowedIPs: true, @@ -250,24 +543,15 @@ func (c *Client) Disconnect() error { return nil } -func (c *Client) DeleteInterface() { - // Delete the WireGuard interface. - log.Printf("About to delete vprox interface %v", c.Ifname) - err := netlink.LinkDel(c.link()) - if err != nil { - log.Printf("error deleting vprox interface %v: %v", c.Ifname, err) - } else { - log.Printf("successfully deleted vprox interface %v", c.Ifname) - } -} - -func (c *Client) link() *linkWireguard { - return &linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: c.Ifname}} -} +// --------------------------------------------------------------------------- +// Health check +// --------------------------------------------------------------------------- // CheckConnection checks the status of the connection with the wireguard peer, // and returns true if it is healthy. This sends 3 pings in succession, and blocks // until they receive a response or the timeout passes. +// In multi-tunnel mode, pings are sent through the first WireGuard tunnel +// interface (not the dummy device) so replies are received on the same interface. func (c *Client) CheckConnection(timeout time.Duration, cancelCtx context.Context) bool { pinger, err := probing.NewPinger(c.wgCidr.Masked().Addr().Next().String()) if err != nil { @@ -275,11 +559,14 @@ func (c *Client) CheckConnection(timeout time.Duration, cancelCtx context.Contex return false } - pinger.InterfaceName = c.Ifname + // Use the first WireGuard tunnel for health checks. In single-tunnel mode + // tunnelIfname(0) == Ifname; in multi-tunnel mode it's the actual WireGuard + // device (e.g. "vprox0t0") rather than the dummy ("vprox0"). + pinger.InterfaceName = c.tunnelIfname(0) pinger.Timeout = timeout pinger.Count = 3 pinger.Interval = 10 * time.Millisecond // Send approximately all at once - err = pinger.RunWithContext(cancelCtx) // Blocks until finished. + err = pinger.RunWithContext(cancelCtx) // Blocks until finished. if err != nil { log.Printf("error running pinger: %v", err) return false @@ -290,3 +577,55 @@ func (c *Client) CheckConnection(timeout time.Duration, cancelCtx context.Contex } return stats.PacketsRecv > 0 } + +// --------------------------------------------------------------------------- +// HTTPS / control-plane +// --------------------------------------------------------------------------- + +// sendConnectionRequest attempts to send a connection request to the peer. +func (c *Client) sendConnectionRequest() (connectResponse, error) { + connectUrl, err := url.Parse(fmt.Sprintf("https://%s/connect", c.ServerIp)) + if err != nil { + return connectResponse{}, fmt.Errorf("failed to parse connect URL: %v", err) + } + + reqJson := &connectRequest{ + PeerPublicKey: c.Key.PublicKey().String(), + } + buf, err := json.Marshal(reqJson) + if err != nil { + return connectResponse{}, fmt.Errorf("failed to marshal connect request: %v", err) + } + + req := &http.Request{ + Method: http.MethodPost, + URL: connectUrl, + Header: http.Header{ + "Authorization": []string{"Bearer " + c.Password}, + }, + Body: io.NopCloser(bytes.NewBuffer(buf)), + } + + resp, err := c.Http.Do(req) + if err != nil { + return connectResponse{}, fmt.Errorf("failed to connect to server: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + recoverable := resp.StatusCode != http.StatusUnauthorized + return connectResponse{}, &ConnectionError{ + Message: fmt.Sprintf("server returned status %v", resp.Status), + Recoverable: recoverable, + } + } + + buf, err = io.ReadAll(resp.Body) + if err != nil { + return connectResponse{}, fmt.Errorf("failed to read response body: %v", err) + } + + var respJson connectResponse + json.Unmarshal(buf, &respJson) + return respJson, nil +} diff --git a/lib/server.go b/lib/server.go index 0d1ff5a..061ce0c 100644 --- a/lib/server.go +++ b/lib/server.go @@ -25,9 +25,46 @@ import ( // FwmarkBase is the base value for firewall marks used by vprox. const FwmarkBase = 0x54437D00 +// PolicyRoutingTable is the custom routing table number used for multi-tunnel +// multipath routing on both server and client. +const PolicyRoutingTable = 51820 + +// PolicyRoutingPriority is the ip rule priority for the vprox policy route. +const PolicyRoutingPriority = 100 + // UDP listen port base value for WireGuard connections. const WireguardListenPortBase = 50227 +// WireGuard interface MTU. WireGuard adds ~60 bytes overhead (40 for IPv4/UDP +// + 16 for WG header + padding). Setting MTU to 1420 prevents fragmentation +// on standard 1500 MTU networks. +const WireguardMTU = 1420 + +// WireGuard interface transmit queue length. Higher values reduce packet drops +// during traffic bursts. +const WireguardTxQLen = 1000 + +// GSO/GRO max size for improved throughput on Linux 5.19+. Allows the kernel +// to batch packets into large 64 KB super-packets before encryption/decryption. +const WireguardGSOMaxSize = 65536 + +// TCP MSS for traffic through the WireGuard tunnel, calculated as +// MTU (1420) - IP header (20) - TCP header (20) = 1380. +const WireguardMSS = 1380 + +// Number of TX/RX queues for parallel packet processing on multi-core systems. +const WireguardNumQueues = 4 + +// MaxTunnelsPerServer is the maximum number of parallel WireGuard tunnels +// allowed per server (per bind IP). Each tunnel uses a different UDP port +// so that the NIC hashes them to different hardware RX queues. +const MaxTunnelsPerServer = 16 + +// PortsPerIndex is the number of UDP ports reserved per server index. +// This must be >= MaxTunnelsPerServer. With this spacing, server index 0 +// uses ports 50227..50242, index 1 uses 50243..50258, etc. +const PortsPerIndex = MaxTunnelsPerServer + // A new peer must connect with a handshake within this time. const FirstHandshakeTimeout = 10 * time.Second @@ -68,6 +105,12 @@ type Server struct { // Index is a unique server index for firewall marks and other uses. It starts at 0. Index uint16 + // NumTunnels is the number of parallel WireGuard tunnels to create for + // this server. Each tunnel listens on a different UDP port so that the NIC + // hashes them to different hardware RX queues, increasing throughput beyond + // the single-flow limit. Defaults to 1 for backwards compatibility. + NumTunnels int + // Ipt is the iptables client for managing firewall rules. Ipt *iptables.IPTables @@ -94,6 +137,17 @@ type Server struct { takeover bool } +// numTunnels returns the effective tunnel count, defaulting to 1. +func (srv *Server) numTunnels() int { + if srv.NumTunnels <= 0 { + return 1 + } + if srv.NumTunnels > MaxTunnelsPerServer { + return MaxTunnelsPerServer + } + return srv.NumTunnels +} + // InitState initializes the private server state. func (srv *Server) InitState() error { if srv.BindIface == nil { @@ -177,10 +231,25 @@ func (srv *Server) indexHandler(w http.ResponseWriter, r *http.Request) { type connectRequest struct { PeerPublicKey string } + +// TunnelInfo describes a single WireGuard tunnel endpoint within a multi-tunnel +// connection. Clients that support multi-tunnel will use these to set up +// parallel WireGuard interfaces. +type TunnelInfo struct { + ListenPort int `json:"ListenPort"` + Ifname string `json:"Ifname"` +} + type connectResponse struct { AssignedAddr string ServerPublicKey string ServerListenPort int + + // Tunnels lists all available tunnel endpoints for this server. Clients + // that support multi-tunnel create one WireGuard interface per entry. + // Clients that don't understand this field will fall back to the single + // ServerListenPort above (backwards compatible). + Tunnels []TunnelInfo `json:"Tunnels,omitempty"` } // Handle a new connection. @@ -242,31 +311,56 @@ func (srv *Server) connectHandler(w http.ResponseWriter, r *http.Request) { clientIp := strings.Split(r.RemoteAddr, ":")[0] // for logging log.Printf("[%v] new peer %v at %v: %v", srv.BindAddr, clientIp, peerIp, peerKey) - err = srv.WgClient.ConfigureDevice(srv.Ifname(), wgtypes.Config{ - Peers: []wgtypes.PeerConfig{ - { - PublicKey: peerKey, - ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{prefixToIPNet(netip.PrefixFrom(peerIp, 32))}, + + // Add the peer to ALL tunnel interfaces so that traffic arriving on any + // tunnel is accepted, and the server can send traffic back on any tunnel. + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + err = srv.WgClient.ConfigureDevice(ifname, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + ReplaceAllowedIPs: true, + AllowedIPs: []net.IPNet{prefixToIPNet(netip.PrefixFrom(peerIp, 32))}, + }, }, - }, - }) - if err != nil { - srv.mu.Lock() - delete(srv.allPeers, peerKey) - srv.mu.Unlock() + }) + if err != nil { + // Roll back: remove from any interfaces we already configured. + for rb := 0; rb < t; rb++ { + rbIfname := srv.TunnelIfname(rb) + _ = srv.WgClient.ConfigureDevice(rbIfname, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{{PublicKey: peerKey, Remove: true}}, + }) + } - srv.ipAllocator.Free(peerIp) - log.Printf("failed to configure WireGuard peer: %v", err) - http.Error(w, "failed to configure WireGuard peer", http.StatusInternalServerError) - return + srv.mu.Lock() + delete(srv.allPeers, peerKey) + srv.mu.Unlock() + srv.ipAllocator.Free(peerIp) + + log.Printf("failed to configure WireGuard peer on %s: %v", ifname, err) + http.Error(w, "failed to configure WireGuard peer", http.StatusInternalServerError) + return + } + } + + // Build the Tunnels list for multi-tunnel clients. + tunnels := make([]TunnelInfo, nt) + for t := 0; t < nt; t++ { + tunnels[t] = TunnelInfo{ + ListenPort: srv.tunnelListenPort(t), + Ifname: srv.TunnelIfname(t), + } } // Return the assigned IP address and the server's public key. resp := &connectResponse{ AssignedAddr: fmt.Sprintf("%v/%d", peerIp, srv.WgCidr.Bits()), ServerPublicKey: srv.Key.PublicKey().String(), - ServerListenPort: WireguardListenPortBase + int(srv.Index), + ServerListenPort: srv.tunnelListenPort(0), // primary tunnel for old clients + Tunnels: tunnels, } respBuf, err := json.Marshal(resp) @@ -416,40 +510,89 @@ func (srv *Server) versionHandler(w http.ResponseWriter, r *http.Request) { w.Write(respBuf) } +// Ifname returns the primary WireGuard interface name (tunnel 0). This is used +// for backwards-compatible code paths like takeover and idle peer removal. func (srv *Server) Ifname() string { - return fmt.Sprintf("vprox%d", srv.Index) + return srv.TunnelIfname(0) } +// TunnelIfname returns the WireGuard interface name for the t-th tunnel. +// When NumTunnels == 1, this returns "vprox0" (same as before). +// When NumTunnels > 1, tunnel 0 is "vprox0", tunnel 1 is "vprox0t1", etc. +func (srv *Server) TunnelIfname(t int) string { + base := fmt.Sprintf("vprox%d", srv.Index) + if t == 0 { + return base + } + return fmt.Sprintf("%st%d", base, t) +} + +// tunnelListenPort returns the UDP listen port for the t-th tunnel. +func (srv *Server) tunnelListenPort(t int) int { + return WireguardListenPortBase + int(srv.Index)*PortsPerIndex + t +} + +// StartWireguard creates and configures all tunnel WireGuard interfaces. func (srv *Server) StartWireguard() error { - ifname := srv.Ifname() - link := &linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}} + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + if err := srv.startWireguardTunnel(t); err != nil { + // Clean up any tunnels we already created. + for rb := 0; rb < t; rb++ { + srv.cleanupWireguardTunnel(rb) + } + return err + } + } + if nt > 1 { + log.Printf("[%v] started %d WireGuard tunnels (ports %d..%d)", + srv.BindAddr, nt, srv.tunnelListenPort(0), srv.tunnelListenPort(nt-1)) + + // Set up policy routing so reply traffic is distributed across tunnels. + if err := srv.setupPolicyRouting(nt); err != nil { + log.Printf("[%v] warning: failed to set up policy routing: %v", srv.BindAddr, err) + } else { + log.Printf("[%v] policy routing configured across %d tunnels", srv.BindAddr, nt) + } + } + return nil +} + +// startWireguardTunnel creates and configures a single WireGuard tunnel interface. +func (srv *Server) startWireguardTunnel(t int) error { + ifname := srv.TunnelIfname(t) + link := &linkWireguard{LinkAttrs: netlink.LinkAttrs{ + Name: ifname, + MTU: WireguardMTU, + TxQLen: WireguardTxQLen, + NumTxQueues: WireguardNumQueues, + NumRxQueues: WireguardNumQueues, + GSOMaxSize: WireguardGSOMaxSize, + GROMaxSize: WireguardGSOMaxSize, + }} - // Track whether we created a fresh interface (for cleanup on error) createdFreshInterface := false if srv.takeover { - // In takeover mode, use existing interface if available, otherwise create fresh _, err := netlink.LinkByName(ifname) if err == nil { log.Printf("[%v] takeover mode: using existing WireGuard interface %s", srv.BindAddr, ifname) } else { - // Interface doesn't exist, create fresh log.Printf("[%v] takeover mode: interface %s not found, creating fresh", srv.BindAddr, ifname) - if err := srv.createFreshInterface(link); err != nil { + if err := srv.createFreshInterface(link, t); err != nil { return err } createdFreshInterface = true } } else { - // Normal mode: delete and recreate the interface _ = netlink.LinkDel(link) // remove if it already exists - if err := srv.createFreshInterface(link); err != nil { + if err := srv.createFreshInterface(link, t); err != nil { return err } createdFreshInterface = true } - listenPort := WireguardListenPortBase + int(srv.Index) + listenPort := srv.tunnelListenPort(t) err := srv.WgClient.ConfigureDevice(ifname, wgtypes.Config{ PrivateKey: &srv.Key, ListenPort: &listenPort, @@ -465,28 +608,129 @@ func (srv *Server) StartWireguard() error { } // createFreshInterface creates and configures a new WireGuard interface. -func (srv *Server) createFreshInterface(link *linkWireguard) error { +// Every tunnel interface gets the same subnet IP so the kernel can route +// reply packets back through any of them. +func (srv *Server) createFreshInterface(link *linkWireguard, tunnelIndex int) error { err := netlink.LinkAdd(link) if err != nil { - return fmt.Errorf("failed to create WireGuard device: %v", err) + return fmt.Errorf("failed to create WireGuard device %s: %v", link.Name, err) } + // Assign the subnet IP to every tunnel interface. ipnet := prefixToIPNet(srv.WgCidr) - err = netlink.AddrAdd(link, &netlink.Addr{IPNet: &ipnet}) + err = netlink.AddrReplace(link, &netlink.Addr{IPNet: &ipnet}) if err != nil { netlink.LinkDel(link) - return fmt.Errorf("failed to add address to WireGuard device: %v", err) + return fmt.Errorf("failed to add address to WireGuard device %s: %v", link.Name, err) + } + + // Set MTU explicitly after link creation (some kernels ignore it in LinkAttrs) + err = netlink.LinkSetMTU(link, WireguardMTU) + if err != nil { + netlink.LinkDel(link) + return fmt.Errorf("failed to set MTU on WireGuard device %s: %v", link.Name, err) + } + + // Set TxQLen for improved burst handling + err = netlink.LinkSetTxQLen(link, WireguardTxQLen) + if err != nil { + log.Printf("warning: failed to set TxQLen on WireGuard device %s: %v", link.Name, err) } err = netlink.LinkSetUp(link) if err != nil { netlink.LinkDel(link) - return fmt.Errorf("failed to bring up WireGuard device: %v", err) + return fmt.Errorf("failed to bring up WireGuard device %s: %v", link.Name, err) + } + + return nil +} + +// setupPolicyRouting creates: +// 1. Multipath routes in a custom routing table across all WireGuard tunnels. +// 2. An ip rule that matches traffic destined for the WireGuard subnet and +// directs it to that custom table. +// +// The rule matches on destination (not source) because the server forwards +// traffic from the internet to clients — the source IP of forwarded packets is +// the remote host, not the server's WireGuard IP. +func (srv *Server) setupPolicyRouting(nt int) error { + // Build multipath nexthops — one per tunnel interface. We use the subnet + // as the destination (not default) since the server only needs to reach + // the WireGuard peer subnet via these tunnels. + subnetIPNet := prefixToIPNet(srv.WgCidr.Masked()) + + var nexthops []*netlink.NexthopInfo + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + link, err := netlink.LinkByName(ifname) + if err != nil { + return fmt.Errorf("failed to find interface %s: %v", ifname, err) + } + nexthops = append(nexthops, &netlink.NexthopInfo{ + LinkIndex: link.Attrs().Index, + Hops: 0, + }) + } + + // Add the multipath route in custom table. + mpRoute := &netlink.Route{ + Table: PolicyRoutingTable, + Dst: &subnetIPNet, + MultiPath: nexthops, + Scope: netlink.SCOPE_LINK, + } + if err := netlink.RouteReplace(mpRoute); err != nil { + return fmt.Errorf("failed to add multipath route to table %d: %v", PolicyRoutingTable, err) + } + + // Add an ip rule: to lookup custom table. + // This catches both locally-originated replies and forwarded (NAT'd) traffic + // destined for any client in the WireGuard subnet. + dstNet := &net.IPNet{ + IP: addrToIp(srv.WgCidr.Masked().Addr()), + Mask: net.CIDRMask(srv.WgCidr.Bits(), 32), + } + rule := netlink.NewRule() + rule.Dst = dstNet + rule.Table = PolicyRoutingTable + rule.Priority = PolicyRoutingPriority + + // Remove any stale rule first (idempotent). + _ = netlink.RuleDel(rule) + + if err := netlink.RuleAdd(rule); err != nil { + return fmt.Errorf("failed to add ip rule for dst %v: %v", dstNet, err) } return nil } +// cleanupPolicyRouting removes the ip rule and flushes the custom routing table. +func (srv *Server) cleanupPolicyRouting() { + dstNet := &net.IPNet{ + IP: addrToIp(srv.WgCidr.Masked().Addr()), + Mask: net.CIDRMask(srv.WgCidr.Bits(), 32), + } + rule := netlink.NewRule() + rule.Dst = dstNet + rule.Table = PolicyRoutingTable + rule.Priority = PolicyRoutingPriority + if err := netlink.RuleDel(rule); err != nil { + log.Printf("warning: failed to delete ip rule: %v", err) + } + + routes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{ + Table: PolicyRoutingTable, + }, netlink.RT_FILTER_TABLE) + if err == nil { + for i := range routes { + _ = netlink.RouteDel(&routes[i]) + } + } +} + +// CleanupWireguard removes all tunnel WireGuard interfaces and policy routing. func (srv *Server) CleanupWireguard() { srv.mu.Lock() relinquished := srv.relinquished @@ -494,35 +738,58 @@ func (srv *Server) CleanupWireguard() { if relinquished { log.Printf("[%v] skipping WireGuard cleanup (relinquished)", srv.BindAddr) - } else { - log.Printf("[%v] cleaning up WireGuard state", srv.BindAddr) - ifname := srv.Ifname() - _ = netlink.LinkDel(&linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}}) + return + } + + log.Printf("[%v] cleaning up WireGuard state", srv.BindAddr) + + if srv.numTunnels() > 1 { + srv.cleanupPolicyRouting() + } + + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + srv.cleanupWireguardTunnel(t) } } -// iptablesInputFwmarkRule adds or removes the mangle PREROUTING rule for traffic from WireGuard. +// cleanupWireguardTunnel removes a single WireGuard tunnel interface. +func (srv *Server) cleanupWireguardTunnel(t int) { + ifname := srv.TunnelIfname(t) + _ = netlink.LinkDel(&linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}}) +} + +// iptablesInputFwmarkRule adds or removes the mangle PREROUTING rule for traffic +// from WireGuard. One rule per tunnel interface, all using the same fwmark. func (srv *Server) iptablesInputFwmarkRule(enabled bool) error { firewallMark := FwmarkBase + int(srv.Index) - rule := []string{ - "-i", srv.Ifname(), - "-j", "MARK", "--set-mark", strconv.Itoa(firewallMark), - "-m", "comment", "--comment", fmt.Sprintf("vprox fwmark rule for %s", srv.Ifname()), - } - if enabled { - return srv.Ipt.AppendUnique("mangle", "PREROUTING", rule...) - } else { - return srv.Ipt.Delete("mangle", "PREROUTING", rule...) + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + rule := []string{ + "-i", ifname, + "-j", "MARK", "--set-mark", strconv.Itoa(firewallMark), + "-m", "comment", "--comment", fmt.Sprintf("vprox fwmark rule for %s", ifname), + } + if enabled { + if err := srv.Ipt.AppendUnique("mangle", "PREROUTING", rule...); err != nil { + return err + } + } else { + srv.Ipt.Delete("mangle", "PREROUTING", rule...) + } } + return nil } // iptablesSnatRule adds or removes the nat POSTROUTING rule for outbound traffic. +// This is shared across all tunnels via fwmark (only one rule needed). func (srv *Server) iptablesSnatRule(enabled bool) error { firewallMark := FwmarkBase + int(srv.Index) rule := []string{ "-m", "mark", "--mark", strconv.Itoa(firewallMark), "-j", "SNAT", "--to-source", srv.BindAddr.String(), - "-m", "comment", "--comment", fmt.Sprintf("vprox snat rule for %s", srv.Ifname()), + "-m", "comment", "--comment", fmt.Sprintf("vprox snat rule for index %d", srv.Index), } if enabled { return srv.Ipt.AppendUnique("nat", "POSTROUTING", rule...) @@ -531,21 +798,78 @@ func (srv *Server) iptablesSnatRule(enabled bool) error { } } -// iptablesMssRule adds or removes the FORWARD chain rule for TCP MSS adjustment -func (srv *Server) iptablesMssRule(enabled bool) error { - rule := []string{ - "-p", "tcp", - "--tcp-flags", "SYN,RST", "SYN", - "-j", "TCPMSS", - "--set-mss", "1160", - "-m", "comment", "--comment", fmt.Sprintf("vprox mss rule for %s", srv.Ifname()), +// iptablesNotrackRule adds or removes NOTRACK rules in the raw table to bypass +// connection tracking for WireGuard UDP traffic on all tunnel ports. +func (srv *Server) iptablesNotrackRule(enabled bool) error { + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + listenPort := strconv.Itoa(srv.tunnelListenPort(t)) + ifname := srv.TunnelIfname(t) + inRule := []string{ + "-p", "udp", + "--dport", listenPort, + "-j", "NOTRACK", + "-m", "comment", "--comment", fmt.Sprintf("vprox notrack in for %s", ifname), + } + outRule := []string{ + "-p", "udp", + "--sport", listenPort, + "-j", "NOTRACK", + "-m", "comment", "--comment", fmt.Sprintf("vprox notrack out for %s", ifname), + } + if enabled { + if err := srv.Ipt.AppendUnique("raw", "PREROUTING", inRule...); err != nil { + return fmt.Errorf("failed to add NOTRACK PREROUTING rule for %s: %v", ifname, err) + } + if err := srv.Ipt.AppendUnique("raw", "OUTPUT", outRule...); err != nil { + srv.Ipt.Delete("raw", "PREROUTING", inRule...) + return fmt.Errorf("failed to add NOTRACK OUTPUT rule for %s: %v", ifname, err) + } + } else { + srv.Ipt.Delete("raw", "PREROUTING", inRule...) + srv.Ipt.Delete("raw", "OUTPUT", outRule...) + } } + return nil +} - if enabled { - return srv.Ipt.AppendUnique("filter", "FORWARD", rule...) - } else { - return srv.Ipt.Delete("filter", "FORWARD", rule...) +// iptablesMssRule adds or removes FORWARD chain rules for TCP MSS clamping in +// both directions on all tunnel interfaces. +func (srv *Server) iptablesMssRule(enabled bool) error { + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + outRule := []string{ + "-o", ifname, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", strconv.Itoa(WireguardMSS), + "-m", "comment", "--comment", fmt.Sprintf("vprox mss out rule for %s", ifname), + } + inRule := []string{ + "-i", ifname, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", strconv.Itoa(WireguardMSS), + "-m", "comment", "--comment", fmt.Sprintf("vprox mss in rule for %s", ifname), + } + + if enabled { + if err := srv.Ipt.AppendUnique("mangle", "FORWARD", outRule...); err != nil { + return err + } + if err := srv.Ipt.AppendUnique("mangle", "FORWARD", inRule...); err != nil { + srv.Ipt.Delete("mangle", "FORWARD", outRule...) + return err + } + } else { + srv.Ipt.Delete("mangle", "FORWARD", outRule...) + srv.Ipt.Delete("mangle", "FORWARD", inRule...) + } } + return nil } func (srv *Server) StartIptables() error { @@ -567,16 +891,25 @@ func (srv *Server) StartIptables() error { return fmt.Errorf("failed to add MSS rule: %v", err) } + // NOTRACK is best-effort — don't fail startup if the raw table isn't available. + if err = srv.iptablesNotrackRule(true); err != nil { + log.Printf("warning: failed to add NOTRACK rules (non-fatal): %v", err) + } + return nil } func (srv *Server) CleanupIptables() { if err := srv.iptablesInputFwmarkRule(false); err != nil { - log.Printf("warning: error cleaning up IP tables: failed to add fwmark rule: %v\n", err) + log.Printf("warning: error cleaning up iptables fwmark rule: %v\n", err) } if err := srv.iptablesSnatRule(false); err != nil { - log.Printf("warning: error cleaning up IP tables: failed to add SNAT rule: %v\n", err) + log.Printf("warning: error cleaning up iptables SNAT rule: %v\n", err) } + if err := srv.iptablesMssRule(false); err != nil { + log.Printf("warning: error cleaning up iptables MSS rule: %v\n", err) + } + srv.iptablesNotrackRule(false) } func (srv *Server) removeIdlePeersLoop() { @@ -594,15 +927,13 @@ func (srv *Server) removeIdlePeersLoop() { } } -// cleanupPeer removes a peer from the WireGuard interface, reclaims its subnet IP, -// and removes it from the allPeers map. This function is idempotent and safe to call -// even if the peer doesn't exist. It uses the allPeers map as the source of truth. +// cleanupPeer removes a peer from ALL WireGuard tunnel interfaces, reclaims its +// subnet IP, and removes it from the allPeers map. func (srv *Server) cleanupPeer(publicKey wgtypes.Key) error { // Look up the peer in allPeers map to get its IP address. srv.mu.Lock() peerInfo, exists := srv.allPeers[publicKey] if !exists { - // Peer not in allPeers - it likely already got cleaned up. srv.mu.Unlock() log.Printf("[%v] peer unexpectedly not found in allPeers - did /disconnect race with the periodic peer-GC loop?: %v", srv.BindAddr, publicKey) return nil @@ -613,18 +944,23 @@ func (srv *Server) cleanupPeer(publicKey wgtypes.Key) error { delete(srv.allPeers, publicKey) srv.mu.Unlock() - // Remove the peer from WireGuard (no lock held during WireGuard operations). - log.Printf("[%v] removing peer at %v: %v", srv.BindAddr, peerIp, publicKey) - err := srv.WgClient.ConfigureDevice(srv.Ifname(), wgtypes.Config{ - Peers: []wgtypes.PeerConfig{ - { - PublicKey: publicKey, - Remove: true, + // Remove the peer from ALL tunnel interfaces. + nt := srv.numTunnels() + log.Printf("[%v] removing peer at %v from %d tunnel(s): %v", srv.BindAddr, peerIp, nt, publicKey) + var firstErr error + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + err := srv.WgClient.ConfigureDevice(ifname, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: publicKey, + Remove: true, + }, }, - }, - }) - if err != nil { - return fmt.Errorf("failed to remove WireGuard peer: %v", err) + }) + if err != nil && firstErr == nil { + firstErr = fmt.Errorf("failed to remove WireGuard peer from %s: %v", ifname, err) + } } // Free the IP address. @@ -632,10 +968,12 @@ func (srv *Server) cleanupPeer(publicKey wgtypes.Key) error { srv.ipAllocator.Free(peerIp) } - return nil + return firstErr } func (srv *Server) removeIdlePeers() error { + // Check idle status using the primary tunnel interface (tunnel 0). + // All tunnels share the same peers, so we only need to inspect one. device, err := srv.WgClient.Device(srv.Ifname()) if err != nil { return fmt.Errorf("failed to get WireGuard device: %v", err) @@ -645,7 +983,7 @@ func (srv *Server) removeIdlePeers() error { srv.mu.Lock() defer srv.mu.Unlock() - var removePeers []wgtypes.PeerConfig + var removePeerKeys []wgtypes.Key var removeIps []netip.Addr for _, peer := range device.Peers { var idle bool @@ -654,8 +992,6 @@ func (srv *Server) removeIdlePeers() error { if exists { idle = time.Since(peerInfo.ConnectionTime) > PeerIdleTimeout } else { - // If we somehow have a WireGuard interface for a peer but no allPeers entry, - // let's just assume it's idle and remove it. idle = true } } else { @@ -671,19 +1007,28 @@ func (srv *Server) removeIdlePeers() error { removeIps = append(removeIps, netip.AddrFrom4([4]byte(ipv4))) } } - removePeers = append(removePeers, wgtypes.PeerConfig{ - PublicKey: peer.PublicKey, - Remove: true, - }) + removePeerKeys = append(removePeerKeys, peer.PublicKey) delete(srv.allPeers, peer.PublicKey) } } - if len(removePeers) > 0 { - err := srv.WgClient.ConfigureDevice(srv.Ifname(), wgtypes.Config{Peers: removePeers}) - if err != nil { - return err + if len(removePeerKeys) > 0 { + // Build the peer removal config. + removePeers := make([]wgtypes.PeerConfig, len(removePeerKeys)) + for i, pk := range removePeerKeys { + removePeers[i] = wgtypes.PeerConfig{PublicKey: pk, Remove: true} } + + // Remove from ALL tunnel interfaces. + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + err := srv.WgClient.ConfigureDevice(ifname, wgtypes.Config{Peers: removePeers}) + if err != nil { + log.Printf("warning: failed to remove idle peers from %s: %v", ifname, err) + } + } + for _, ip := range removeIps { srv.ipAllocator.Free(ip) } diff --git a/lib/server_manager.go b/lib/server_manager.go index 632541f..1f4576d 100644 --- a/lib/server_manager.go +++ b/lib/server_manager.go @@ -25,6 +25,7 @@ type ServerManager struct { ipt *iptables.IPTables key wgtypes.Key password string + numTunnels int ctx context.Context waitGroup *sync.WaitGroup wgBlock netip.Prefix @@ -41,7 +42,7 @@ type ServerManager struct { } // NewServerManager creates a new server manager -func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Context, key wgtypes.Key, password string, takeover bool) (*ServerManager, error) { +func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Context, key wgtypes.Key, password string, numTunnels int, takeover bool) (*ServerManager, error) { // Make a shared WireGuard client. wgClient, err := wgctrl.New() if err != nil { @@ -58,11 +59,19 @@ func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Conte color.New(color.Bold).Sprint("server public key:"), key.PublicKey().String()) + if numTunnels <= 0 { + numTunnels = 1 + } + if numTunnels > MaxTunnelsPerServer { + numTunnels = MaxTunnelsPerServer + } + sm := new(ServerManager) sm.wgClient = wgClient sm.ipt = ipt sm.key = key sm.password = password + sm.numTunnels = numTunnels sm.ctx = ctx sm.waitGroup = new(sync.WaitGroup) sm.wgBlock = wgBlock.Masked() @@ -109,15 +118,16 @@ func (sm *ServerManager) Start(ip netip.Addr) error { wgCidr := netip.PrefixFrom(subnetStart.Next(), int(sm.wgBlockPerIp)) srv := &Server{ - Key: sm.key, - BindAddr: ip, - Password: sm.password, - Index: i, - Ipt: sm.ipt, - WgClient: sm.wgClient, - WgCidr: wgCidr, - Ctx: subctx, - takeover: sm.takeover, + Key: sm.key, + BindAddr: ip, + Password: sm.password, + Index: i, + NumTunnels: sm.numTunnels, + Ipt: sm.ipt, + WgClient: sm.wgClient, + WgCidr: wgCidr, + Ctx: subctx, + takeover: sm.takeover, } if err := srv.InitState(); err != nil { _ = cancel // cancel should be discarded