From 8e6435b650d01e905c011deddb71ea9a26722f1e Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Fri, 9 Jan 2026 16:20:04 -0500 Subject: [PATCH 01/14] Updates sysctl settings to improve wireguard performance --- lib/client.go | 24 ++++++++- lib/server.go | 144 ++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 156 insertions(+), 12 deletions(-) diff --git a/lib/client.go b/lib/client.go index 65f58e4..17f5a57 100644 --- a/lib/client.go +++ b/lib/client.go @@ -74,6 +74,20 @@ func (c *Client) CreateInterface() error { return fmt.Errorf("error creating vprox interface: %v", err) } + // Set MTU explicitly for optimal throughput (matching server-side WireguardMTU) + err = netlink.LinkSetMTU(link, 1420) + if err != nil { + netlink.LinkDel(link) + return fmt.Errorf("error setting MTU on vprox interface: %v", err) + } + + // Set TxQLen for improved burst handling (matching server-side WireguardTxQLen) + err = netlink.LinkSetTxQLen(link, 1000) + if err != nil { + // Non-fatal: log warning but continue + log.Printf("warning: failed to set TxQLen on vprox interface: %v", err) + } + return nil } @@ -262,7 +276,15 @@ func (c *Client) DeleteInterface() { } func (c *Client) link() *linkWireguard { - return &linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: c.Ifname}} + return &linkWireguard{LinkAttrs: netlink.LinkAttrs{ + Name: c.Ifname, + MTU: 1420, // WireguardMTU - must match server + TxQLen: 1000, // WireguardTxQLen - for improved burst handling + NumTxQueues: 4, // WireguardNumQueues - for parallel packet processing + NumRxQueues: 4, // WireguardNumQueues - for parallel packet processing + GSOMaxSize: 65536, // WireguardGSOMaxSize - for GSO/GRO offload on Linux 5.19+ + GROMaxSize: 65536, + }} } // CheckConnection checks the status of the connection with the wireguard peer, diff --git a/lib/server.go b/lib/server.go index 0d1ff5a..271dd1d 100644 --- a/lib/server.go +++ b/lib/server.go @@ -28,6 +28,22 @@ const FwmarkBase = 0x54437D00 // UDP listen port base value for WireGuard connections. const WireguardListenPortBase = 50227 +// WireGuard interface MTU. WireGuard adds ~60 bytes overhead (40 for IPv4/UDP + 16 for WG header + padding). +// Setting MTU to 1420 prevents fragmentation on standard 1500 MTU networks. +const WireguardMTU = 1420 + +// WireGuard interface transmit queue length. Higher values reduce packet drops during traffic bursts. +const WireguardTxQLen = 1000 + +// GSO/GRO max size for improved throughput on Linux 5.19+. Allows batching packets before encryption. +const WireguardGSOMaxSize = 65536 + +// TCP MSS for WireGuard traffic. Calculated as MTU (1420) - TCP/IP headers (40) = 1380. +const WireguardMSS = 1380 + +// Number of TX/RX queues for parallel packet processing on multi-core systems. +const WireguardNumQueues = 4 + // A new peer must connect with a handshake within this time. const FirstHandshakeTimeout = 10 * time.Second @@ -422,7 +438,15 @@ func (srv *Server) Ifname() string { func (srv *Server) StartWireguard() error { ifname := srv.Ifname() - link := &linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}} + link := &linkWireguard{LinkAttrs: netlink.LinkAttrs{ + Name: ifname, + MTU: WireguardMTU, + TxQLen: WireguardTxQLen, + NumTxQueues: WireguardNumQueues, + NumRxQueues: WireguardNumQueues, + GSOMaxSize: WireguardGSOMaxSize, + GROMaxSize: WireguardGSOMaxSize, + }} // Track whether we created a fresh interface (for cleanup on error) createdFreshInterface := false @@ -450,9 +474,11 @@ func (srv *Server) StartWireguard() error { } listenPort := WireguardListenPortBase + int(srv.Index) + firewallMark := FwmarkBase + int(srv.Index) err := srv.WgClient.ConfigureDevice(ifname, wgtypes.Config{ - PrivateKey: &srv.Key, - ListenPort: &listenPort, + PrivateKey: &srv.Key, + ListenPort: &listenPort, + FirewallMark: &firewallMark, }) if err != nil { if createdFreshInterface { @@ -478,6 +504,20 @@ func (srv *Server) createFreshInterface(link *linkWireguard) error { return fmt.Errorf("failed to add address to WireGuard device: %v", err) } + // Set MTU explicitly after link creation (some kernels ignore it in LinkAttrs) + err = netlink.LinkSetMTU(link, WireguardMTU) + if err != nil { + netlink.LinkDel(link) + return fmt.Errorf("failed to set MTU on WireGuard device: %v", err) + } + + // Set TxQLen for improved burst handling + err = netlink.LinkSetTxQLen(link, WireguardTxQLen) + if err != nil { + // Non-fatal: log warning but continue + log.Printf("warning: failed to set TxQLen on WireGuard device: %v", err) + } + err = netlink.LinkSetUp(link) if err != nil { netlink.LinkDel(link) @@ -531,20 +571,88 @@ func (srv *Server) iptablesSnatRule(enabled bool) error { } } -// iptablesMssRule adds or removes the FORWARD chain rule for TCP MSS adjustment +// iptablesNotrackRule adds or removes NOTRACK rules in the raw table to bypass +// connection tracking for established WireGuard UDP flows. This significantly +// reduces CPU overhead for high-throughput scenarios. +func (srv *Server) iptablesNotrackRule(enabled bool) error { + listenPort := WireguardListenPortBase + int(srv.Index) + ifname := srv.Ifname() + + // NOTRACK for incoming WireGuard UDP packets (PREROUTING) + ruleIn := []string{ + "-p", "udp", + "--dport", strconv.Itoa(listenPort), + "-j", "NOTRACK", + "-m", "comment", "--comment", fmt.Sprintf("vprox notrack in for %s", ifname), + } + // NOTRACK for outgoing WireGuard UDP packets (OUTPUT) + ruleOut := []string{ + "-p", "udp", + "--sport", strconv.Itoa(listenPort), + "-j", "NOTRACK", + "-m", "comment", "--comment", fmt.Sprintf("vprox notrack out for %s", ifname), + } + + if enabled { + if err := srv.Ipt.AppendUnique("raw", "PREROUTING", ruleIn...); err != nil { + return err + } + if err := srv.Ipt.AppendUnique("raw", "OUTPUT", ruleOut...); err != nil { + srv.Ipt.Delete("raw", "PREROUTING", ruleIn...) + return err + } + return nil + } else { + errIn := srv.Ipt.Delete("raw", "PREROUTING", ruleIn...) + errOut := srv.Ipt.Delete("raw", "OUTPUT", ruleOut...) + if errIn != nil { + return errIn + } + return errOut + } +} + +// iptablesMssRule adds or removes the FORWARD chain rule for TCP MSS adjustment. +// The rule is scoped to traffic from/to the WireGuard interface for this server. func (srv *Server) iptablesMssRule(enabled bool) error { - rule := []string{ + ifname := srv.Ifname() + // Rule for traffic coming from the WireGuard interface + ruleIn := []string{ + "-i", ifname, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", strconv.Itoa(WireguardMSS), + "-m", "comment", "--comment", fmt.Sprintf("vprox mss rule in for %s", ifname), + } + // Rule for traffic going to the WireGuard interface + ruleOut := []string{ + "-o", ifname, "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", - "--set-mss", "1160", - "-m", "comment", "--comment", fmt.Sprintf("vprox mss rule for %s", srv.Ifname()), + "--set-mss", strconv.Itoa(WireguardMSS), + "-m", "comment", "--comment", fmt.Sprintf("vprox mss rule out for %s", ifname), } if enabled { - return srv.Ipt.AppendUnique("filter", "FORWARD", rule...) + if err := srv.Ipt.AppendUnique("filter", "FORWARD", ruleIn...); err != nil { + return err + } + if err := srv.Ipt.AppendUnique("filter", "FORWARD", ruleOut...); err != nil { + // Try to clean up the first rule if the second fails + srv.Ipt.Delete("filter", "FORWARD", ruleIn...) + return err + } + return nil } else { - return srv.Ipt.Delete("filter", "FORWARD", rule...) + // Delete both rules, ignoring errors (rules may not exist) + errIn := srv.Ipt.Delete("filter", "FORWARD", ruleIn...) + errOut := srv.Ipt.Delete("filter", "FORWARD", ruleOut...) + if errIn != nil { + return errIn + } + return errOut } } @@ -567,15 +675,29 @@ func (srv *Server) StartIptables() error { return fmt.Errorf("failed to add MSS rule: %v", err) } + err = srv.iptablesNotrackRule(true) + if err != nil { + srv.iptablesMssRule(false) + srv.iptablesSnatRule(false) + srv.iptablesInputFwmarkRule(false) + return fmt.Errorf("failed to add NOTRACK rule: %v", err) + } + return nil } func (srv *Server) CleanupIptables() { if err := srv.iptablesInputFwmarkRule(false); err != nil { - log.Printf("warning: error cleaning up IP tables: failed to add fwmark rule: %v\n", err) + log.Printf("warning: error cleaning up IP tables: failed to remove fwmark rule: %v\n", err) } if err := srv.iptablesSnatRule(false); err != nil { - log.Printf("warning: error cleaning up IP tables: failed to add SNAT rule: %v\n", err) + log.Printf("warning: error cleaning up IP tables: failed to remove SNAT rule: %v\n", err) + } + if err := srv.iptablesMssRule(false); err != nil { + log.Printf("warning: error cleaning up IP tables: failed to remove MSS rule: %v\n", err) + } + if err := srv.iptablesNotrackRule(false); err != nil { + log.Printf("warning: error cleaning up IP tables: failed to remove NOTRACK rule: %v\n", err) } } From 904ab5ce514d07a7e1272008546b840398bdb8b6 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Wed, 11 Mar 2026 23:50:45 -0400 Subject: [PATCH 02/14] make further changes --- lib/client.go | 10 ++-- lib/server.go | 123 +++++++++++++++++++++++--------------------------- 2 files changed, 61 insertions(+), 72 deletions(-) diff --git a/lib/client.go b/lib/client.go index 17f5a57..2c70ea2 100644 --- a/lib/client.go +++ b/lib/client.go @@ -278,11 +278,11 @@ func (c *Client) DeleteInterface() { func (c *Client) link() *linkWireguard { return &linkWireguard{LinkAttrs: netlink.LinkAttrs{ Name: c.Ifname, - MTU: 1420, // WireguardMTU - must match server - TxQLen: 1000, // WireguardTxQLen - for improved burst handling - NumTxQueues: 4, // WireguardNumQueues - for parallel packet processing - NumRxQueues: 4, // WireguardNumQueues - for parallel packet processing - GSOMaxSize: 65536, // WireguardGSOMaxSize - for GSO/GRO offload on Linux 5.19+ + MTU: 1420, + TxQLen: 1000, + NumTxQueues: 4, + NumRxQueues: 4, + GSOMaxSize: 65536, GROMaxSize: 65536, }} } diff --git a/lib/server.go b/lib/server.go index 271dd1d..a26dbbc 100644 --- a/lib/server.go +++ b/lib/server.go @@ -28,17 +28,21 @@ const FwmarkBase = 0x54437D00 // UDP listen port base value for WireGuard connections. const WireguardListenPortBase = 50227 -// WireGuard interface MTU. WireGuard adds ~60 bytes overhead (40 for IPv4/UDP + 16 for WG header + padding). -// Setting MTU to 1420 prevents fragmentation on standard 1500 MTU networks. +// WireGuard interface MTU. WireGuard adds ~60 bytes overhead (40 for IPv4/UDP +// + 16 for WG header + padding). Setting MTU to 1420 prevents fragmentation +// on standard 1500 MTU networks. const WireguardMTU = 1420 -// WireGuard interface transmit queue length. Higher values reduce packet drops during traffic bursts. +// WireGuard interface transmit queue length. Higher values reduce packet drops +// during traffic bursts. const WireguardTxQLen = 1000 -// GSO/GRO max size for improved throughput on Linux 5.19+. Allows batching packets before encryption. +// GSO/GRO max size for improved throughput on Linux 5.19+. Allows the kernel +// to batch packets into large 64 KB super-packets before encryption/decryption. const WireguardGSOMaxSize = 65536 -// TCP MSS for WireGuard traffic. Calculated as MTU (1420) - TCP/IP headers (40) = 1380. +// TCP MSS for traffic through the WireGuard tunnel, calculated as +// MTU (1420) - IP header (20) - TCP header (20) = 1380. const WireguardMSS = 1380 // Number of TX/RX queues for parallel packet processing on multi-core systems. @@ -571,89 +575,79 @@ func (srv *Server) iptablesSnatRule(enabled bool) error { } } -// iptablesNotrackRule adds or removes NOTRACK rules in the raw table to bypass -// connection tracking for established WireGuard UDP flows. This significantly -// reduces CPU overhead for high-throughput scenarios. +// iptablesNotrackRule adds or removes a NOTRACK rule in the raw table to bypass +// connection tracking for WireGuard UDP traffic. This significantly reduces +// per-packet CPU overhead for tunneled flows. func (srv *Server) iptablesNotrackRule(enabled bool) error { - listenPort := WireguardListenPortBase + int(srv.Index) - ifname := srv.Ifname() - - // NOTRACK for incoming WireGuard UDP packets (PREROUTING) - ruleIn := []string{ + listenPort := strconv.Itoa(WireguardListenPortBase + int(srv.Index)) + // Inbound WireGuard UDP + inRule := []string{ "-p", "udp", - "--dport", strconv.Itoa(listenPort), + "--dport", listenPort, "-j", "NOTRACK", - "-m", "comment", "--comment", fmt.Sprintf("vprox notrack in for %s", ifname), + "-m", "comment", "--comment", fmt.Sprintf("vprox notrack in for %s", srv.Ifname()), } - // NOTRACK for outgoing WireGuard UDP packets (OUTPUT) - ruleOut := []string{ + // Outbound WireGuard UDP + outRule := []string{ "-p", "udp", - "--sport", strconv.Itoa(listenPort), + "--sport", listenPort, "-j", "NOTRACK", - "-m", "comment", "--comment", fmt.Sprintf("vprox notrack out for %s", ifname), + "-m", "comment", "--comment", fmt.Sprintf("vprox notrack out for %s", srv.Ifname()), } - if enabled { - if err := srv.Ipt.AppendUnique("raw", "PREROUTING", ruleIn...); err != nil { - return err + if err := srv.Ipt.AppendUnique("raw", "PREROUTING", inRule...); err != nil { + return fmt.Errorf("failed to add NOTRACK PREROUTING rule: %v", err) } - if err := srv.Ipt.AppendUnique("raw", "OUTPUT", ruleOut...); err != nil { - srv.Ipt.Delete("raw", "PREROUTING", ruleIn...) - return err + if err := srv.Ipt.AppendUnique("raw", "OUTPUT", outRule...); err != nil { + // Clean up the first rule on failure. + srv.Ipt.Delete("raw", "PREROUTING", inRule...) + return fmt.Errorf("failed to add NOTRACK OUTPUT rule: %v", err) } return nil - } else { - errIn := srv.Ipt.Delete("raw", "PREROUTING", ruleIn...) - errOut := srv.Ipt.Delete("raw", "OUTPUT", ruleOut...) - if errIn != nil { - return errIn - } - return errOut } + // Cleanup: best-effort, ignore errors. + srv.Ipt.Delete("raw", "PREROUTING", inRule...) + srv.Ipt.Delete("raw", "OUTPUT", outRule...) + return nil } -// iptablesMssRule adds or removes the FORWARD chain rule for TCP MSS adjustment. -// The rule is scoped to traffic from/to the WireGuard interface for this server. +// iptablesMssRule adds or removes FORWARD chain rules for TCP MSS clamping in +// both directions. Uses the mangle table which is the correct place for packet +// modification. We need both -o (traffic entering the tunnel, server→client) +// and -i (traffic leaving the tunnel, client→server) so that SYN packets in +// either direction get their MSS clamped to fit within the WireGuard MTU. func (srv *Server) iptablesMssRule(enabled bool) error { - ifname := srv.Ifname() - // Rule for traffic coming from the WireGuard interface - ruleIn := []string{ - "-i", ifname, + outRule := []string{ + "-o", srv.Ifname(), "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--set-mss", strconv.Itoa(WireguardMSS), - "-m", "comment", "--comment", fmt.Sprintf("vprox mss rule in for %s", ifname), + "-m", "comment", "--comment", fmt.Sprintf("vprox mss out rule for %s", srv.Ifname()), } - // Rule for traffic going to the WireGuard interface - ruleOut := []string{ - "-o", ifname, + inRule := []string{ + "-i", srv.Ifname(), "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--set-mss", strconv.Itoa(WireguardMSS), - "-m", "comment", "--comment", fmt.Sprintf("vprox mss rule out for %s", ifname), + "-m", "comment", "--comment", fmt.Sprintf("vprox mss in rule for %s", srv.Ifname()), } if enabled { - if err := srv.Ipt.AppendUnique("filter", "FORWARD", ruleIn...); err != nil { + if err := srv.Ipt.AppendUnique("mangle", "FORWARD", outRule...); err != nil { return err } - if err := srv.Ipt.AppendUnique("filter", "FORWARD", ruleOut...); err != nil { - // Try to clean up the first rule if the second fails - srv.Ipt.Delete("filter", "FORWARD", ruleIn...) + if err := srv.Ipt.AppendUnique("mangle", "FORWARD", inRule...); err != nil { + srv.Ipt.Delete("mangle", "FORWARD", outRule...) return err } return nil - } else { - // Delete both rules, ignoring errors (rules may not exist) - errIn := srv.Ipt.Delete("filter", "FORWARD", ruleIn...) - errOut := srv.Ipt.Delete("filter", "FORWARD", ruleOut...) - if errIn != nil { - return errIn - } - return errOut } + // Cleanup: best-effort both directions. + srv.Ipt.Delete("mangle", "FORWARD", outRule...) + srv.Ipt.Delete("mangle", "FORWARD", inRule...) + return nil } func (srv *Server) StartIptables() error { @@ -675,12 +669,9 @@ func (srv *Server) StartIptables() error { return fmt.Errorf("failed to add MSS rule: %v", err) } - err = srv.iptablesNotrackRule(true) - if err != nil { - srv.iptablesMssRule(false) - srv.iptablesSnatRule(false) - srv.iptablesInputFwmarkRule(false) - return fmt.Errorf("failed to add NOTRACK rule: %v", err) + // NOTRACK is best-effort — don't fail startup if the raw table isn't available. + if err = srv.iptablesNotrackRule(true); err != nil { + log.Printf("warning: failed to add NOTRACK rules (non-fatal): %v", err) } return nil @@ -688,17 +679,15 @@ func (srv *Server) StartIptables() error { func (srv *Server) CleanupIptables() { if err := srv.iptablesInputFwmarkRule(false); err != nil { - log.Printf("warning: error cleaning up IP tables: failed to remove fwmark rule: %v\n", err) + log.Printf("warning: error cleaning up iptables fwmark rule: %v\n", err) } if err := srv.iptablesSnatRule(false); err != nil { - log.Printf("warning: error cleaning up IP tables: failed to remove SNAT rule: %v\n", err) + log.Printf("warning: error cleaning up iptables SNAT rule: %v\n", err) } if err := srv.iptablesMssRule(false); err != nil { - log.Printf("warning: error cleaning up IP tables: failed to remove MSS rule: %v\n", err) - } - if err := srv.iptablesNotrackRule(false); err != nil { - log.Printf("warning: error cleaning up IP tables: failed to remove NOTRACK rule: %v\n", err) + log.Printf("warning: error cleaning up iptables MSS rule: %v\n", err) } + srv.iptablesNotrackRule(false) } func (srv *Server) removeIdlePeersLoop() { From 0a29730c4235b125915b3f5b904325800b0afdb5 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 00:52:20 -0400 Subject: [PATCH 03/14] creates multiple tunnels --- cmd/connect.go | 16 +- cmd/server.go | 9 +- lib/client.go | 254 +++++++++++++++++++++----- lib/server.go | 416 ++++++++++++++++++++++++++++-------------- lib/server_manager.go | 30 ++- 5 files changed, 524 insertions(+), 201 deletions(-) diff --git a/cmd/connect.go b/cmd/connect.go index 0cf7f12..b938303 100644 --- a/cmd/connect.go +++ b/cmd/connect.go @@ -68,12 +68,15 @@ var ConnectCmd = &cobra.Command{ } var connectCmdArgs struct { - ifname string + ifname string + tunnels int } func init() { ConnectCmd.Flags().StringVar(&connectCmdArgs.ifname, "interface", "vprox0", "Interface name to proxy traffic through the VPN") + ConnectCmd.Flags().IntVar(&connectCmdArgs.tunnels, "tunnels", + 1, "Number of parallel WireGuard tunnels (higher values improve throughput by spreading traffic across NIC queues)") } func runConnect(cmd *cobra.Command, args []string) error { @@ -98,11 +101,12 @@ func runConnect(cmd *cobra.Command, args []string) error { } client := &lib.Client{ - Key: key, - Ifname: connectCmdArgs.ifname, - ServerIp: serverIp, - Password: password, - WgClient: wgClient, + Key: key, + Ifname: connectCmdArgs.ifname, + ServerIp: serverIp, + Password: password, + NumTunnels: connectCmdArgs.tunnels, + WgClient: wgClient, Http: &http.Client{ Timeout: 5 * time.Second, Transport: &http.Transport{ diff --git a/cmd/server.go b/cmd/server.go index 47c366b..7dee489 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -29,6 +29,7 @@ var serverCmdArgs struct { wgBlockPerIp string cloud string takeover bool + tunnels int } func init() { @@ -42,6 +43,8 @@ func init() { "", "Cloud provider for IP metadata (watches for changes)") ServerCmd.Flags().BoolVar(&serverCmdArgs.takeover, "takeover", false, "Take over existing WireGuard state from a previous server instance (for non-disruptive upgrades)") + ServerCmd.Flags().IntVar(&serverCmdArgs.tunnels, "tunnels", + 1, "Number of parallel WireGuard tunnels per IP (higher values improve throughput by spreading traffic across NIC queues)") } func runServer(cmd *cobra.Command, args []string) error { @@ -98,9 +101,13 @@ func runServer(cmd *cobra.Command, args []string) error { return err } + if serverCmdArgs.tunnels < 1 || serverCmdArgs.tunnels > lib.MaxTunnelsPerServer { + return fmt.Errorf("--tunnels must be between 1 and %d", lib.MaxTunnelsPerServer) + } + ctx, done := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password, serverCmdArgs.takeover) + sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password, serverCmdArgs.tunnels, serverCmdArgs.takeover) if err != nil { done() return err diff --git a/lib/client.go b/lib/client.go index 2c70ea2..97ed3d9 100644 --- a/lib/client.go +++ b/lib/client.go @@ -40,12 +40,15 @@ func IsRecoverableError(err error) bool { return true } -// Client manages a peering connection with with a local WireGuard interface. +// Client manages a peering connection with a local WireGuard interface (or a +// set of parallel WireGuard interfaces when multi-tunnel is enabled). type Client struct { // Key is the private key of the client. Key wgtypes.Key - // Ifname is the name of the client WireGuard interface. + // Ifname is the base name of the client WireGuard interface (e.g. "vprox0"). + // With multi-tunnel this becomes the primary interface; additional tunnels + // are named "vprox0t1", "vprox0t2", etc. Ifname string // ServerIp is the public IPv4 address of the server. @@ -54,6 +57,10 @@ type Client struct { // Password authenticates the client connection. Password string + // NumTunnels is the number of parallel WireGuard tunnels to create. + // When <= 1, the client behaves exactly as before (single interface). + NumTunnels int + // WgClient is a shared client for interacting with the WireGuard kernel module. WgClient *wgctrl.Client @@ -62,45 +69,120 @@ type Client struct { // wgCidr is the current subnet assigned to the WireGuard interface, if any. wgCidr netip.Prefix + + // activeTunnels tracks how many tunnel interfaces were actually created + // during the last successful Connect(). This may be less than NumTunnels + // if the server returned fewer Tunnels entries (e.g. old server). + activeTunnels int +} + +// numTunnels returns the effective tunnel count, defaulting to 1. +func (c *Client) numTunnels() int { + if c.NumTunnels <= 1 { + return 1 + } + return c.NumTunnels } -// CreateInterface creates a new interface for wireguard. DeleteInterface() needs -// to be called to clean this up. +// tunnelIfname returns the interface name for the t-th tunnel. +// Tunnel 0 uses Ifname directly (e.g. "vprox0"). +// Tunnel 1+ appends "t1", "t2", etc. (e.g. "vprox0t1", "vprox0t2"). +func (c *Client) tunnelIfname(t int) string { + if t == 0 { + return c.Ifname + } + return fmt.Sprintf("%st%d", c.Ifname, t) +} + +// tunnelLink builds a linkWireguard for the t-th tunnel with tuned LinkAttrs. +func (c *Client) tunnelLink(t int) *linkWireguard { + return &linkWireguard{LinkAttrs: netlink.LinkAttrs{ + Name: c.tunnelIfname(t), + MTU: WireguardMTU, + TxQLen: WireguardTxQLen, + NumTxQueues: WireguardNumQueues, + NumRxQueues: WireguardNumQueues, + GSOMaxSize: WireguardGSOMaxSize, + GROMaxSize: WireguardGSOMaxSize, + }} +} + +// link returns a linkWireguard for the primary (tunnel 0) interface. +func (c *Client) link() *linkWireguard { + return c.tunnelLink(0) +} + +// CreateInterface creates the WireGuard interface(s). For single-tunnel mode +// this creates one interface; for multi-tunnel mode it creates N interfaces. +// DeleteInterface() must be called to clean up. func (c *Client) CreateInterface() error { - link := c.link() + nt := c.numTunnels() + for t := 0; t < nt; t++ { + if err := c.createTunnelInterface(t); err != nil { + // Clean up any interfaces we already created. + for rb := 0; rb < t; rb++ { + c.deleteTunnelInterface(rb) + } + return err + } + } + if nt > 1 { + log.Printf("created %d tunnel interfaces (%s .. %s)", nt, c.tunnelIfname(0), c.tunnelIfname(nt-1)) + } + return nil +} + +// createTunnelInterface creates a single WireGuard tunnel interface. +func (c *Client) createTunnelInterface(t int) error { + link := c.tunnelLink(t) err := netlink.LinkAdd(link) if err != nil { - return fmt.Errorf("error creating vprox interface: %v", err) + return fmt.Errorf("error creating vprox interface %s: %v", link.Name, err) } - // Set MTU explicitly for optimal throughput (matching server-side WireguardMTU) - err = netlink.LinkSetMTU(link, 1420) + // Set MTU explicitly (some kernels ignore LinkAttrs.MTU on creation) + err = netlink.LinkSetMTU(link, WireguardMTU) if err != nil { netlink.LinkDel(link) - return fmt.Errorf("error setting MTU on vprox interface: %v", err) + return fmt.Errorf("error setting MTU on vprox interface %s: %v", link.Name, err) } - // Set TxQLen for improved burst handling (matching server-side WireguardTxQLen) - err = netlink.LinkSetTxQLen(link, 1000) + // Set TxQLen for improved burst handling + err = netlink.LinkSetTxQLen(link, WireguardTxQLen) if err != nil { // Non-fatal: log warning but continue - log.Printf("warning: failed to set TxQLen on vprox interface: %v", err) + log.Printf("warning: failed to set TxQLen on vprox interface %s: %v", link.Name, err) } return nil } -// Connect attempts to reconnect to the peer. A network interface needs to -// have already been created with CreateInterface() before calling Connect() +// Connect attempts to connect (or reconnect) to the server. All tunnel +// interfaces must already exist via CreateInterface(). func (c *Client) Connect() error { resp, err := c.sendConnectionRequest() if err != nil { return err } - link := c.link() - err = netlink.LinkSetUp(link) + // Determine how many tunnels to actually use. Use the minimum of what + // the client wants and what the server offers. + nt := c.numTunnels() + serverTunnels := len(resp.Tunnels) + if serverTunnels > 0 && serverTunnels < nt { + nt = serverTunnels + } + // If the server returned no Tunnels list (old server), use 1 tunnel. + if serverTunnels == 0 { + nt = 1 + } + c.activeTunnels = nt + + // Bring up and configure the primary interface (tunnel 0) — this is the + // one that gets the IP address and subnet route. + primaryLink := c.tunnelLink(0) + err = netlink.LinkSetUp(primaryLink) if err != nil { return fmt.Errorf("error setting up vprox interface: %v", err) } @@ -110,15 +192,43 @@ func (c *Client) Connect() error { return err } - err = c.configureWireguard(resp) + // Configure WireGuard on tunnel 0 using the primary ServerListenPort + // (works for both old and new servers). + err = c.configureWireguardTunnel(0, resp, resp.ServerListenPort) if err != nil { - return fmt.Errorf("error configuring wireguard interface: %v", err) + return fmt.Errorf("error configuring wireguard on %s: %v", c.tunnelIfname(0), err) + } + + // Configure additional tunnels if the server provided them. + for t := 1; t < nt; t++ { + link := c.tunnelLink(t) + err = netlink.LinkSetUp(link) + if err != nil { + return fmt.Errorf("error setting up vprox interface %s: %v", link.Name, err) + } + + port := resp.Tunnels[t].ListenPort + err = c.configureWireguardTunnel(t, resp, port) + if err != nil { + return fmt.Errorf("error configuring wireguard on %s: %v", c.tunnelIfname(t), err) + } + } + + // Set up multipath routing if we have multiple active tunnels. + if nt > 1 { + if err := c.setupMultipathRouting(nt); err != nil { + log.Printf("warning: failed to set up multipath routing: %v", err) + // Fall back: traffic will just use the primary interface's route. + } else { + log.Printf("multipath routing configured across %d tunnels", nt) + } } return nil } -// updateInterface updates the wireguard interface based on the provided connectionResponse +// updateInterface updates the primary WireGuard interface (tunnel 0) address +// based on the connect response. func (c *Client) updateInterface(resp connectResponse) error { cidr, err := netip.ParsePrefix(resp.AssignedAddr) if err != nil { @@ -131,7 +241,6 @@ func (c *Client) updateInterface(resp connectResponse) error { if c.wgCidr.IsValid() { oldIpnet := prefixToIPNet(c.wgCidr) err = netlink.AddrDel(link, &netlink.Addr{IPNet: &oldIpnet}) - if err != nil { log.Printf("warning: failed to remove old address from vprox interface when reconnecting: %v", err) } @@ -147,6 +256,59 @@ func (c *Client) updateInterface(resp connectResponse) error { return nil } +// setupMultipathRouting creates equal-cost multipath routes across all active +// tunnel interfaces so that the kernel distributes flows across them. +func (c *Client) setupMultipathRouting(nt int) error { + if !c.wgCidr.IsValid() { + return fmt.Errorf("no valid CIDR assigned yet") + } + + // The server's WireGuard IP is the first address in the subnet (the + // gateway for our multipath nexthops). + gwAddr := c.wgCidr.Masked().Addr().Next() + gwIP := addrToIp(gwAddr) + + // Build multipath nexthops — one per tunnel interface. + var nexthops []*netlink.NexthopInfo + for t := 0; t < nt; t++ { + ifname := c.tunnelIfname(t) + link, err := netlink.LinkByName(ifname) + if err != nil { + return fmt.Errorf("failed to find interface %s: %v", ifname, err) + } + nexthops = append(nexthops, &netlink.NexthopInfo{ + LinkIndex: link.Attrs().Index, + Gw: gwIP, + Hops: 0, // equal weight + }) + } + + // Remove any existing default route on the primary interface first. + // (The kernel creates one when we assign the address.) + existingRoutes, _ := netlink.RouteList(nil, netlink.FAMILY_V4) + for i := range existingRoutes { + r := &existingRoutes[i] + if r.Dst != nil && r.Dst.String() == c.wgCidr.Masked().String() { + // This is the subnet route — we need to replace it with multipath. + _ = netlink.RouteDel(r) + } + } + + // Add the multipath route for the WireGuard subnet. + subnetIPNet := prefixToIPNet(c.wgCidr.Masked()) + route := &netlink.Route{ + Dst: &subnetIPNet, + MultiPath: nexthops, + } + + err := netlink.RouteReplace(route) + if err != nil { + return fmt.Errorf("failed to add multipath route: %v", err) + } + + return nil +} + // sendConnectionRequest attempts to send a connection request to the peer func (c *Client) sendConnectionRequest() (connectResponse, error) { connectUrl, err := url.Parse(fmt.Sprintf("https://%s/connect", c.ServerIp)) @@ -195,15 +357,17 @@ func (c *Client) sendConnectionRequest() (connectResponse, error) { return respJson, nil } -// configureWireguard configures the WireGuard peer. -func (c *Client) configureWireguard(connectionResponse connectResponse) error { - serverPublicKey, err := wgtypes.ParseKey(connectionResponse.ServerPublicKey) +// configureWireguardTunnel configures a single WireGuard tunnel interface with +// the server as a peer on the given port. +func (c *Client) configureWireguardTunnel(t int, resp connectResponse, serverPort int) error { + serverPublicKey, err := wgtypes.ParseKey(resp.ServerPublicKey) if err != nil { return fmt.Errorf("failed to parse server public key: %v", err) } keepalive := 25 * time.Second - return c.WgClient.ConfigureDevice(c.Ifname, wgtypes.Config{ + ifname := c.tunnelIfname(t) + return c.WgClient.ConfigureDevice(ifname, wgtypes.Config{ PrivateKey: &c.Key, ReplacePeers: true, Peers: []wgtypes.PeerConfig{ @@ -211,7 +375,7 @@ func (c *Client) configureWireguard(connectionResponse connectResponse) error { PublicKey: serverPublicKey, Endpoint: &net.UDPAddr{ IP: addrToIp(c.ServerIp), - Port: connectionResponse.ServerListenPort, + Port: serverPort, }, PersistentKeepaliveInterval: &keepalive, ReplaceAllowedIPs: true, @@ -224,6 +388,11 @@ func (c *Client) configureWireguard(connectionResponse connectResponse) error { }) } +// configureWireguard configures WireGuard on the primary tunnel (backwards compat). +func (c *Client) configureWireguard(connectionResponse connectResponse) error { + return c.configureWireguardTunnel(0, connectionResponse, connectionResponse.ServerListenPort) +} + // Disconnect notifies the server that this client is disconnecting, allowing the // server to immediately reclaim resources (wireguard peer and subnet IP) instead of // waiting for the idle timeout. @@ -264,32 +433,31 @@ func (c *Client) Disconnect() error { return nil } +// DeleteInterface removes all WireGuard tunnel interfaces. func (c *Client) DeleteInterface() { - // Delete the WireGuard interface. - log.Printf("About to delete vprox interface %v", c.Ifname) - err := netlink.LinkDel(c.link()) - if err != nil { - log.Printf("error deleting vprox interface %v: %v", c.Ifname, err) - } else { - log.Printf("successfully deleted vprox interface %v", c.Ifname) + nt := c.numTunnels() + for t := nt - 1; t >= 0; t-- { + c.deleteTunnelInterface(t) } } -func (c *Client) link() *linkWireguard { - return &linkWireguard{LinkAttrs: netlink.LinkAttrs{ - Name: c.Ifname, - MTU: 1420, - TxQLen: 1000, - NumTxQueues: 4, - NumRxQueues: 4, - GSOMaxSize: 65536, - GROMaxSize: 65536, - }} +// deleteTunnelInterface removes a single WireGuard tunnel interface. +func (c *Client) deleteTunnelInterface(t int) { + ifname := c.tunnelIfname(t) + log.Printf("About to delete vprox interface %v", ifname) + link := &linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}} + err := netlink.LinkDel(link) + if err != nil { + log.Printf("error deleting vprox interface %v: %v", ifname, err) + } else { + log.Printf("successfully deleted vprox interface %v", ifname) + } } // CheckConnection checks the status of the connection with the wireguard peer, // and returns true if it is healthy. This sends 3 pings in succession, and blocks // until they receive a response or the timeout passes. +// Pings are sent through the primary tunnel interface (tunnel 0). func (c *Client) CheckConnection(timeout time.Duration, cancelCtx context.Context) bool { pinger, err := probing.NewPinger(c.wgCidr.Masked().Addr().Next().String()) if err != nil { @@ -301,7 +469,7 @@ func (c *Client) CheckConnection(timeout time.Duration, cancelCtx context.Contex pinger.Timeout = timeout pinger.Count = 3 pinger.Interval = 10 * time.Millisecond // Send approximately all at once - err = pinger.RunWithContext(cancelCtx) // Blocks until finished. + err = pinger.RunWithContext(cancelCtx) // Blocks until finished. if err != nil { log.Printf("error running pinger: %v", err) return false diff --git a/lib/server.go b/lib/server.go index a26dbbc..6cd052a 100644 --- a/lib/server.go +++ b/lib/server.go @@ -48,6 +48,16 @@ const WireguardMSS = 1380 // Number of TX/RX queues for parallel packet processing on multi-core systems. const WireguardNumQueues = 4 +// MaxTunnelsPerServer is the maximum number of parallel WireGuard tunnels +// allowed per server (per bind IP). Each tunnel uses a different UDP port +// so that the NIC hashes them to different hardware RX queues. +const MaxTunnelsPerServer = 16 + +// PortsPerIndex is the number of UDP ports reserved per server index. +// This must be >= MaxTunnelsPerServer. With this spacing, server index 0 +// uses ports 50227..50242, index 1 uses 50243..50258, etc. +const PortsPerIndex = MaxTunnelsPerServer + // A new peer must connect with a handshake within this time. const FirstHandshakeTimeout = 10 * time.Second @@ -88,6 +98,12 @@ type Server struct { // Index is a unique server index for firewall marks and other uses. It starts at 0. Index uint16 + // NumTunnels is the number of parallel WireGuard tunnels to create for + // this server. Each tunnel listens on a different UDP port so that the NIC + // hashes them to different hardware RX queues, increasing throughput beyond + // the single-flow limit. Defaults to 1 for backwards compatibility. + NumTunnels int + // Ipt is the iptables client for managing firewall rules. Ipt *iptables.IPTables @@ -114,6 +130,17 @@ type Server struct { takeover bool } +// numTunnels returns the effective tunnel count, defaulting to 1. +func (srv *Server) numTunnels() int { + if srv.NumTunnels <= 0 { + return 1 + } + if srv.NumTunnels > MaxTunnelsPerServer { + return MaxTunnelsPerServer + } + return srv.NumTunnels +} + // InitState initializes the private server state. func (srv *Server) InitState() error { if srv.BindIface == nil { @@ -197,10 +224,25 @@ func (srv *Server) indexHandler(w http.ResponseWriter, r *http.Request) { type connectRequest struct { PeerPublicKey string } + +// TunnelInfo describes a single WireGuard tunnel endpoint within a multi-tunnel +// connection. Clients that support multi-tunnel will use these to set up +// parallel WireGuard interfaces. +type TunnelInfo struct { + ListenPort int `json:"ListenPort"` + Ifname string `json:"Ifname"` +} + type connectResponse struct { AssignedAddr string ServerPublicKey string ServerListenPort int + + // Tunnels lists all available tunnel endpoints for this server. Clients + // that support multi-tunnel create one WireGuard interface per entry. + // Clients that don't understand this field will fall back to the single + // ServerListenPort above (backwards compatible). + Tunnels []TunnelInfo `json:"Tunnels,omitempty"` } // Handle a new connection. @@ -262,31 +304,56 @@ func (srv *Server) connectHandler(w http.ResponseWriter, r *http.Request) { clientIp := strings.Split(r.RemoteAddr, ":")[0] // for logging log.Printf("[%v] new peer %v at %v: %v", srv.BindAddr, clientIp, peerIp, peerKey) - err = srv.WgClient.ConfigureDevice(srv.Ifname(), wgtypes.Config{ - Peers: []wgtypes.PeerConfig{ - { - PublicKey: peerKey, - ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{prefixToIPNet(netip.PrefixFrom(peerIp, 32))}, + + // Add the peer to ALL tunnel interfaces so that traffic arriving on any + // tunnel is accepted, and the server can send traffic back on any tunnel. + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + err = srv.WgClient.ConfigureDevice(ifname, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + ReplaceAllowedIPs: true, + AllowedIPs: []net.IPNet{prefixToIPNet(netip.PrefixFrom(peerIp, 32))}, + }, }, - }, - }) - if err != nil { - srv.mu.Lock() - delete(srv.allPeers, peerKey) - srv.mu.Unlock() + }) + if err != nil { + // Roll back: remove from any interfaces we already configured. + for rb := 0; rb < t; rb++ { + rbIfname := srv.TunnelIfname(rb) + _ = srv.WgClient.ConfigureDevice(rbIfname, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{{PublicKey: peerKey, Remove: true}}, + }) + } - srv.ipAllocator.Free(peerIp) - log.Printf("failed to configure WireGuard peer: %v", err) - http.Error(w, "failed to configure WireGuard peer", http.StatusInternalServerError) - return + srv.mu.Lock() + delete(srv.allPeers, peerKey) + srv.mu.Unlock() + srv.ipAllocator.Free(peerIp) + + log.Printf("failed to configure WireGuard peer on %s: %v", ifname, err) + http.Error(w, "failed to configure WireGuard peer", http.StatusInternalServerError) + return + } + } + + // Build the Tunnels list for multi-tunnel clients. + tunnels := make([]TunnelInfo, nt) + for t := 0; t < nt; t++ { + tunnels[t] = TunnelInfo{ + ListenPort: srv.tunnelListenPort(t), + Ifname: srv.TunnelIfname(t), + } } // Return the assigned IP address and the server's public key. resp := &connectResponse{ AssignedAddr: fmt.Sprintf("%v/%d", peerIp, srv.WgCidr.Bits()), ServerPublicKey: srv.Key.PublicKey().String(), - ServerListenPort: WireguardListenPortBase + int(srv.Index), + ServerListenPort: srv.tunnelListenPort(0), // primary tunnel for old clients + Tunnels: tunnels, } respBuf, err := json.Marshal(resp) @@ -436,12 +503,50 @@ func (srv *Server) versionHandler(w http.ResponseWriter, r *http.Request) { w.Write(respBuf) } +// Ifname returns the primary WireGuard interface name (tunnel 0). This is used +// for backwards-compatible code paths like takeover and idle peer removal. func (srv *Server) Ifname() string { - return fmt.Sprintf("vprox%d", srv.Index) + return srv.TunnelIfname(0) +} + +// TunnelIfname returns the WireGuard interface name for the t-th tunnel. +// When NumTunnels == 1, this returns "vprox0" (same as before). +// When NumTunnels > 1, tunnel 0 is "vprox0", tunnel 1 is "vprox0t1", etc. +func (srv *Server) TunnelIfname(t int) string { + base := fmt.Sprintf("vprox%d", srv.Index) + if t == 0 { + return base + } + return fmt.Sprintf("%st%d", base, t) } +// tunnelListenPort returns the UDP listen port for the t-th tunnel. +func (srv *Server) tunnelListenPort(t int) int { + return WireguardListenPortBase + int(srv.Index)*PortsPerIndex + t +} + +// StartWireguard creates and configures all tunnel WireGuard interfaces. func (srv *Server) StartWireguard() error { - ifname := srv.Ifname() + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + if err := srv.startWireguardTunnel(t); err != nil { + // Clean up any tunnels we already created. + for rb := 0; rb < t; rb++ { + srv.cleanupWireguardTunnel(rb) + } + return err + } + } + if nt > 1 { + log.Printf("[%v] started %d WireGuard tunnels (ports %d..%d)", + srv.BindAddr, nt, srv.tunnelListenPort(0), srv.tunnelListenPort(nt-1)) + } + return nil +} + +// startWireguardTunnel creates and configures a single WireGuard tunnel interface. +func (srv *Server) startWireguardTunnel(t int) error { + ifname := srv.TunnelIfname(t) link := &linkWireguard{LinkAttrs: netlink.LinkAttrs{ Name: ifname, MTU: WireguardMTU, @@ -452,37 +557,31 @@ func (srv *Server) StartWireguard() error { GROMaxSize: WireguardGSOMaxSize, }} - // Track whether we created a fresh interface (for cleanup on error) createdFreshInterface := false if srv.takeover { - // In takeover mode, use existing interface if available, otherwise create fresh _, err := netlink.LinkByName(ifname) if err == nil { log.Printf("[%v] takeover mode: using existing WireGuard interface %s", srv.BindAddr, ifname) } else { - // Interface doesn't exist, create fresh log.Printf("[%v] takeover mode: interface %s not found, creating fresh", srv.BindAddr, ifname) - if err := srv.createFreshInterface(link); err != nil { + if err := srv.createFreshInterface(link, t); err != nil { return err } createdFreshInterface = true } } else { - // Normal mode: delete and recreate the interface _ = netlink.LinkDel(link) // remove if it already exists - if err := srv.createFreshInterface(link); err != nil { + if err := srv.createFreshInterface(link, t); err != nil { return err } createdFreshInterface = true } - listenPort := WireguardListenPortBase + int(srv.Index) - firewallMark := FwmarkBase + int(srv.Index) + listenPort := srv.tunnelListenPort(t) err := srv.WgClient.ConfigureDevice(ifname, wgtypes.Config{ - PrivateKey: &srv.Key, - ListenPort: &listenPort, - FirewallMark: &firewallMark, + PrivateKey: &srv.Key, + ListenPort: &listenPort, }) if err != nil { if createdFreshInterface { @@ -495,42 +594,48 @@ func (srv *Server) StartWireguard() error { } // createFreshInterface creates and configures a new WireGuard interface. -func (srv *Server) createFreshInterface(link *linkWireguard) error { +// Only tunnel 0 gets an IP address assigned — the other tunnels share the same +// subnet via the kernel routing table entry that tunnel 0 creates. +func (srv *Server) createFreshInterface(link *linkWireguard, tunnelIndex int) error { err := netlink.LinkAdd(link) if err != nil { - return fmt.Errorf("failed to create WireGuard device: %v", err) + return fmt.Errorf("failed to create WireGuard device %s: %v", link.Name, err) } - ipnet := prefixToIPNet(srv.WgCidr) - err = netlink.AddrAdd(link, &netlink.Addr{IPNet: &ipnet}) - if err != nil { - netlink.LinkDel(link) - return fmt.Errorf("failed to add address to WireGuard device: %v", err) + // Only the primary tunnel (index 0) gets the subnet IP address. + // Additional tunnels participate in the same subnet without their own address. + if tunnelIndex == 0 { + ipnet := prefixToIPNet(srv.WgCidr) + err = netlink.AddrAdd(link, &netlink.Addr{IPNet: &ipnet}) + if err != nil { + netlink.LinkDel(link) + return fmt.Errorf("failed to add address to WireGuard device %s: %v", link.Name, err) + } } // Set MTU explicitly after link creation (some kernels ignore it in LinkAttrs) err = netlink.LinkSetMTU(link, WireguardMTU) if err != nil { netlink.LinkDel(link) - return fmt.Errorf("failed to set MTU on WireGuard device: %v", err) + return fmt.Errorf("failed to set MTU on WireGuard device %s: %v", link.Name, err) } // Set TxQLen for improved burst handling err = netlink.LinkSetTxQLen(link, WireguardTxQLen) if err != nil { - // Non-fatal: log warning but continue - log.Printf("warning: failed to set TxQLen on WireGuard device: %v", err) + log.Printf("warning: failed to set TxQLen on WireGuard device %s: %v", link.Name, err) } err = netlink.LinkSetUp(link) if err != nil { netlink.LinkDel(link) - return fmt.Errorf("failed to bring up WireGuard device: %v", err) + return fmt.Errorf("failed to bring up WireGuard device %s: %v", link.Name, err) } return nil } +// CleanupWireguard removes all tunnel WireGuard interfaces. func (srv *Server) CleanupWireguard() { srv.mu.Lock() relinquished := srv.relinquished @@ -538,35 +643,53 @@ func (srv *Server) CleanupWireguard() { if relinquished { log.Printf("[%v] skipping WireGuard cleanup (relinquished)", srv.BindAddr) - } else { - log.Printf("[%v] cleaning up WireGuard state", srv.BindAddr) - ifname := srv.Ifname() - _ = netlink.LinkDel(&linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}}) + return + } + + log.Printf("[%v] cleaning up WireGuard state", srv.BindAddr) + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + srv.cleanupWireguardTunnel(t) } } -// iptablesInputFwmarkRule adds or removes the mangle PREROUTING rule for traffic from WireGuard. +// cleanupWireguardTunnel removes a single WireGuard tunnel interface. +func (srv *Server) cleanupWireguardTunnel(t int) { + ifname := srv.TunnelIfname(t) + _ = netlink.LinkDel(&linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}}) +} + +// iptablesInputFwmarkRule adds or removes the mangle PREROUTING rule for traffic +// from WireGuard. One rule per tunnel interface, all using the same fwmark. func (srv *Server) iptablesInputFwmarkRule(enabled bool) error { firewallMark := FwmarkBase + int(srv.Index) - rule := []string{ - "-i", srv.Ifname(), - "-j", "MARK", "--set-mark", strconv.Itoa(firewallMark), - "-m", "comment", "--comment", fmt.Sprintf("vprox fwmark rule for %s", srv.Ifname()), - } - if enabled { - return srv.Ipt.AppendUnique("mangle", "PREROUTING", rule...) - } else { - return srv.Ipt.Delete("mangle", "PREROUTING", rule...) + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + rule := []string{ + "-i", ifname, + "-j", "MARK", "--set-mark", strconv.Itoa(firewallMark), + "-m", "comment", "--comment", fmt.Sprintf("vprox fwmark rule for %s", ifname), + } + if enabled { + if err := srv.Ipt.AppendUnique("mangle", "PREROUTING", rule...); err != nil { + return err + } + } else { + srv.Ipt.Delete("mangle", "PREROUTING", rule...) + } } + return nil } // iptablesSnatRule adds or removes the nat POSTROUTING rule for outbound traffic. +// This is shared across all tunnels via fwmark (only one rule needed). func (srv *Server) iptablesSnatRule(enabled bool) error { firewallMark := FwmarkBase + int(srv.Index) rule := []string{ "-m", "mark", "--mark", strconv.Itoa(firewallMark), "-j", "SNAT", "--to-source", srv.BindAddr.String(), - "-m", "comment", "--comment", fmt.Sprintf("vprox snat rule for %s", srv.Ifname()), + "-m", "comment", "--comment", fmt.Sprintf("vprox snat rule for index %d", srv.Index), } if enabled { return srv.Ipt.AppendUnique("nat", "POSTROUTING", rule...) @@ -575,78 +698,77 @@ func (srv *Server) iptablesSnatRule(enabled bool) error { } } -// iptablesNotrackRule adds or removes a NOTRACK rule in the raw table to bypass -// connection tracking for WireGuard UDP traffic. This significantly reduces -// per-packet CPU overhead for tunneled flows. +// iptablesNotrackRule adds or removes NOTRACK rules in the raw table to bypass +// connection tracking for WireGuard UDP traffic on all tunnel ports. func (srv *Server) iptablesNotrackRule(enabled bool) error { - listenPort := strconv.Itoa(WireguardListenPortBase + int(srv.Index)) - // Inbound WireGuard UDP - inRule := []string{ - "-p", "udp", - "--dport", listenPort, - "-j", "NOTRACK", - "-m", "comment", "--comment", fmt.Sprintf("vprox notrack in for %s", srv.Ifname()), - } - // Outbound WireGuard UDP - outRule := []string{ - "-p", "udp", - "--sport", listenPort, - "-j", "NOTRACK", - "-m", "comment", "--comment", fmt.Sprintf("vprox notrack out for %s", srv.Ifname()), - } - if enabled { - if err := srv.Ipt.AppendUnique("raw", "PREROUTING", inRule...); err != nil { - return fmt.Errorf("failed to add NOTRACK PREROUTING rule: %v", err) + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + listenPort := strconv.Itoa(srv.tunnelListenPort(t)) + ifname := srv.TunnelIfname(t) + inRule := []string{ + "-p", "udp", + "--dport", listenPort, + "-j", "NOTRACK", + "-m", "comment", "--comment", fmt.Sprintf("vprox notrack in for %s", ifname), + } + outRule := []string{ + "-p", "udp", + "--sport", listenPort, + "-j", "NOTRACK", + "-m", "comment", "--comment", fmt.Sprintf("vprox notrack out for %s", ifname), } - if err := srv.Ipt.AppendUnique("raw", "OUTPUT", outRule...); err != nil { - // Clean up the first rule on failure. + if enabled { + if err := srv.Ipt.AppendUnique("raw", "PREROUTING", inRule...); err != nil { + return fmt.Errorf("failed to add NOTRACK PREROUTING rule for %s: %v", ifname, err) + } + if err := srv.Ipt.AppendUnique("raw", "OUTPUT", outRule...); err != nil { + srv.Ipt.Delete("raw", "PREROUTING", inRule...) + return fmt.Errorf("failed to add NOTRACK OUTPUT rule for %s: %v", ifname, err) + } + } else { srv.Ipt.Delete("raw", "PREROUTING", inRule...) - return fmt.Errorf("failed to add NOTRACK OUTPUT rule: %v", err) + srv.Ipt.Delete("raw", "OUTPUT", outRule...) } - return nil } - // Cleanup: best-effort, ignore errors. - srv.Ipt.Delete("raw", "PREROUTING", inRule...) - srv.Ipt.Delete("raw", "OUTPUT", outRule...) return nil } // iptablesMssRule adds or removes FORWARD chain rules for TCP MSS clamping in -// both directions. Uses the mangle table which is the correct place for packet -// modification. We need both -o (traffic entering the tunnel, server→client) -// and -i (traffic leaving the tunnel, client→server) so that SYN packets in -// either direction get their MSS clamped to fit within the WireGuard MTU. +// both directions on all tunnel interfaces. func (srv *Server) iptablesMssRule(enabled bool) error { - outRule := []string{ - "-o", srv.Ifname(), - "-p", "tcp", - "--tcp-flags", "SYN,RST", "SYN", - "-j", "TCPMSS", - "--set-mss", strconv.Itoa(WireguardMSS), - "-m", "comment", "--comment", fmt.Sprintf("vprox mss out rule for %s", srv.Ifname()), - } - inRule := []string{ - "-i", srv.Ifname(), - "-p", "tcp", - "--tcp-flags", "SYN,RST", "SYN", - "-j", "TCPMSS", - "--set-mss", strconv.Itoa(WireguardMSS), - "-m", "comment", "--comment", fmt.Sprintf("vprox mss in rule for %s", srv.Ifname()), - } - - if enabled { - if err := srv.Ipt.AppendUnique("mangle", "FORWARD", outRule...); err != nil { - return err + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + outRule := []string{ + "-o", ifname, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", strconv.Itoa(WireguardMSS), + "-m", "comment", "--comment", fmt.Sprintf("vprox mss out rule for %s", ifname), } - if err := srv.Ipt.AppendUnique("mangle", "FORWARD", inRule...); err != nil { + inRule := []string{ + "-i", ifname, + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", strconv.Itoa(WireguardMSS), + "-m", "comment", "--comment", fmt.Sprintf("vprox mss in rule for %s", ifname), + } + + if enabled { + if err := srv.Ipt.AppendUnique("mangle", "FORWARD", outRule...); err != nil { + return err + } + if err := srv.Ipt.AppendUnique("mangle", "FORWARD", inRule...); err != nil { + srv.Ipt.Delete("mangle", "FORWARD", outRule...) + return err + } + } else { srv.Ipt.Delete("mangle", "FORWARD", outRule...) - return err + srv.Ipt.Delete("mangle", "FORWARD", inRule...) } - return nil } - // Cleanup: best-effort both directions. - srv.Ipt.Delete("mangle", "FORWARD", outRule...) - srv.Ipt.Delete("mangle", "FORWARD", inRule...) return nil } @@ -705,15 +827,13 @@ func (srv *Server) removeIdlePeersLoop() { } } -// cleanupPeer removes a peer from the WireGuard interface, reclaims its subnet IP, -// and removes it from the allPeers map. This function is idempotent and safe to call -// even if the peer doesn't exist. It uses the allPeers map as the source of truth. +// cleanupPeer removes a peer from ALL WireGuard tunnel interfaces, reclaims its +// subnet IP, and removes it from the allPeers map. func (srv *Server) cleanupPeer(publicKey wgtypes.Key) error { // Look up the peer in allPeers map to get its IP address. srv.mu.Lock() peerInfo, exists := srv.allPeers[publicKey] if !exists { - // Peer not in allPeers - it likely already got cleaned up. srv.mu.Unlock() log.Printf("[%v] peer unexpectedly not found in allPeers - did /disconnect race with the periodic peer-GC loop?: %v", srv.BindAddr, publicKey) return nil @@ -724,18 +844,23 @@ func (srv *Server) cleanupPeer(publicKey wgtypes.Key) error { delete(srv.allPeers, publicKey) srv.mu.Unlock() - // Remove the peer from WireGuard (no lock held during WireGuard operations). - log.Printf("[%v] removing peer at %v: %v", srv.BindAddr, peerIp, publicKey) - err := srv.WgClient.ConfigureDevice(srv.Ifname(), wgtypes.Config{ - Peers: []wgtypes.PeerConfig{ - { - PublicKey: publicKey, - Remove: true, + // Remove the peer from ALL tunnel interfaces. + nt := srv.numTunnels() + log.Printf("[%v] removing peer at %v from %d tunnel(s): %v", srv.BindAddr, peerIp, nt, publicKey) + var firstErr error + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + err := srv.WgClient.ConfigureDevice(ifname, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: publicKey, + Remove: true, + }, }, - }, - }) - if err != nil { - return fmt.Errorf("failed to remove WireGuard peer: %v", err) + }) + if err != nil && firstErr == nil { + firstErr = fmt.Errorf("failed to remove WireGuard peer from %s: %v", ifname, err) + } } // Free the IP address. @@ -743,10 +868,12 @@ func (srv *Server) cleanupPeer(publicKey wgtypes.Key) error { srv.ipAllocator.Free(peerIp) } - return nil + return firstErr } func (srv *Server) removeIdlePeers() error { + // Check idle status using the primary tunnel interface (tunnel 0). + // All tunnels share the same peers, so we only need to inspect one. device, err := srv.WgClient.Device(srv.Ifname()) if err != nil { return fmt.Errorf("failed to get WireGuard device: %v", err) @@ -756,7 +883,7 @@ func (srv *Server) removeIdlePeers() error { srv.mu.Lock() defer srv.mu.Unlock() - var removePeers []wgtypes.PeerConfig + var removePeerKeys []wgtypes.Key var removeIps []netip.Addr for _, peer := range device.Peers { var idle bool @@ -765,8 +892,6 @@ func (srv *Server) removeIdlePeers() error { if exists { idle = time.Since(peerInfo.ConnectionTime) > PeerIdleTimeout } else { - // If we somehow have a WireGuard interface for a peer but no allPeers entry, - // let's just assume it's idle and remove it. idle = true } } else { @@ -782,19 +907,28 @@ func (srv *Server) removeIdlePeers() error { removeIps = append(removeIps, netip.AddrFrom4([4]byte(ipv4))) } } - removePeers = append(removePeers, wgtypes.PeerConfig{ - PublicKey: peer.PublicKey, - Remove: true, - }) + removePeerKeys = append(removePeerKeys, peer.PublicKey) delete(srv.allPeers, peer.PublicKey) } } - if len(removePeers) > 0 { - err := srv.WgClient.ConfigureDevice(srv.Ifname(), wgtypes.Config{Peers: removePeers}) - if err != nil { - return err + if len(removePeerKeys) > 0 { + // Build the peer removal config. + removePeers := make([]wgtypes.PeerConfig, len(removePeerKeys)) + for i, pk := range removePeerKeys { + removePeers[i] = wgtypes.PeerConfig{PublicKey: pk, Remove: true} } + + // Remove from ALL tunnel interfaces. + nt := srv.numTunnels() + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + err := srv.WgClient.ConfigureDevice(ifname, wgtypes.Config{Peers: removePeers}) + if err != nil { + log.Printf("warning: failed to remove idle peers from %s: %v", ifname, err) + } + } + for _, ip := range removeIps { srv.ipAllocator.Free(ip) } diff --git a/lib/server_manager.go b/lib/server_manager.go index 632541f..1f4576d 100644 --- a/lib/server_manager.go +++ b/lib/server_manager.go @@ -25,6 +25,7 @@ type ServerManager struct { ipt *iptables.IPTables key wgtypes.Key password string + numTunnels int ctx context.Context waitGroup *sync.WaitGroup wgBlock netip.Prefix @@ -41,7 +42,7 @@ type ServerManager struct { } // NewServerManager creates a new server manager -func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Context, key wgtypes.Key, password string, takeover bool) (*ServerManager, error) { +func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Context, key wgtypes.Key, password string, numTunnels int, takeover bool) (*ServerManager, error) { // Make a shared WireGuard client. wgClient, err := wgctrl.New() if err != nil { @@ -58,11 +59,19 @@ func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Conte color.New(color.Bold).Sprint("server public key:"), key.PublicKey().String()) + if numTunnels <= 0 { + numTunnels = 1 + } + if numTunnels > MaxTunnelsPerServer { + numTunnels = MaxTunnelsPerServer + } + sm := new(ServerManager) sm.wgClient = wgClient sm.ipt = ipt sm.key = key sm.password = password + sm.numTunnels = numTunnels sm.ctx = ctx sm.waitGroup = new(sync.WaitGroup) sm.wgBlock = wgBlock.Masked() @@ -109,15 +118,16 @@ func (sm *ServerManager) Start(ip netip.Addr) error { wgCidr := netip.PrefixFrom(subnetStart.Next(), int(sm.wgBlockPerIp)) srv := &Server{ - Key: sm.key, - BindAddr: ip, - Password: sm.password, - Index: i, - Ipt: sm.ipt, - WgClient: sm.wgClient, - WgCidr: wgCidr, - Ctx: subctx, - takeover: sm.takeover, + Key: sm.key, + BindAddr: ip, + Password: sm.password, + Index: i, + NumTunnels: sm.numTunnels, + Ipt: sm.ipt, + WgClient: sm.wgClient, + WgCidr: wgCidr, + Ctx: subctx, + takeover: sm.takeover, } if err := srv.InitState(); err != nil { _ = cancel // cancel should be discarded From 32fdb02448a8559b5cbf2c82d58da87e24890de0 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 00:55:22 -0400 Subject: [PATCH 04/14] use many tunnels --- lib/client.go | 77 +++++++++++++++++++++++++-------------------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/lib/client.go b/lib/client.go index 97ed3d9..df96346 100644 --- a/lib/client.go +++ b/lib/client.go @@ -179,35 +179,31 @@ func (c *Client) Connect() error { } c.activeTunnels = nt - // Bring up and configure the primary interface (tunnel 0) — this is the - // one that gets the IP address and subnet route. - primaryLink := c.tunnelLink(0) - err = netlink.LinkSetUp(primaryLink) - if err != nil { - return fmt.Errorf("error setting up vprox interface: %v", err) - } - - err = c.updateInterface(resp) - if err != nil { - return err - } - - // Configure WireGuard on tunnel 0 using the primary ServerListenPort - // (works for both old and new servers). - err = c.configureWireguardTunnel(0, resp, resp.ServerListenPort) - if err != nil { - return fmt.Errorf("error configuring wireguard on %s: %v", c.tunnelIfname(0), err) - } - - // Configure additional tunnels if the server provided them. - for t := 1; t < nt; t++ { + // Bring up, assign address, and configure WireGuard on ALL tunnel interfaces. + // Each interface gets the same IP address so the kernel knows how to reach + // the gateway (server WireGuard IP) through any of them. + for t := 0; t < nt; t++ { link := c.tunnelLink(t) err = netlink.LinkSetUp(link) if err != nil { return fmt.Errorf("error setting up vprox interface %s: %v", link.Name, err) } - port := resp.Tunnels[t].ListenPort + // Assign the same address to every tunnel interface. The first call + // also updates c.wgCidr; subsequent calls for the same CIDR are + // handled by updateTunnelInterface which skips if already set. + if err := c.updateTunnelInterface(t, resp); err != nil { + return fmt.Errorf("error updating interface %s: %v", link.Name, err) + } + + // Pick the listen port: tunnel 0 always uses ServerListenPort + // (backwards compatible with old servers); tunnels 1+ use Tunnels[t]. + var port int + if t == 0 { + port = resp.ServerListenPort + } else { + port = resp.Tunnels[t].ListenPort + } err = c.configureWireguardTunnel(t, resp, port) if err != nil { return fmt.Errorf("error configuring wireguard on %s: %v", c.tunnelIfname(t), err) @@ -227,30 +223,33 @@ func (c *Client) Connect() error { return nil } -// updateInterface updates the primary WireGuard interface (tunnel 0) address -// based on the connect response. -func (c *Client) updateInterface(resp connectResponse) error { +// updateTunnelInterface assigns the WireGuard address to tunnel interface t. +// Every tunnel interface gets the same IP/CIDR so that the server gateway is +// reachable through each of them (required for multipath routing). +func (c *Client) updateTunnelInterface(t int, resp connectResponse) error { cidr, err := netip.ParsePrefix(resp.AssignedAddr) if err != nil { return fmt.Errorf("failed to parse assigned address %v: %v", resp.AssignedAddr, err) } - if cidr != c.wgCidr { - link := c.link() + link := c.tunnelLink(t) - if c.wgCidr.IsValid() { - oldIpnet := prefixToIPNet(c.wgCidr) - err = netlink.AddrDel(link, &netlink.Addr{IPNet: &oldIpnet}) - if err != nil { - log.Printf("warning: failed to remove old address from vprox interface when reconnecting: %v", err) - } + if t == 0 && c.wgCidr.IsValid() && cidr != c.wgCidr { + // On reconnect the primary tunnel may need the old address removed. + oldIpnet := prefixToIPNet(c.wgCidr) + if err := netlink.AddrDel(link, &netlink.Addr{IPNet: &oldIpnet}); err != nil { + log.Printf("warning: failed to remove old address from %s when reconnecting: %v", c.tunnelIfname(t), err) } + } - ipnet := prefixToIPNet(cidr) - err = netlink.AddrAdd(link, &netlink.Addr{IPNet: &ipnet}) - if err != nil { - return fmt.Errorf("failed to add new address to vprox interface: %v", err) - } + ipnet := prefixToIPNet(cidr) + err = netlink.AddrReplace(link, &netlink.Addr{IPNet: &ipnet}) + if err != nil { + return fmt.Errorf("failed to add address to %s: %v", c.tunnelIfname(t), err) + } + + // Track the CIDR on the first tunnel. + if t == 0 { c.wgCidr = cidr } return nil From 97a69626e75b52e75f318cb26c0aa30599963493 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 01:00:34 -0400 Subject: [PATCH 05/14] fixes routing --- lib/client.go | 67 ++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/lib/client.go b/lib/client.go index df96346..b68922c 100644 --- a/lib/client.go +++ b/lib/client.go @@ -257,17 +257,51 @@ func (c *Client) updateTunnelInterface(t int, resp connectResponse) error { // setupMultipathRouting creates equal-cost multipath routes across all active // tunnel interfaces so that the kernel distributes flows across them. +// +// WireGuard interfaces are POINTOPOINT and don't automatically get subnet +// routes, so we first add a per-device route for each tunnel, then replace +// them with a single multipath route. func (c *Client) setupMultipathRouting(nt int) error { if !c.wgCidr.IsValid() { return fmt.Errorf("no valid CIDR assigned yet") } - // The server's WireGuard IP is the first address in the subnet (the - // gateway for our multipath nexthops). + subnetIPNet := prefixToIPNet(c.wgCidr.Masked()) + + // Step 1: Remove any existing routes for this subnet so we start clean. + existingRoutes, _ := netlink.RouteList(nil, netlink.FAMILY_V4) + for i := range existingRoutes { + r := &existingRoutes[i] + if r.Dst != nil && r.Dst.String() == subnetIPNet.String() { + _ = netlink.RouteDel(r) + } + } + + // Step 2: Ensure each tunnel interface has a device-scoped route for the + // subnet. This makes the gateway reachable through every interface. + for t := 0; t < nt; t++ { + ifname := c.tunnelIfname(t) + link, err := netlink.LinkByName(ifname) + if err != nil { + return fmt.Errorf("failed to find interface %s: %v", ifname, err) + } + devRoute := &netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: &subnetIPNet, + Scope: netlink.SCOPE_LINK, + } + if err := netlink.RouteReplace(devRoute); err != nil { + return fmt.Errorf("failed to add device route on %s: %v", ifname, err) + } + } + + // Step 3: Build multipath nexthops — one per tunnel interface, using the + // server's WireGuard IP (first address in subnet) as the gateway. Now + // that each interface has a device-scoped route, the gateway is reachable + // through all of them. gwAddr := c.wgCidr.Masked().Addr().Next() gwIP := addrToIp(gwAddr) - // Build multipath nexthops — one per tunnel interface. var nexthops []*netlink.NexthopInfo for t := 0; t < nt; t++ { ifname := c.tunnelIfname(t) @@ -282,26 +316,25 @@ func (c *Client) setupMultipathRouting(nt int) error { }) } - // Remove any existing default route on the primary interface first. - // (The kernel creates one when we assign the address.) - existingRoutes, _ := netlink.RouteList(nil, netlink.FAMILY_V4) - for i := range existingRoutes { - r := &existingRoutes[i] - if r.Dst != nil && r.Dst.String() == c.wgCidr.Masked().String() { - // This is the subnet route — we need to replace it with multipath. - _ = netlink.RouteDel(r) + // Step 4: Remove the per-device routes and replace with a single + // multipath route. + for t := 0; t < nt; t++ { + ifname := c.tunnelIfname(t) + link, _ := netlink.LinkByName(ifname) + if link != nil { + _ = netlink.RouteDel(&netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: &subnetIPNet, + Scope: netlink.SCOPE_LINK, + }) } } - // Add the multipath route for the WireGuard subnet. - subnetIPNet := prefixToIPNet(c.wgCidr.Masked()) - route := &netlink.Route{ + mpRoute := &netlink.Route{ Dst: &subnetIPNet, MultiPath: nexthops, } - - err := netlink.RouteReplace(route) - if err != nil { + if err := netlink.RouteReplace(mpRoute); err != nil { return fmt.Errorf("failed to add multipath route: %v", err) } From 46c90cdbcbe5df9f7ebb8552ae9b1d33ad514f13 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 01:02:17 -0400 Subject: [PATCH 06/14] remove gateway --- lib/client.go | 65 ++++++++------------------------------------------- 1 file changed, 10 insertions(+), 55 deletions(-) diff --git a/lib/client.go b/lib/client.go index b68922c..782933c 100644 --- a/lib/client.go +++ b/lib/client.go @@ -255,12 +255,10 @@ func (c *Client) updateTunnelInterface(t int, resp connectResponse) error { return nil } -// setupMultipathRouting creates equal-cost multipath routes across all active -// tunnel interfaces so that the kernel distributes flows across them. -// -// WireGuard interfaces are POINTOPOINT and don't automatically get subnet -// routes, so we first add a per-device route for each tunnel, then replace -// them with a single multipath route. +// setupMultipathRouting adds one equal-cost device-scoped route per tunnel +// interface for the WireGuard subnet. The kernel automatically round-robins +// across equal-cost routes to the same destination, distributing flows across +// the tunnel interfaces. func (c *Client) setupMultipathRouting(nt int) error { if !c.wgCidr.IsValid() { return fmt.Errorf("no valid CIDR assigned yet") @@ -268,7 +266,7 @@ func (c *Client) setupMultipathRouting(nt int) error { subnetIPNet := prefixToIPNet(c.wgCidr.Masked()) - // Step 1: Remove any existing routes for this subnet so we start clean. + // Remove any existing routes for this subnet so we start clean. existingRoutes, _ := netlink.RouteList(nil, netlink.FAMILY_V4) for i := range existingRoutes { r := &existingRoutes[i] @@ -277,67 +275,24 @@ func (c *Client) setupMultipathRouting(nt int) error { } } - // Step 2: Ensure each tunnel interface has a device-scoped route for the - // subnet. This makes the gateway reachable through every interface. + // Add a device-scoped route for each tunnel interface. Equal-cost routes + // to the same destination cause the kernel to distribute flows across them. for t := 0; t < nt; t++ { ifname := c.tunnelIfname(t) link, err := netlink.LinkByName(ifname) if err != nil { return fmt.Errorf("failed to find interface %s: %v", ifname, err) } - devRoute := &netlink.Route{ + route := &netlink.Route{ LinkIndex: link.Attrs().Index, Dst: &subnetIPNet, Scope: netlink.SCOPE_LINK, } - if err := netlink.RouteReplace(devRoute); err != nil { - return fmt.Errorf("failed to add device route on %s: %v", ifname, err) + if err := netlink.RouteAdd(route); err != nil { + return fmt.Errorf("failed to add route on %s: %v", ifname, err) } } - // Step 3: Build multipath nexthops — one per tunnel interface, using the - // server's WireGuard IP (first address in subnet) as the gateway. Now - // that each interface has a device-scoped route, the gateway is reachable - // through all of them. - gwAddr := c.wgCidr.Masked().Addr().Next() - gwIP := addrToIp(gwAddr) - - var nexthops []*netlink.NexthopInfo - for t := 0; t < nt; t++ { - ifname := c.tunnelIfname(t) - link, err := netlink.LinkByName(ifname) - if err != nil { - return fmt.Errorf("failed to find interface %s: %v", ifname, err) - } - nexthops = append(nexthops, &netlink.NexthopInfo{ - LinkIndex: link.Attrs().Index, - Gw: gwIP, - Hops: 0, // equal weight - }) - } - - // Step 4: Remove the per-device routes and replace with a single - // multipath route. - for t := 0; t < nt; t++ { - ifname := c.tunnelIfname(t) - link, _ := netlink.LinkByName(ifname) - if link != nil { - _ = netlink.RouteDel(&netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: &subnetIPNet, - Scope: netlink.SCOPE_LINK, - }) - } - } - - mpRoute := &netlink.Route{ - Dst: &subnetIPNet, - MultiPath: nexthops, - } - if err := netlink.RouteReplace(mpRoute); err != nil { - return fmt.Errorf("failed to add multipath route: %v", err) - } - return nil } From 4049f33e12394ff15050001f38c0008681af6e7a Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 01:04:54 -0400 Subject: [PATCH 07/14] create different routes --- lib/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/client.go b/lib/client.go index 782933c..78be68b 100644 --- a/lib/client.go +++ b/lib/client.go @@ -288,7 +288,7 @@ func (c *Client) setupMultipathRouting(nt int) error { Dst: &subnetIPNet, Scope: netlink.SCOPE_LINK, } - if err := netlink.RouteAdd(route); err != nil { + if err := netlink.RouteAppend(route); err != nil { return fmt.Errorf("failed to add route on %s: %v", ifname, err) } } From 33558fac90cfa2947cc6f973d1b7aa14e2518ef9 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 01:20:38 -0400 Subject: [PATCH 08/14] fix routing --- lib/server.go | 61 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/lib/server.go b/lib/server.go index 6cd052a..d6256ed 100644 --- a/lib/server.go +++ b/lib/server.go @@ -540,6 +540,14 @@ func (srv *Server) StartWireguard() error { if nt > 1 { log.Printf("[%v] started %d WireGuard tunnels (ports %d..%d)", srv.BindAddr, nt, srv.tunnelListenPort(0), srv.tunnelListenPort(nt-1)) + + // Set up equal-cost routes across all tunnel interfaces so the kernel + // distributes reply traffic across them (same approach as client side). + if err := srv.setupMultipathRouting(nt); err != nil { + log.Printf("[%v] warning: failed to set up multipath routing: %v", srv.BindAddr, err) + } else { + log.Printf("[%v] multipath routing configured across %d tunnels", srv.BindAddr, nt) + } } return nil } @@ -594,23 +602,20 @@ func (srv *Server) startWireguardTunnel(t int) error { } // createFreshInterface creates and configures a new WireGuard interface. -// Only tunnel 0 gets an IP address assigned — the other tunnels share the same -// subnet via the kernel routing table entry that tunnel 0 creates. +// Every tunnel interface gets the same subnet IP so the kernel can route +// reply packets back through any of them. func (srv *Server) createFreshInterface(link *linkWireguard, tunnelIndex int) error { err := netlink.LinkAdd(link) if err != nil { return fmt.Errorf("failed to create WireGuard device %s: %v", link.Name, err) } - // Only the primary tunnel (index 0) gets the subnet IP address. - // Additional tunnels participate in the same subnet without their own address. - if tunnelIndex == 0 { - ipnet := prefixToIPNet(srv.WgCidr) - err = netlink.AddrAdd(link, &netlink.Addr{IPNet: &ipnet}) - if err != nil { - netlink.LinkDel(link) - return fmt.Errorf("failed to add address to WireGuard device %s: %v", link.Name, err) - } + // Assign the subnet IP to every tunnel interface. + ipnet := prefixToIPNet(srv.WgCidr) + err = netlink.AddrReplace(link, &netlink.Addr{IPNet: &ipnet}) + if err != nil { + netlink.LinkDel(link) + return fmt.Errorf("failed to add address to WireGuard device %s: %v", link.Name, err) } // Set MTU explicitly after link creation (some kernels ignore it in LinkAttrs) @@ -635,6 +640,40 @@ func (srv *Server) createFreshInterface(link *linkWireguard, tunnelIndex int) er return nil } +// setupMultipathRouting adds equal-cost device-scoped routes for the WireGuard +// subnet across all tunnel interfaces so the kernel round-robins reply traffic. +func (srv *Server) setupMultipathRouting(nt int) error { + subnetIPNet := prefixToIPNet(srv.WgCidr.Masked()) + + // Remove any existing routes for this subnet so we start clean. + existingRoutes, _ := netlink.RouteList(nil, netlink.FAMILY_V4) + for i := range existingRoutes { + r := &existingRoutes[i] + if r.Dst != nil && r.Dst.String() == subnetIPNet.String() { + _ = netlink.RouteDel(r) + } + } + + // Append one route per tunnel interface. Equal-cost routes to the same + // destination cause the kernel to distribute flows across them. + for t := 0; t < nt; t++ { + ifname := srv.TunnelIfname(t) + link, err := netlink.LinkByName(ifname) + if err != nil { + return fmt.Errorf("failed to find interface %s: %v", ifname, err) + } + route := &netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: &subnetIPNet, + Scope: netlink.SCOPE_LINK, + } + if err := netlink.RouteAppend(route); err != nil { + return fmt.Errorf("failed to append route on %s: %v", ifname, err) + } + } + return nil +} + // CleanupWireguard removes all tunnel WireGuard interfaces. func (srv *Server) CleanupWireguard() { srv.mu.Lock() From f0c079dc944712ec4b362089b61d967402fbba19 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 01:32:57 -0400 Subject: [PATCH 09/14] uses bonding --- lib/client.go | 458 +++++++++++++++++++++++++++++--------------------- 1 file changed, 266 insertions(+), 192 deletions(-) diff --git a/lib/client.go b/lib/client.go index 78be68b..90cf948 100644 --- a/lib/client.go +++ b/lib/client.go @@ -12,6 +12,7 @@ import ( "net/http" "net/netip" "net/url" + "os" "time" probing "github.com/prometheus-community/pro-bing" @@ -40,15 +41,29 @@ func IsRecoverableError(err error) bool { return true } -// Client manages a peering connection with a local WireGuard interface (or a -// set of parallel WireGuard interfaces when multi-tunnel is enabled). +// Client manages a peering connection with a local WireGuard interface, or a +// set of parallel WireGuard interfaces bonded together when multi-tunnel is +// enabled. +// +// Single-tunnel (NumTunnels <= 1): +// +// Applications use "vprox0" which is a plain WireGuard interface. +// +// Multi-tunnel (NumTunnels > 1): +// +// WireGuard slaves: vprox0t0, vprox0t1, vprox0t2, ... +// Bond master: vprox0 (balance-rr, presents a single interface) +// +// Applications bind to "vprox0" and the bonding driver distributes packets +// round-robin across the WireGuard slaves. Each slave uses a different UDP +// port to the server, so the NIC hashes them to different hardware RX queues. type Client struct { // Key is the private key of the client. Key wgtypes.Key - // Ifname is the base name of the client WireGuard interface (e.g. "vprox0"). - // With multi-tunnel this becomes the primary interface; additional tunnels - // are named "vprox0t1", "vprox0t2", etc. + // Ifname is the name of the interface exposed to applications (e.g. "vprox0"). + // In multi-tunnel mode this is the bond device; individual WireGuard tunnels + // are named t0, t1, etc. Ifname string // ServerIp is the public IPv4 address of the server. @@ -58,7 +73,8 @@ type Client struct { Password string // NumTunnels is the number of parallel WireGuard tunnels to create. - // When <= 1, the client behaves exactly as before (single interface). + // When <= 1, the client creates a single plain WireGuard interface. + // When > 1, a bonding device is created over N WireGuard slaves. NumTunnels int // WgClient is a shared client for interacting with the WireGuard kernel module. @@ -67,15 +83,18 @@ type Client struct { // Http is used to make connect requests to the server. Http *http.Client - // wgCidr is the current subnet assigned to the WireGuard interface, if any. + // wgCidr is the current subnet assigned to the interface, if any. wgCidr netip.Prefix // activeTunnels tracks how many tunnel interfaces were actually created - // during the last successful Connect(). This may be less than NumTunnels - // if the server returned fewer Tunnels entries (e.g. old server). + // during the last successful Connect(). activeTunnels int } +// --------------------------------------------------------------------------- +// Naming helpers +// --------------------------------------------------------------------------- + // numTunnels returns the effective tunnel count, defaulting to 1. func (c *Client) numTunnels() int { if c.NumTunnels <= 1 { @@ -84,17 +103,22 @@ func (c *Client) numTunnels() int { return c.NumTunnels } -// tunnelIfname returns the interface name for the t-th tunnel. -// Tunnel 0 uses Ifname directly (e.g. "vprox0"). -// Tunnel 1+ appends "t1", "t2", etc. (e.g. "vprox0t1", "vprox0t2"). +// isMultiTunnel returns true when we should create a bond device. +func (c *Client) isMultiTunnel() bool { + return c.numTunnels() > 1 +} + +// tunnelIfname returns the WireGuard interface name for the t-th tunnel. +// - Single-tunnel mode: returns Ifname directly (e.g. "vprox0"). +// - Multi-tunnel mode: returns "t" (e.g. "vprox0t0", "vprox0t1"). func (c *Client) tunnelIfname(t int) string { - if t == 0 { + if !c.isMultiTunnel() { return c.Ifname } return fmt.Sprintf("%st%d", c.Ifname, t) } -// tunnelLink builds a linkWireguard for the t-th tunnel with tuned LinkAttrs. +// tunnelLink builds a linkWireguard with tuned LinkAttrs for the t-th tunnel. func (c *Client) tunnelLink(t int) *linkWireguard { return &linkWireguard{LinkAttrs: netlink.LinkAttrs{ Name: c.tunnelIfname(t), @@ -107,28 +131,40 @@ func (c *Client) tunnelLink(t int) *linkWireguard { }} } -// link returns a linkWireguard for the primary (tunnel 0) interface. -func (c *Client) link() *linkWireguard { - return c.tunnelLink(0) -} +// --------------------------------------------------------------------------- +// Interface creation / deletion +// --------------------------------------------------------------------------- -// CreateInterface creates the WireGuard interface(s). For single-tunnel mode -// this creates one interface; for multi-tunnel mode it creates N interfaces. +// CreateInterface creates the network interface(s) that applications will use. +// - Single-tunnel: one plain WireGuard interface named Ifname. +// - Multi-tunnel: N WireGuard interfaces + a bond master named Ifname. +// // DeleteInterface() must be called to clean up. func (c *Client) CreateInterface() error { nt := c.numTunnels() + + // Create the WireGuard tunnel interfaces. for t := 0; t < nt; t++ { if err := c.createTunnelInterface(t); err != nil { - // Clean up any interfaces we already created. for rb := 0; rb < t; rb++ { c.deleteTunnelInterface(rb) } return err } } - if nt > 1 { - log.Printf("created %d tunnel interfaces (%s .. %s)", nt, c.tunnelIfname(0), c.tunnelIfname(nt-1)) + + // In multi-tunnel mode, create a bond over the WireGuard slaves. + if c.isMultiTunnel() { + if err := c.createBond(nt); err != nil { + for t := 0; t < nt; t++ { + c.deleteTunnelInterface(t) + } + return err + } + log.Printf("created bond %s over %d tunnel slaves (%s .. %s)", + c.Ifname, nt, c.tunnelIfname(0), c.tunnelIfname(nt-1)) } + return nil } @@ -138,26 +174,112 @@ func (c *Client) createTunnelInterface(t int) error { err := netlink.LinkAdd(link) if err != nil { - return fmt.Errorf("error creating vprox interface %s: %v", link.Name, err) + return fmt.Errorf("error creating WireGuard interface %s: %v", link.Name, err) } - // Set MTU explicitly (some kernels ignore LinkAttrs.MTU on creation) - err = netlink.LinkSetMTU(link, WireguardMTU) - if err != nil { + // Set MTU explicitly (some kernels ignore LinkAttrs.MTU on creation). + if err := netlink.LinkSetMTU(link, WireguardMTU); err != nil { netlink.LinkDel(link) - return fmt.Errorf("error setting MTU on vprox interface %s: %v", link.Name, err) + return fmt.Errorf("error setting MTU on %s: %v", link.Name, err) + } + + // Set TxQLen for improved burst handling (non-fatal). + if err := netlink.LinkSetTxQLen(link, WireguardTxQLen); err != nil { + log.Printf("warning: failed to set TxQLen on %s: %v", link.Name, err) } - // Set TxQLen for improved burst handling - err = netlink.LinkSetTxQLen(link, WireguardTxQLen) + // In multi-tunnel mode the slaves are brought up later when enslaved + // to the bond. In single-tunnel mode we don't bring it up yet either + // — Connect() will do it. + + return nil +} + +// createBond creates a balance-rr bond device named Ifname and enslaves all +// WireGuard tunnel interfaces to it. +func (c *Client) createBond(nt int) error { + // Ensure the bonding kernel module is loaded. + _ = writeSysFile("/sys/module/bonding/initstate", "") + + bond := netlink.NewLinkBond(netlink.LinkAttrs{ + Name: c.Ifname, + MTU: WireguardMTU, + TxQLen: WireguardTxQLen, + }) + bond.Mode = netlink.BOND_MODE_BALANCE_RR + // MIIMon: link monitoring interval in ms. We set a low value so that if + // a slave goes down, the bond reacts quickly. + bond.Miimon = 100 + + // Remove any stale bond with this name. + if existing, _ := netlink.LinkByName(c.Ifname); existing != nil { + _ = netlink.LinkDel(existing) + } + + if err := netlink.LinkAdd(bond); err != nil { + return fmt.Errorf("failed to create bond %s: %v", c.Ifname, err) + } + + // Enslave each WireGuard tunnel interface to the bond. + bondLink, err := netlink.LinkByName(c.Ifname) if err != nil { - // Non-fatal: log warning but continue - log.Printf("warning: failed to set TxQLen on vprox interface %s: %v", link.Name, err) + netlink.LinkDel(bond) + return fmt.Errorf("failed to find bond %s after creation: %v", c.Ifname, err) + } + + for t := 0; t < nt; t++ { + slave, err := netlink.LinkByName(c.tunnelIfname(t)) + if err != nil { + netlink.LinkDel(bond) + return fmt.Errorf("failed to find slave %s: %v", c.tunnelIfname(t), err) + } + // The slave must be down before enslaving. + _ = netlink.LinkSetDown(slave) + if err := netlink.LinkSetMaster(slave, bondLink); err != nil { + netlink.LinkDel(bond) + return fmt.Errorf("failed to enslave %s to %s: %v", c.tunnelIfname(t), c.Ifname, err) + } } return nil } +// DeleteInterface removes all interfaces (bond + WireGuard tunnels). +func (c *Client) DeleteInterface() { + if c.isMultiTunnel() { + // Deleting the bond master also releases the slaves. + log.Printf("About to delete bond interface %v", c.Ifname) + if bond, err := netlink.LinkByName(c.Ifname); err == nil { + if err := netlink.LinkDel(bond); err != nil { + log.Printf("error deleting bond %v: %v", c.Ifname, err) + } else { + log.Printf("successfully deleted bond %v", c.Ifname) + } + } + } + + nt := c.numTunnels() + for t := nt - 1; t >= 0; t-- { + c.deleteTunnelInterface(t) + } +} + +// deleteTunnelInterface removes a single WireGuard tunnel interface. +func (c *Client) deleteTunnelInterface(t int) { + ifname := c.tunnelIfname(t) + log.Printf("About to delete vprox interface %v", ifname) + link := &linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}} + if err := netlink.LinkDel(link); err != nil { + log.Printf("error deleting vprox interface %v: %v", ifname, err) + } else { + log.Printf("successfully deleted vprox interface %v", ifname) + } +} + +// --------------------------------------------------------------------------- +// Connect / Disconnect +// --------------------------------------------------------------------------- + // Connect attempts to connect (or reconnect) to the server. All tunnel // interfaces must already exist via CreateInterface(). func (c *Client) Connect() error { @@ -166,182 +288,93 @@ func (c *Client) Connect() error { return err } - // Determine how many tunnels to actually use. Use the minimum of what - // the client wants and what the server offers. + // Determine how many tunnels to actually use — minimum of what the client + // wants and what the server offers. nt := c.numTunnels() serverTunnels := len(resp.Tunnels) if serverTunnels > 0 && serverTunnels < nt { nt = serverTunnels } - // If the server returned no Tunnels list (old server), use 1 tunnel. if serverTunnels == 0 { nt = 1 } c.activeTunnels = nt - // Bring up, assign address, and configure WireGuard on ALL tunnel interfaces. - // Each interface gets the same IP address so the kernel knows how to reach - // the gateway (server WireGuard IP) through any of them. + // Configure WireGuard on each tunnel interface. for t := 0; t < nt; t++ { - link := c.tunnelLink(t) - err = netlink.LinkSetUp(link) - if err != nil { - return fmt.Errorf("error setting up vprox interface %s: %v", link.Name, err) - } + ifname := c.tunnelIfname(t) - // Assign the same address to every tunnel interface. The first call - // also updates c.wgCidr; subsequent calls for the same CIDR are - // handled by updateTunnelInterface which skips if already set. - if err := c.updateTunnelInterface(t, resp); err != nil { - return fmt.Errorf("error updating interface %s: %v", link.Name, err) + // Bring the slave up (bond requires slaves to be up for traffic). + slave := c.tunnelLink(t) + if err := netlink.LinkSetUp(slave); err != nil { + return fmt.Errorf("error setting up %s: %v", ifname, err) } // Pick the listen port: tunnel 0 always uses ServerListenPort // (backwards compatible with old servers); tunnels 1+ use Tunnels[t]. - var port int - if t == 0 { - port = resp.ServerListenPort - } else { + port := resp.ServerListenPort + if t > 0 && t < len(resp.Tunnels) { port = resp.Tunnels[t].ListenPort } - err = c.configureWireguardTunnel(t, resp, port) - if err != nil { - return fmt.Errorf("error configuring wireguard on %s: %v", c.tunnelIfname(t), err) - } - } - // Set up multipath routing if we have multiple active tunnels. - if nt > 1 { - if err := c.setupMultipathRouting(nt); err != nil { - log.Printf("warning: failed to set up multipath routing: %v", err) - // Fall back: traffic will just use the primary interface's route. - } else { - log.Printf("multipath routing configured across %d tunnels", nt) + if err := c.configureWireguardTunnel(t, resp, port); err != nil { + return fmt.Errorf("error configuring wireguard on %s: %v", ifname, err) } } - return nil -} - -// updateTunnelInterface assigns the WireGuard address to tunnel interface t. -// Every tunnel interface gets the same IP/CIDR so that the server gateway is -// reachable through each of them (required for multipath routing). -func (c *Client) updateTunnelInterface(t int, resp connectResponse) error { - cidr, err := netip.ParsePrefix(resp.AssignedAddr) - if err != nil { - return fmt.Errorf("failed to parse assigned address %v: %v", resp.AssignedAddr, err) - } - - link := c.tunnelLink(t) - - if t == 0 && c.wgCidr.IsValid() && cidr != c.wgCidr { - // On reconnect the primary tunnel may need the old address removed. - oldIpnet := prefixToIPNet(c.wgCidr) - if err := netlink.AddrDel(link, &netlink.Addr{IPNet: &oldIpnet}); err != nil { - log.Printf("warning: failed to remove old address from %s when reconnecting: %v", c.tunnelIfname(t), err) - } + // Bring up the user-facing interface and assign the address. + if err := c.bringUpUserInterface(); err != nil { + return err } - ipnet := prefixToIPNet(cidr) - err = netlink.AddrReplace(link, &netlink.Addr{IPNet: &ipnet}) - if err != nil { - return fmt.Errorf("failed to add address to %s: %v", c.tunnelIfname(t), err) + if err := c.updateAddress(resp); err != nil { + return err } - // Track the CIDR on the first tunnel. - if t == 0 { - c.wgCidr = cidr - } return nil } -// setupMultipathRouting adds one equal-cost device-scoped route per tunnel -// interface for the WireGuard subnet. The kernel automatically round-robins -// across equal-cost routes to the same destination, distributing flows across -// the tunnel interfaces. -func (c *Client) setupMultipathRouting(nt int) error { - if !c.wgCidr.IsValid() { - return fmt.Errorf("no valid CIDR assigned yet") - } - - subnetIPNet := prefixToIPNet(c.wgCidr.Masked()) - - // Remove any existing routes for this subnet so we start clean. - existingRoutes, _ := netlink.RouteList(nil, netlink.FAMILY_V4) - for i := range existingRoutes { - r := &existingRoutes[i] - if r.Dst != nil && r.Dst.String() == subnetIPNet.String() { - _ = netlink.RouteDel(r) - } +// bringUpUserInterface brings up the interface that applications will use. +// In single-tunnel mode this is the WireGuard interface itself; in multi-tunnel +// mode this is the bond device. +func (c *Client) bringUpUserInterface() error { + link, err := netlink.LinkByName(c.Ifname) + if err != nil { + return fmt.Errorf("failed to find interface %s: %v", c.Ifname, err) } - - // Add a device-scoped route for each tunnel interface. Equal-cost routes - // to the same destination cause the kernel to distribute flows across them. - for t := 0; t < nt; t++ { - ifname := c.tunnelIfname(t) - link, err := netlink.LinkByName(ifname) - if err != nil { - return fmt.Errorf("failed to find interface %s: %v", ifname, err) - } - route := &netlink.Route{ - LinkIndex: link.Attrs().Index, - Dst: &subnetIPNet, - Scope: netlink.SCOPE_LINK, - } - if err := netlink.RouteAppend(route); err != nil { - return fmt.Errorf("failed to add route on %s: %v", ifname, err) - } + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("error setting up %s: %v", c.Ifname, err) } - return nil } -// sendConnectionRequest attempts to send a connection request to the peer -func (c *Client) sendConnectionRequest() (connectResponse, error) { - connectUrl, err := url.Parse(fmt.Sprintf("https://%s/connect", c.ServerIp)) - if err != nil { - return connectResponse{}, fmt.Errorf("failed to parse connect URL: %v", err) - } - - reqJson := &connectRequest{ - PeerPublicKey: c.Key.PublicKey().String(), - } - buf, err := json.Marshal(reqJson) +// updateAddress assigns (or updates) the IP address on the user-facing +// interface (Ifname). +func (c *Client) updateAddress(resp connectResponse) error { + cidr, err := netip.ParsePrefix(resp.AssignedAddr) if err != nil { - return connectResponse{}, fmt.Errorf("failed to marshal connect request: %v", err) - } - - req := &http.Request{ - Method: http.MethodPost, - URL: connectUrl, - Header: http.Header{ - "Authorization": []string{"Bearer " + c.Password}, - }, - Body: io.NopCloser(bytes.NewBuffer(buf)), + return fmt.Errorf("failed to parse assigned address %v: %v", resp.AssignedAddr, err) } - resp, err := c.Http.Do(req) + link, err := netlink.LinkByName(c.Ifname) if err != nil { - return connectResponse{}, fmt.Errorf("failed to connect to server: %v", err) + return fmt.Errorf("failed to find interface %s: %v", c.Ifname, err) } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - recoverable := resp.StatusCode != http.StatusUnauthorized - return connectResponse{}, &ConnectionError{ - Message: fmt.Sprintf("server returned status %v", resp.Status), - Recoverable: recoverable, + if cidr != c.wgCidr { + if c.wgCidr.IsValid() { + oldIpnet := prefixToIPNet(c.wgCidr) + if err := netlink.AddrDel(link, &netlink.Addr{IPNet: &oldIpnet}); err != nil { + log.Printf("warning: failed to remove old address from %s: %v", c.Ifname, err) + } } + ipnet := prefixToIPNet(cidr) + if err := netlink.AddrAdd(link, &netlink.Addr{IPNet: &ipnet}); err != nil { + return fmt.Errorf("failed to add address to %s: %v", c.Ifname, err) + } + c.wgCidr = cidr } - - buf, err = io.ReadAll(resp.Body) - if err != nil { - return connectResponse{}, fmt.Errorf("failed to read response body: %v", err) - } - - var respJson connectResponse - json.Unmarshal(buf, &respJson) - return respJson, nil + return nil } // configureWireguardTunnel configures a single WireGuard tunnel interface with @@ -375,11 +408,6 @@ func (c *Client) configureWireguardTunnel(t int, resp connectResponse, serverPor }) } -// configureWireguard configures WireGuard on the primary tunnel (backwards compat). -func (c *Client) configureWireguard(connectionResponse connectResponse) error { - return c.configureWireguardTunnel(0, connectionResponse, connectionResponse.ServerListenPort) -} - // Disconnect notifies the server that this client is disconnecting, allowing the // server to immediately reclaim resources (wireguard peer and subnet IP) instead of // waiting for the idle timeout. @@ -420,31 +448,15 @@ func (c *Client) Disconnect() error { return nil } -// DeleteInterface removes all WireGuard tunnel interfaces. -func (c *Client) DeleteInterface() { - nt := c.numTunnels() - for t := nt - 1; t >= 0; t-- { - c.deleteTunnelInterface(t) - } -} - -// deleteTunnelInterface removes a single WireGuard tunnel interface. -func (c *Client) deleteTunnelInterface(t int) { - ifname := c.tunnelIfname(t) - log.Printf("About to delete vprox interface %v", ifname) - link := &linkWireguard{LinkAttrs: netlink.LinkAttrs{Name: ifname}} - err := netlink.LinkDel(link) - if err != nil { - log.Printf("error deleting vprox interface %v: %v", ifname, err) - } else { - log.Printf("successfully deleted vprox interface %v", ifname) - } -} +// --------------------------------------------------------------------------- +// Health check +// --------------------------------------------------------------------------- // CheckConnection checks the status of the connection with the wireguard peer, // and returns true if it is healthy. This sends 3 pings in succession, and blocks // until they receive a response or the timeout passes. -// Pings are sent through the primary tunnel interface (tunnel 0). +// Pings are sent through the user-facing interface (Ifname — either the plain +// WireGuard device in single-tunnel mode, or the bond in multi-tunnel mode). func (c *Client) CheckConnection(timeout time.Duration, cancelCtx context.Context) bool { pinger, err := probing.NewPinger(c.wgCidr.Masked().Addr().Next().String()) if err != nil { @@ -467,3 +479,65 @@ func (c *Client) CheckConnection(timeout time.Duration, cancelCtx context.Contex } return stats.PacketsRecv > 0 } + +// --------------------------------------------------------------------------- +// HTTPS / control-plane +// --------------------------------------------------------------------------- + +// sendConnectionRequest attempts to send a connection request to the peer. +func (c *Client) sendConnectionRequest() (connectResponse, error) { + connectUrl, err := url.Parse(fmt.Sprintf("https://%s/connect", c.ServerIp)) + if err != nil { + return connectResponse{}, fmt.Errorf("failed to parse connect URL: %v", err) + } + + reqJson := &connectRequest{ + PeerPublicKey: c.Key.PublicKey().String(), + } + buf, err := json.Marshal(reqJson) + if err != nil { + return connectResponse{}, fmt.Errorf("failed to marshal connect request: %v", err) + } + + req := &http.Request{ + Method: http.MethodPost, + URL: connectUrl, + Header: http.Header{ + "Authorization": []string{"Bearer " + c.Password}, + }, + Body: io.NopCloser(bytes.NewBuffer(buf)), + } + + resp, err := c.Http.Do(req) + if err != nil { + return connectResponse{}, fmt.Errorf("failed to connect to server: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + recoverable := resp.StatusCode != http.StatusUnauthorized + return connectResponse{}, &ConnectionError{ + Message: fmt.Sprintf("server returned status %v", resp.Status), + Recoverable: recoverable, + } + } + + buf, err = io.ReadAll(resp.Body) + if err != nil { + return connectResponse{}, fmt.Errorf("failed to read response body: %v", err) + } + + var respJson connectResponse + json.Unmarshal(buf, &respJson) + return respJson, nil +} + +// --------------------------------------------------------------------------- +// Sysfs helper +// --------------------------------------------------------------------------- + +// writeSysFile is a best-effort helper to write a value to a sysfs file. +// Used to poke kernel module parameters. Errors are silently ignored. +func writeSysFile(path, value string) error { + return os.WriteFile(path, []byte(value), 0644) +} From ba1a6330ef85ba4c6b4cd42ee04c618191b52819 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 01:38:11 -0400 Subject: [PATCH 10/14] create dummy interface --- lib/client.go | 243 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 167 insertions(+), 76 deletions(-) diff --git a/lib/client.go b/lib/client.go index 90cf948..a4b028b 100644 --- a/lib/client.go +++ b/lib/client.go @@ -12,7 +12,6 @@ import ( "net/http" "net/netip" "net/url" - "os" "time" probing "github.com/prometheus-community/pro-bing" @@ -21,6 +20,14 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +// PolicyRoutingTable is the custom routing table number used for multi-tunnel +// multipath routing. Traffic from the vprox IP is redirected here via an +// ip rule, and this table contains equal-cost routes across all tunnels. +const PolicyRoutingTable = 51820 + +// PolicyRoutingPriority is the ip rule priority for the vprox policy route. +const PolicyRoutingPriority = 100 + // Used to determine if we can recover from an error during connection setup. type ConnectionError struct { Message string @@ -42,8 +49,7 @@ func IsRecoverableError(err error) bool { } // Client manages a peering connection with a local WireGuard interface, or a -// set of parallel WireGuard interfaces bonded together when multi-tunnel is -// enabled. +// set of parallel WireGuard interfaces when multi-tunnel is enabled. // // Single-tunnel (NumTunnels <= 1): // @@ -51,18 +57,20 @@ func IsRecoverableError(err error) bool { // // Multi-tunnel (NumTunnels > 1): // -// WireGuard slaves: vprox0t0, vprox0t1, vprox0t2, ... -// Bond master: vprox0 (balance-rr, presents a single interface) +// WireGuard tunnels: vprox0t0, vprox0t1, vprox0t2, ... +// Dummy device: vprox0 (holds the IP address, user-facing) // -// Applications bind to "vprox0" and the bonding driver distributes packets -// round-robin across the WireGuard slaves. Each slave uses a different UDP -// port to the server, so the NIC hashes them to different hardware RX queues. +// Applications bind to "vprox0" (the dummy interface). An ip rule redirects +// traffic sourced from the vprox IP into a custom routing table that has +// equal-cost multipath routes across the WireGuard tunnels. The kernel +// distributes flows across them via L4 hashing. Each tunnel uses a different +// UDP port so the NIC hashes the outer packets to different hardware RX queues. type Client struct { // Key is the private key of the client. Key wgtypes.Key // Ifname is the name of the interface exposed to applications (e.g. "vprox0"). - // In multi-tunnel mode this is the bond device; individual WireGuard tunnels + // In multi-tunnel mode this is a dummy device; individual WireGuard tunnels // are named t0, t1, etc. Ifname string @@ -74,7 +82,7 @@ type Client struct { // NumTunnels is the number of parallel WireGuard tunnels to create. // When <= 1, the client creates a single plain WireGuard interface. - // When > 1, a bonding device is created over N WireGuard slaves. + // When > 1, a dummy device + policy routing is created over N WireGuard tunnels. NumTunnels int // WgClient is a shared client for interacting with the WireGuard kernel module. @@ -137,7 +145,8 @@ func (c *Client) tunnelLink(t int) *linkWireguard { // CreateInterface creates the network interface(s) that applications will use. // - Single-tunnel: one plain WireGuard interface named Ifname. -// - Multi-tunnel: N WireGuard interfaces + a bond master named Ifname. +// - Multi-tunnel: N WireGuard interfaces + a dummy device named Ifname +// with policy routing to distribute traffic across the tunnels. // // DeleteInterface() must be called to clean up. func (c *Client) CreateInterface() error { @@ -153,15 +162,15 @@ func (c *Client) CreateInterface() error { } } - // In multi-tunnel mode, create a bond over the WireGuard slaves. + // In multi-tunnel mode, create a dummy device for the user-facing interface. if c.isMultiTunnel() { - if err := c.createBond(nt); err != nil { + if err := c.createDummyInterface(); err != nil { for t := 0; t < nt; t++ { c.deleteTunnelInterface(t) } return err } - log.Printf("created bond %s over %d tunnel slaves (%s .. %s)", + log.Printf("created dummy %s with %d WireGuard tunnels (%s .. %s)", c.Ifname, nt, c.tunnelIfname(0), c.tunnelIfname(nt-1)) } @@ -188,72 +197,47 @@ func (c *Client) createTunnelInterface(t int) error { log.Printf("warning: failed to set TxQLen on %s: %v", link.Name, err) } - // In multi-tunnel mode the slaves are brought up later when enslaved - // to the bond. In single-tunnel mode we don't bring it up yet either - // — Connect() will do it. - return nil } -// createBond creates a balance-rr bond device named Ifname and enslaves all -// WireGuard tunnel interfaces to it. -func (c *Client) createBond(nt int) error { - // Ensure the bonding kernel module is loaded. - _ = writeSysFile("/sys/module/bonding/initstate", "") - - bond := netlink.NewLinkBond(netlink.LinkAttrs{ - Name: c.Ifname, - MTU: WireguardMTU, - TxQLen: WireguardTxQLen, - }) - bond.Mode = netlink.BOND_MODE_BALANCE_RR - // MIIMon: link monitoring interval in ms. We set a low value so that if - // a slave goes down, the bond reacts quickly. - bond.Miimon = 100 - - // Remove any stale bond with this name. +// createDummyInterface creates a dummy network interface named Ifname. This is +// the user-facing device that applications bind to. A policy routing rule will +// redirect its traffic into a custom table with multipath routes across the +// WireGuard tunnels. +func (c *Client) createDummyInterface() error { + // Remove any stale interface with this name. if existing, _ := netlink.LinkByName(c.Ifname); existing != nil { _ = netlink.LinkDel(existing) } - if err := netlink.LinkAdd(bond); err != nil { - return fmt.Errorf("failed to create bond %s: %v", c.Ifname, err) - } - - // Enslave each WireGuard tunnel interface to the bond. - bondLink, err := netlink.LinkByName(c.Ifname) - if err != nil { - netlink.LinkDel(bond) - return fmt.Errorf("failed to find bond %s after creation: %v", c.Ifname, err) + dummy := &netlink.Dummy{ + LinkAttrs: netlink.LinkAttrs{ + Name: c.Ifname, + MTU: WireguardMTU, + TxQLen: WireguardTxQLen, + }, } - for t := 0; t < nt; t++ { - slave, err := netlink.LinkByName(c.tunnelIfname(t)) - if err != nil { - netlink.LinkDel(bond) - return fmt.Errorf("failed to find slave %s: %v", c.tunnelIfname(t), err) - } - // The slave must be down before enslaving. - _ = netlink.LinkSetDown(slave) - if err := netlink.LinkSetMaster(slave, bondLink); err != nil { - netlink.LinkDel(bond) - return fmt.Errorf("failed to enslave %s to %s: %v", c.tunnelIfname(t), c.Ifname, err) - } + if err := netlink.LinkAdd(dummy); err != nil { + return fmt.Errorf("failed to create dummy interface %s: %v", c.Ifname, err) } return nil } -// DeleteInterface removes all interfaces (bond + WireGuard tunnels). +// DeleteInterface removes all interfaces and policy routing rules. func (c *Client) DeleteInterface() { if c.isMultiTunnel() { - // Deleting the bond master also releases the slaves. - log.Printf("About to delete bond interface %v", c.Ifname) - if bond, err := netlink.LinkByName(c.Ifname); err == nil { - if err := netlink.LinkDel(bond); err != nil { - log.Printf("error deleting bond %v: %v", c.Ifname, err) + // Clean up policy routing. + c.cleanupPolicyRouting() + + // Delete the dummy interface. + log.Printf("About to delete dummy interface %v", c.Ifname) + if dummy, err := netlink.LinkByName(c.Ifname); err == nil { + if err := netlink.LinkDel(dummy); err != nil { + log.Printf("error deleting dummy %v: %v", c.Ifname, err) } else { - log.Printf("successfully deleted bond %v", c.Ifname) + log.Printf("successfully deleted dummy %v", c.Ifname) } } } @@ -304,9 +288,8 @@ func (c *Client) Connect() error { for t := 0; t < nt; t++ { ifname := c.tunnelIfname(t) - // Bring the slave up (bond requires slaves to be up for traffic). - slave := c.tunnelLink(t) - if err := netlink.LinkSetUp(slave); err != nil { + link := c.tunnelLink(t) + if err := netlink.LinkSetUp(link); err != nil { return fmt.Errorf("error setting up %s: %v", ifname, err) } @@ -331,6 +314,15 @@ func (c *Client) Connect() error { return err } + // In multi-tunnel mode, assign addresses to each tunnel and set up + // policy routing to distribute traffic across them. + if c.isMultiTunnel() && nt > 1 { + if err := c.setupPolicyRouting(nt); err != nil { + return fmt.Errorf("error setting up policy routing: %v", err) + } + log.Printf("policy routing configured across %d tunnels", nt) + } + return nil } @@ -377,6 +369,115 @@ func (c *Client) updateAddress(resp connectResponse) error { return nil } +// --------------------------------------------------------------------------- +// Policy routing (multi-tunnel) +// --------------------------------------------------------------------------- + +// setupPolicyRouting creates: +// 1. An ip rule that matches traffic from the vprox source IP and directs it +// to a custom routing table. +// 2. Equal-cost multipath routes in that table across all WireGuard tunnels. +// +// This allows applications binding to the dummy vprox0 interface (or using the +// vprox IP as source) to have their traffic distributed across tunnels by the +// kernel's L4 flow hash. +func (c *Client) setupPolicyRouting(nt int) error { + if !c.wgCidr.IsValid() { + return fmt.Errorf("no valid CIDR assigned yet") + } + + // Assign the same IP address to each WireGuard tunnel interface so that + // the kernel can use any of them to reach the server WireGuard IP. + for t := 0; t < nt; t++ { + ifname := c.tunnelIfname(t) + tunnelLink, err := netlink.LinkByName(ifname) + if err != nil { + return fmt.Errorf("failed to find tunnel %s: %v", ifname, err) + } + ipnet := prefixToIPNet(c.wgCidr) + if err := netlink.AddrReplace(tunnelLink, &netlink.Addr{IPNet: &ipnet}); err != nil { + return fmt.Errorf("failed to assign address to %s: %v", ifname, err) + } + } + + // Build multipath nexthops — one per tunnel interface. + gwAddr := c.wgCidr.Masked().Addr().Next() + gwIP := addrToIp(gwAddr) + + var nexthops []*netlink.NexthopInfo + for t := 0; t < nt; t++ { + ifname := c.tunnelIfname(t) + tunnelLink, err := netlink.LinkByName(ifname) + if err != nil { + return fmt.Errorf("failed to find tunnel %s: %v", ifname, err) + } + nexthops = append(nexthops, &netlink.NexthopInfo{ + LinkIndex: tunnelLink.Attrs().Index, + Gw: gwIP, + Hops: 0, + }) + } + + // Add default multipath route in the custom table. + _, defaultDst, _ := net.ParseCIDR("0.0.0.0/0") + mpRoute := &netlink.Route{ + Table: PolicyRoutingTable, + Dst: defaultDst, + MultiPath: nexthops, + } + if err := netlink.RouteReplace(mpRoute); err != nil { + return fmt.Errorf("failed to add multipath route to table %d: %v", PolicyRoutingTable, err) + } + + // Add an ip rule: from lookup table PolicyRoutingTable. + srcIP := c.wgCidr.Addr() + srcNet := &net.IPNet{ + IP: addrToIp(srcIP), + Mask: net.CIDRMask(32, 32), + } + rule := netlink.NewRule() + rule.Src = srcNet + rule.Table = PolicyRoutingTable + rule.Priority = PolicyRoutingPriority + + // Remove any stale rule first (idempotent). + _ = netlink.RuleDel(rule) + + if err := netlink.RuleAdd(rule); err != nil { + return fmt.Errorf("failed to add ip rule for %v: %v", srcIP, err) + } + + return nil +} + +// cleanupPolicyRouting removes the ip rule and flushes the custom routing table. +func (c *Client) cleanupPolicyRouting() { + if c.wgCidr.IsValid() { + srcIP := c.wgCidr.Addr() + srcNet := &net.IPNet{ + IP: addrToIp(srcIP), + Mask: net.CIDRMask(32, 32), + } + rule := netlink.NewRule() + rule.Src = srcNet + rule.Table = PolicyRoutingTable + rule.Priority = PolicyRoutingPriority + if err := netlink.RuleDel(rule); err != nil { + log.Printf("warning: failed to delete ip rule: %v", err) + } + } + + // Flush routes in our custom table. + routes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{ + Table: PolicyRoutingTable, + }, netlink.RT_FILTER_TABLE) + if err == nil { + for i := range routes { + _ = netlink.RouteDel(&routes[i]) + } + } +} + // configureWireguardTunnel configures a single WireGuard tunnel interface with // the server as a peer on the given port. func (c *Client) configureWireguardTunnel(t int, resp connectResponse, serverPort int) error { @@ -531,13 +632,3 @@ func (c *Client) sendConnectionRequest() (connectResponse, error) { json.Unmarshal(buf, &respJson) return respJson, nil } - -// --------------------------------------------------------------------------- -// Sysfs helper -// --------------------------------------------------------------------------- - -// writeSysFile is a best-effort helper to write a value to a sysfs file. -// Used to poke kernel module parameters. Errors are silently ignored. -func writeSysFile(path, value string) error { - return os.WriteFile(path, []byte(value), 0644) -} From 3c294187b747827e0c5bd8502235e42206184a13 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 01:40:22 -0400 Subject: [PATCH 11/14] rebuild dummy interface --- lib/client.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lib/client.go b/lib/client.go index a4b028b..f58529f 100644 --- a/lib/client.go +++ b/lib/client.go @@ -556,8 +556,8 @@ func (c *Client) Disconnect() error { // CheckConnection checks the status of the connection with the wireguard peer, // and returns true if it is healthy. This sends 3 pings in succession, and blocks // until they receive a response or the timeout passes. -// Pings are sent through the user-facing interface (Ifname — either the plain -// WireGuard device in single-tunnel mode, or the bond in multi-tunnel mode). +// In multi-tunnel mode, pings are sent through the first WireGuard tunnel +// interface (not the dummy device) so replies are received on the same interface. func (c *Client) CheckConnection(timeout time.Duration, cancelCtx context.Context) bool { pinger, err := probing.NewPinger(c.wgCidr.Masked().Addr().Next().String()) if err != nil { @@ -565,7 +565,10 @@ func (c *Client) CheckConnection(timeout time.Duration, cancelCtx context.Contex return false } - pinger.InterfaceName = c.Ifname + // Use the first WireGuard tunnel for health checks. In single-tunnel mode + // tunnelIfname(0) == Ifname; in multi-tunnel mode it's the actual WireGuard + // device (e.g. "vprox0t0") rather than the dummy ("vprox0"). + pinger.InterfaceName = c.tunnelIfname(0) pinger.Timeout = timeout pinger.Count = 3 pinger.Interval = 10 * time.Millisecond // Send approximately all at once From 325243ef88c4605a450f83347774415703d73da0 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 01:48:51 -0400 Subject: [PATCH 12/14] fix download path --- lib/client.go | 6 --- lib/server.go | 114 ++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 87 insertions(+), 33 deletions(-) diff --git a/lib/client.go b/lib/client.go index f58529f..12006ef 100644 --- a/lib/client.go +++ b/lib/client.go @@ -20,13 +20,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// PolicyRoutingTable is the custom routing table number used for multi-tunnel -// multipath routing. Traffic from the vprox IP is redirected here via an -// ip rule, and this table contains equal-cost routes across all tunnels. -const PolicyRoutingTable = 51820 -// PolicyRoutingPriority is the ip rule priority for the vprox policy route. -const PolicyRoutingPriority = 100 // Used to determine if we can recover from an error during connection setup. type ConnectionError struct { diff --git a/lib/server.go b/lib/server.go index d6256ed..d427379 100644 --- a/lib/server.go +++ b/lib/server.go @@ -25,6 +25,13 @@ import ( // FwmarkBase is the base value for firewall marks used by vprox. const FwmarkBase = 0x54437D00 +// PolicyRoutingTable is the custom routing table number used for multi-tunnel +// multipath routing on both server and client. +const PolicyRoutingTable = 51820 + +// PolicyRoutingPriority is the ip rule priority for the vprox policy route. +const PolicyRoutingPriority = 100 + // UDP listen port base value for WireGuard connections. const WireguardListenPortBase = 50227 @@ -541,12 +548,11 @@ func (srv *Server) StartWireguard() error { log.Printf("[%v] started %d WireGuard tunnels (ports %d..%d)", srv.BindAddr, nt, srv.tunnelListenPort(0), srv.tunnelListenPort(nt-1)) - // Set up equal-cost routes across all tunnel interfaces so the kernel - // distributes reply traffic across them (same approach as client side). - if err := srv.setupMultipathRouting(nt); err != nil { - log.Printf("[%v] warning: failed to set up multipath routing: %v", srv.BindAddr, err) + // Set up policy routing so reply traffic is distributed across tunnels. + if err := srv.setupPolicyRouting(nt); err != nil { + log.Printf("[%v] warning: failed to set up policy routing: %v", srv.BindAddr, err) } else { - log.Printf("[%v] multipath routing configured across %d tunnels", srv.BindAddr, nt) + log.Printf("[%v] policy routing configured across %d tunnels", srv.BindAddr, nt) } } return nil @@ -640,41 +646,90 @@ func (srv *Server) createFreshInterface(link *linkWireguard, tunnelIndex int) er return nil } -// setupMultipathRouting adds equal-cost device-scoped routes for the WireGuard -// subnet across all tunnel interfaces so the kernel round-robins reply traffic. -func (srv *Server) setupMultipathRouting(nt int) error { +// setupPolicyRouting creates: +// 1. Multipath routes in a custom routing table across all WireGuard tunnels. +// 2. An ip rule that matches traffic from the server's WireGuard IP and directs +// it to that custom table. +// +// This ensures reply traffic from the server is distributed across all tunnels +// by the kernel's L4 flow hash, not just sent through tunnel 0. +func (srv *Server) setupPolicyRouting(nt int) error { + // Build multipath nexthops — one per tunnel interface. We use the subnet + // as the destination (not default) since the server only needs to reach + // the WireGuard peer subnet via these tunnels. subnetIPNet := prefixToIPNet(srv.WgCidr.Masked()) - // Remove any existing routes for this subnet so we start clean. - existingRoutes, _ := netlink.RouteList(nil, netlink.FAMILY_V4) - for i := range existingRoutes { - r := &existingRoutes[i] - if r.Dst != nil && r.Dst.String() == subnetIPNet.String() { - _ = netlink.RouteDel(r) - } - } - - // Append one route per tunnel interface. Equal-cost routes to the same - // destination cause the kernel to distribute flows across them. + var nexthops []*netlink.NexthopInfo for t := 0; t < nt; t++ { ifname := srv.TunnelIfname(t) link, err := netlink.LinkByName(ifname) if err != nil { return fmt.Errorf("failed to find interface %s: %v", ifname, err) } - route := &netlink.Route{ + nexthops = append(nexthops, &netlink.NexthopInfo{ LinkIndex: link.Attrs().Index, - Dst: &subnetIPNet, - Scope: netlink.SCOPE_LINK, - } - if err := netlink.RouteAppend(route); err != nil { - return fmt.Errorf("failed to append route on %s: %v", ifname, err) - } + Hops: 0, + }) + } + + // Add the multipath route in custom table. + mpRoute := &netlink.Route{ + Table: PolicyRoutingTable, + Dst: &subnetIPNet, + MultiPath: nexthops, + Scope: netlink.SCOPE_LINK, + } + if err := netlink.RouteReplace(mpRoute); err != nil { + return fmt.Errorf("failed to add multipath route to table %d: %v", PolicyRoutingTable, err) } + + // Add an ip rule: from lookup custom table. + srcIP := srv.WgCidr.Addr() + srcNet := &net.IPNet{ + IP: addrToIp(srcIP), + Mask: net.CIDRMask(32, 32), + } + rule := netlink.NewRule() + rule.Src = srcNet + rule.Table = PolicyRoutingTable + rule.Priority = PolicyRoutingPriority + + // Remove any stale rule first (idempotent). + _ = netlink.RuleDel(rule) + + if err := netlink.RuleAdd(rule); err != nil { + return fmt.Errorf("failed to add ip rule for %v: %v", srcIP, err) + } + return nil } -// CleanupWireguard removes all tunnel WireGuard interfaces. +// cleanupPolicyRouting removes the ip rule and flushes the custom routing table. +func (srv *Server) cleanupPolicyRouting() { + srcIP := srv.WgCidr.Addr() + srcNet := &net.IPNet{ + IP: addrToIp(srcIP), + Mask: net.CIDRMask(32, 32), + } + rule := netlink.NewRule() + rule.Src = srcNet + rule.Table = PolicyRoutingTable + rule.Priority = PolicyRoutingPriority + if err := netlink.RuleDel(rule); err != nil { + log.Printf("warning: failed to delete ip rule: %v", err) + } + + routes, err := netlink.RouteListFiltered(netlink.FAMILY_V4, &netlink.Route{ + Table: PolicyRoutingTable, + }, netlink.RT_FILTER_TABLE) + if err == nil { + for i := range routes { + _ = netlink.RouteDel(&routes[i]) + } + } +} + +// CleanupWireguard removes all tunnel WireGuard interfaces and policy routing. func (srv *Server) CleanupWireguard() { srv.mu.Lock() relinquished := srv.relinquished @@ -686,6 +741,11 @@ func (srv *Server) CleanupWireguard() { } log.Printf("[%v] cleaning up WireGuard state", srv.BindAddr) + + if srv.numTunnels() > 1 { + srv.cleanupPolicyRouting() + } + nt := srv.numTunnels() for t := 0; t < nt; t++ { srv.cleanupWireguardTunnel(t) From 96bcf1087ad1cd9acf77bd141946dcaf3c34c6e6 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 02:03:47 -0400 Subject: [PATCH 13/14] updates server --- lib/server.go | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/lib/server.go b/lib/server.go index d427379..061ce0c 100644 --- a/lib/server.go +++ b/lib/server.go @@ -648,11 +648,12 @@ func (srv *Server) createFreshInterface(link *linkWireguard, tunnelIndex int) er // setupPolicyRouting creates: // 1. Multipath routes in a custom routing table across all WireGuard tunnels. -// 2. An ip rule that matches traffic from the server's WireGuard IP and directs -// it to that custom table. +// 2. An ip rule that matches traffic destined for the WireGuard subnet and +// directs it to that custom table. // -// This ensures reply traffic from the server is distributed across all tunnels -// by the kernel's L4 flow hash, not just sent through tunnel 0. +// The rule matches on destination (not source) because the server forwards +// traffic from the internet to clients — the source IP of forwarded packets is +// the remote host, not the server's WireGuard IP. func (srv *Server) setupPolicyRouting(nt int) error { // Build multipath nexthops — one per tunnel interface. We use the subnet // as the destination (not default) since the server only needs to reach @@ -683,14 +684,15 @@ func (srv *Server) setupPolicyRouting(nt int) error { return fmt.Errorf("failed to add multipath route to table %d: %v", PolicyRoutingTable, err) } - // Add an ip rule: from lookup custom table. - srcIP := srv.WgCidr.Addr() - srcNet := &net.IPNet{ - IP: addrToIp(srcIP), - Mask: net.CIDRMask(32, 32), + // Add an ip rule: to lookup custom table. + // This catches both locally-originated replies and forwarded (NAT'd) traffic + // destined for any client in the WireGuard subnet. + dstNet := &net.IPNet{ + IP: addrToIp(srv.WgCidr.Masked().Addr()), + Mask: net.CIDRMask(srv.WgCidr.Bits(), 32), } rule := netlink.NewRule() - rule.Src = srcNet + rule.Dst = dstNet rule.Table = PolicyRoutingTable rule.Priority = PolicyRoutingPriority @@ -698,7 +700,7 @@ func (srv *Server) setupPolicyRouting(nt int) error { _ = netlink.RuleDel(rule) if err := netlink.RuleAdd(rule); err != nil { - return fmt.Errorf("failed to add ip rule for %v: %v", srcIP, err) + return fmt.Errorf("failed to add ip rule for dst %v: %v", dstNet, err) } return nil @@ -706,13 +708,12 @@ func (srv *Server) setupPolicyRouting(nt int) error { // cleanupPolicyRouting removes the ip rule and flushes the custom routing table. func (srv *Server) cleanupPolicyRouting() { - srcIP := srv.WgCidr.Addr() - srcNet := &net.IPNet{ - IP: addrToIp(srcIP), - Mask: net.CIDRMask(32, 32), + dstNet := &net.IPNet{ + IP: addrToIp(srv.WgCidr.Masked().Addr()), + Mask: net.CIDRMask(srv.WgCidr.Bits(), 32), } rule := netlink.NewRule() - rule.Src = srcNet + rule.Dst = dstNet rule.Table = PolicyRoutingTable rule.Priority = PolicyRoutingPriority if err := netlink.RuleDel(rule); err != nil { From 1a03d13aa77069d0ae356c37e7c34ef12eb54e32 Mon Sep 17 00:00:00 2001 From: Luis Capelo Date: Thu, 12 Mar 2026 03:27:51 -0400 Subject: [PATCH 14/14] update readme --- README.md | 141 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 135 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 2b77659..02d4610 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,58 @@ # vprox -vprox is a high-performance network proxy acting as a split tunnel VPN. The server accepts peering requests from clients, which then establish WireGuard tunnels that direct all traffic on the client's network interface through the server, with IP masquerading. +vprox is a high-performance network proxy acting as a split tunnel VPN, powered by WireGuard. The server accepts peering requests from clients, which then establish WireGuard tunnels that direct all traffic on the client's network interface through the server, with IP masquerading. Both the client and server commands need root access. The server can have multiple public IP addresses attached, and on cloud providers, it automatically uses the instance metadata endpoint to discover its public IP addresses and start one proxy for each. This property allows the server to be high-availability. In the event of a restart or network partition, the tunnels remain open. If the server's IP address is attached to a new host, clients will automatically re-establish connections. This means that IP addresses can be moved to different hosts in event of an outage. +## Architecture + +In single-tunnel mode, vprox creates one WireGuard interface per connection: + +``` +Client Server +┌──────────┐ UDP :50227 ┌──────────┐ +│ vprox0 │◄──────────────────►│ vprox0 │──► Internet +│ (wg) │ │ (wg) │ (SNAT) +└──────────┘ └──────────┘ +``` + +In multi-tunnel mode (`--tunnels N`), vprox creates N parallel WireGuard interfaces, each on a different UDP port. On the client, a dummy interface with policy routing presents a single `vprox0` device to applications while distributing traffic across all tunnels: + +``` +Client Server +┌──────────┐ policy routing ┌──────────┐ +│ vprox0 │ (dummy, user- │ vprox0 │◄─── UDP :50227 ───► vprox0t0 (wg) +│ │ facing) │ vprox0t1│◄─── UDP :50228 ───► vprox0t1 (wg) +│ │ │ vprox0t2│◄─── UDP :50229 ───► vprox0t2 (wg) ──► Internet +│ │ ip rule: to wg │ vprox0t3│◄─── UDP :50230 ───► vprox0t3 (wg) (SNAT) +│ │ subnet → table └──────────┘ +│ │ with multipath +│ vprox0t0 │◄──────────────────► +│ vprox0t1 │◄──────────────────► +│ vprox0t2 │◄──────────────────► +│ vprox0t3 │◄──────────────────► +└──────────┘ +``` + +Each tunnel uses a different UDP port, so the NIC's RSS (Receive Side Scaling) hashes them to different hardware RX queues. The kernel's multipath routing distributes flows across tunnels using L4 hashing. Applications bind to the single `vprox0` interface and are unaware of the underlying tunnels. + ## Usage +### Prerequisites + On the Linux VPN server and client, install system requirements (`iptables` and `wireguard`). ```bash # On Ubuntu sudo apt install iptables wireguard -# On Fedora +# On Fedora / Amazon Linux sudo dnf install iptables wireguard-tools ``` -Also, you need to set some kernel settings with Sysctl. Enable IPv4 forwarding, and make sure that [`rp_filter`](https://sysctl-explorer.net/net/ipv4/rp_filter/) is set to 2, or masqueraded packets may be filtered out. You can edit your OS configuration file to set this persistently, or set it once below. +Set the required kernel parameters. Enable IPv4 forwarding, and make sure that [`rp_filter`](https://sysctl-explorer.net/net/ipv4/rp_filter/) is set to 2, or masqueraded packets may be filtered out. ```bash # Applies until next reboot @@ -26,7 +60,9 @@ sudo sysctl -w net.ipv4.ip_forward=1 sudo sysctl -w net.ipv4.conf.all.rp_filter=2 ``` -To set up `vprox`, you'll need the private IPv4 address of the server connected to an Internet gateway (use the `ip addr` command), as well as a block of IPs to allocate to the WireGuard subnet between server and client. This has no particular meaning and can be arbitrarily chosen to not overlap with other subnets. +### Basic setup + +You'll need the private IPv4 address of the server connected to an Internet gateway (use the `ip addr` command), as well as a block of IPs to allocate to the WireGuard subnet between server and client. This can be arbitrarily chosen to not overlap with other subnets. ```bash # [Machine A: public IP 1.2.3.4, private IP 172.31.64.125] @@ -42,6 +78,61 @@ Note that Machine B must be able to send UDP packets to port 50227 on Machine A, All outbound network traffic seen by `vprox0` will automatically be forwarded through the WireGuard tunnel. The VPN server masquerades the source IP address. +### Multi-tunnel mode (high throughput) + +A single WireGuard tunnel is encapsulated in one UDP flow (fixed 4-tuple). On cloud providers like AWS, NIC hardware hashes flows to RX queues by this 4-tuple, so a single tunnel is limited to the throughput of one hardware queue — typically ~2-2.5 Gbps on AWS ENA. + +Multi-tunnel mode creates N parallel WireGuard tunnels on different UDP ports, spreading traffic across multiple NIC queues: + +```bash +# Server: 4 parallel tunnels per IP +VPROX_PASSWORD=my-password vprox server --ip 172.31.64.125 --wg-block 240.1.0.0/16 --tunnels 4 + +# Client: 4 parallel tunnels (must be <= server's --tunnels value) +VPROX_PASSWORD=my-password vprox connect 1.2.3.4 --interface vprox0 --tunnels 4 +``` + +Both server and client must use `--tunnels`. The `dummy` kernel module must be available on the client (`sudo modprobe dummy`). + +**Required sysctl on both server and client** for multipath flow distribution: + +```bash +sudo sysctl -w net.ipv4.fib_multipath_hash_policy=1 +``` + +Applications bind to the single `vprox0` interface as before — the multi-tunnel routing is transparent. + +**Choosing the number of tunnels:** Start with `--tunnels 4`. The optimal value depends on the number of CPU cores and NIC queues. On a 4-core server, 4 tunnels will typically saturate the CPU. Adding more tunnels than CPU cores provides diminishing returns since WireGuard encryption becomes the bottleneck. + +### Performance tuning + +For maximum throughput, apply these additional sysctl settings on both server and client: + +```bash +# UDP/Socket buffer sizes (WireGuard uses UDP) +sudo sysctl -w net.core.rmem_max=26214400 +sudo sysctl -w net.core.wmem_max=26214400 +sudo sysctl -w net.core.rmem_default=1048576 +sudo sysctl -w net.core.wmem_default=1048576 + +# Network device backlog (for high packet rates) +sudo sysctl -w net.core.netdev_max_backlog=50000 + +# TCP tuning (for traffic inside the tunnel) +sudo sysctl -w net.ipv4.tcp_rmem="4096 1048576 26214400" +sudo sysctl -w net.ipv4.tcp_wmem="4096 1048576 26214400" +sudo sysctl -w net.ipv4.tcp_congestion_control=bbr + +# Multipath flow hashing (required for multi-tunnel) +sudo sysctl -w net.ipv4.fib_multipath_hash_policy=1 + +# Connection tracking limits (for NAT with many peers) +sudo sysctl -w net.netfilter.nf_conntrack_max=1048576 +``` + +To make these settings persistent across reboots, add them to `/etc/sysctl.d/99-vprox.conf` without the `sudo sysctl -w` prefix, then apply with `sudo sysctl --system`. + + ### Building To build `vprox`, run the following command with Go 1.22+ installed: @@ -56,7 +147,7 @@ This produces a static binary in `./vprox`. On cloud providers like AWS, you can attach [secondary private IP addresses](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/MultipleIP.html) to an interface and associate each of them with a global IPv4 unicast address. -A `vprox` server listening on multiple IP addresses needs to provide `--ip` option once for every IP, and each IP requires its its own WireGuard VPN subnet with a non-overlapping address range. You can pass `--wg-block-per-ip /22` to split the `--wg-block` into smaller blocks for each IP. +A `vprox` server listening on multiple IP addresses needs to provide `--ip` option once for every IP, and each IP requires its own WireGuard VPN subnet with a non-overlapping address range. You can pass `--wg-block-per-ip /22` to split the `--wg-block` into smaller blocks for each IP. On AWS in particular, the `--cloud aws` option allows you to automatically discover the private IP addresses of the server by periodically querying the instance metadata endpoint. @@ -66,8 +157,46 @@ On AWS in particular, the `--cloud aws` option allows you to automatically disco - Supports forwarding IPv4 packets - Works if the server has multiple IPs, specified with `--wg-block-per-ip` - Automatic discovery of IPs using instance metadata endpoints (AWS) -- Only one vprox server may be running on a host +- Multi-tunnel mode for throughput beyond the single NIC queue limit (`--tunnels N`) +- WireGuard interfaces tuned with GSO/GRO offload, multi-queue, and optimized MTU/MSS +- Connection tracking bypass (NOTRACK) for reduced CPU overhead on WireGuard UDP flows +- TCP MSS clamping to prevent fragmentation inside the tunnel - Control traffic is encrypted with TLS (Warning: does not verify server certificate) +- Only one vprox server may be running on a host + +## How it works + +### Control plane + +The server listens on port 443 (HTTPS) for control traffic. Clients send a `/connect` request with their WireGuard public key. The server allocates a peer IP from the WireGuard subnet, adds the client as a peer on all tunnel interfaces, and returns the assigned address along with a list of tunnel endpoints (listen ports). + +### Data plane + +WireGuard handles the data plane. Each tunnel interface encrypts/decrypts traffic independently. The server applies iptables rules for: + +- **SNAT (masquerade)**: Outbound traffic from WireGuard peers is source-NAT'd to the server's bind address. +- **Firewall marks**: Traffic from WireGuard interfaces is marked for routing policy. +- **MSS clamping**: TCP SYN packets are clamped to fit within the WireGuard MTU (1380 bytes). +- **NOTRACK**: WireGuard UDP flows bypass connection tracking to reduce per-packet CPU overhead. + +### Multi-tunnel routing + +In multi-tunnel mode, both server and client use Linux policy routing to distribute traffic: + +- A custom routing table (51820) contains multipath routes across all tunnel interfaces. +- An `ip rule` directs matching traffic to this table. +- On the client, the rule matches traffic sourced from the WireGuard IP (set by the dummy `vprox0` device). +- On the server, the rule matches traffic destined for the WireGuard subnet (forwarded download traffic). +- The kernel's L4 multipath hash (`fib_multipath_hash_policy=1`) distributes different flows to different tunnels. + +### Interface tuning + +WireGuard interfaces are created with performance-optimized settings: + +- **MTU 1420**: Prevents fragmentation on standard 1500 MTU networks (WireGuard adds ~60 bytes overhead). +- **GSO/GRO 65536**: Enables Generic Segmentation/Receive Offload, allowing the kernel to batch packets into 64 KB super-packets before encryption (Linux 5.19+). +- **4 TX/RX queues**: Enables parallel packet processing across multiple CPU cores. +- **TxQLen 1000**: Reduces packet drops during traffic bursts. ## Authors