From f1db8c915f4e00e32da35ccb7f0c7619fb570f32 Mon Sep 17 00:00:00 2001 From: Jordan Rife Date: Thu, 29 May 2025 12:03:41 -0700 Subject: [PATCH 1/3] wgtypes: Add Remove flag to allowed IPs [1] adds the WGALLOWEDIP_F_REMOVE_ME flag to WireGuard's Netlink API which, in the same way that WGPEER_F_REMOVE_ME allows a user to remove a single peer from a WireGuard device's configuration, allows a user to remove an ip from a peer's set of allowed ips. This capability was subsequently ported to wireguard-go as well. Add support for this feature to wgctrl-go, allowing clients to incrementally remove allowed IPs on a peer like so: wgtypes.Config{ Peers: []wgtypes.PeerConfig{ { PublicKey: peerKey, AllowedIPs: []wgtypes.AllowedIPConfig{ { IPNet: ip, Remove: true, }, }, }, }, } [1]: https://lore.kernel.org/netdev/20250517192955.594735-2-jordan@jrife.io/ Signed-off-by: Jordan Rife --- client_integration_test.go | 19 +++++++++++++++---- internal/wgfreebsd/client_freebsd.go | 24 +++++++++++++++++------- internal/wglinux/configure_linux.go | 18 ++++++++++++++---- internal/wguser/configure.go | 11 ++++++++++- internal/wgwindows/client_windows.go | 5 +++++ wgtypes/types.go | 12 +++++++++++- 6 files changed, 72 insertions(+), 17 deletions(-) diff --git a/client_integration_test.go b/client_integration_test.go index ab4bedc..60dd975 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -141,6 +141,18 @@ func testGet(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { } } +func ipsToAllowedIPConfig(ips []net.IPNet) []wgtypes.AllowedIPConfig { + result := make([]wgtypes.AllowedIPConfig, len(ips)) + + for i := range ips { + result[i] = wgtypes.AllowedIPConfig{ + IPNet: ips[i], + } + } + + return result +} + func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { var ( port = 8888 @@ -162,7 +174,7 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { Peers: []wgtypes.PeerConfig{{ PublicKey: peerKey, ReplaceAllowedIPs: true, - AllowedIPs: ips, + AllowedIPs: ipsToAllowedIPConfig(ips), }}, } @@ -245,7 +257,7 @@ func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { peers = append(peers, wgtypes.PeerConfig{ PublicKey: wgtest.MustPublicKey(), ReplaceAllowedIPs: true, - AllowedIPs: ips, + AllowedIPs: ipsToAllowedIPConfig(ips), }) countIPs += len(ips) @@ -295,7 +307,7 @@ func testConfigureManyPeers(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { Port: 1111, }, PersistentKeepaliveInterval: &dur, - AllowedIPs: ips, + AllowedIPs: ipsToAllowedIPConfig(ips), }) } @@ -370,7 +382,6 @@ func testConfigurePeersUpdateOnly(t *testing.T, c *wgctrl.Client, d *wgtypes.Dev t.Skip("FreeBSD kernel devices do not support UpdateOnly flag") } - t.Fatalf("failed to configure second time on %q: %v", d.Name, err) } diff --git a/internal/wgfreebsd/client_freebsd.go b/internal/wgfreebsd/client_freebsd.go index 98b1f81..de1940d 100644 --- a/internal/wgfreebsd/client_freebsd.go +++ b/internal/wgfreebsd/client_freebsd.go @@ -179,7 +179,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { } } - m := unparseConfig(cfg) + m, err := unparseConfig(cfg) + if err != nil { + return err + } mem, sz, err := nv.Marshal(m) if err != nil { return err @@ -459,7 +462,7 @@ func parseDevice(data []byte) (*wgtypes.Device, error) { } // unparsePeerConfig encodes a PeerConfig to a name-value list (nvlist). -func unparsePeerConfig(cfg wgtypes.PeerConfig) nv.List { +func unparsePeerConfig(cfg wgtypes.PeerConfig) (nv.List, error) { m := nv.List{} m["public-key"] = cfg.PublicKey[:] @@ -488,17 +491,21 @@ func unparsePeerConfig(cfg wgtypes.PeerConfig) nv.List { aips := []nv.List{} for _, aip := range cfg.AllowedIPs { - aips = append(aips, unparseAllowedIP(aip)) + if aip.Remove { + return nv.List{}, fmt.Errorf("allowed ips remove not supported: %w", os.ErrInvalid) + } + + aips = append(aips, unparseAllowedIP(aip.IPNet)) } m["allowed-ips"] = aips } - return m + return m, nil } // unparseDevice encodes the device configuration as a FreeBSD name-value list (nvlist). -func unparseConfig(cfg wgtypes.Config) nv.List { +func unparseConfig(cfg wgtypes.Config) (nv.List, error) { m := nv.List{} if v := cfg.PrivateKey; v != nil { @@ -521,12 +528,15 @@ func unparseConfig(cfg wgtypes.Config) nv.List { peers := []nv.List{} for _, p := range v { - peer := unparsePeerConfig(p) + peer, err := unparsePeerConfig(p) + if err != nil { + return nv.List{}, err + } peers = append(peers, peer) } m["peers"] = peers } - return m + return m, nil } diff --git a/internal/wglinux/configure_linux.go b/internal/wglinux/configure_linux.go index bf29092..5769281 100644 --- a/internal/wglinux/configure_linux.go +++ b/internal/wglinux/configure_linux.go @@ -15,6 +15,13 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +// TODO(jrife): These are placeholders. Replace these with values from the +// golang.org/x/sys/unix package after it is updated for the 6.16 kernel. +const ( + WGALLOWEDIP_A_FLAGS = 4 + WGALLOWEDIP_F_REMOVE_ME = 1 +) + // configAttrs creates the required encoded netlink attributes to configure // the device specified by name using the non-nil fields in cfg. func configAttrs(name string, cfg wgtypes.Config) ([]byte, error) { @@ -101,16 +108,16 @@ func buildBatches(cfg wgtypes.Config) []wgtypes.Config { // Iterate until no more allowed IPs. var done bool for !done { - var tmp []net.IPNet + var tmp []wgtypes.AllowedIPConfig if len(p.AllowedIPs) < ipBatchChunk { // IPs all fit within a batch; we are done. - tmp = make([]net.IPNet, len(p.AllowedIPs)) + tmp = make([]wgtypes.AllowedIPConfig, len(p.AllowedIPs)) copy(tmp, p.AllowedIPs) done = true } else { // IPs are larger than a single batch, copy a batch out and // advance the cursor. - tmp = make([]net.IPNet, ipBatchChunk) + tmp = make([]wgtypes.AllowedIPConfig, ipBatchChunk) copy(tmp, p.AllowedIPs[:ipBatchChunk]) p.AllowedIPs = p.AllowedIPs[ipBatchChunk:] @@ -247,7 +254,7 @@ func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) { } // encodeAllowedIPs returns a function to encode allowed IP nested attributes. -func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error { +func encodeAllowedIPs(ipns []wgtypes.AllowedIPConfig) func(ae *netlink.AttributeEncoder) error { return func(ae *netlink.AttributeEncoder) error { for i, ipn := range ipns { if !isValidIP(ipn.IP) { @@ -268,6 +275,9 @@ func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error ones, _ := ipn.Mask.Size() nae.Uint8(unix.WGALLOWEDIP_A_CIDR_MASK, uint8(ones)) + if ipn.Remove { + nae.Uint32(WGALLOWEDIP_A_FLAGS, WGALLOWEDIP_F_REMOVE_ME) + } return nil }) } diff --git a/internal/wguser/configure.go b/internal/wguser/configure.go index 28770e8..3c0f5b1 100644 --- a/internal/wguser/configure.go +++ b/internal/wguser/configure.go @@ -95,11 +95,20 @@ func writeConfig(w io.Writer, cfg wgtypes.Config) { } for _, ip := range p.AllowedIPs { - fmt.Fprintf(w, "allowed_ip=%s\n", ip.String()) + fmt.Fprintf(w, "allowed_ip=%s\n", aipStr(ip)) } } } +func aipStr(aip wgtypes.AllowedIPConfig) string { + s := aip.String() + if aip.Remove { + s = "-" + s + } + + return s +} + // hexKey encodes a wgtypes.Key into a hexadecimal string. func hexKey(k wgtypes.Key) string { return hex.EncodeToString(k[:]) diff --git a/internal/wgwindows/client_windows.go b/internal/wgwindows/client_windows.go index eb4c1c9..f3ff260 100644 --- a/internal/wgwindows/client_windows.go +++ b/internal/wgwindows/client_windows.go @@ -1,6 +1,7 @@ package wgwindows import ( + "fmt" "net" "os" "time" @@ -284,6 +285,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { } b.AppendPeer(peer) for j := range cfg.Peers[i].AllowedIPs { + if cfg.Peers[i].AllowedIPs[j].Remove { + return fmt.Errorf("allowed ips remove not supported: %w", os.ErrInvalid) + } + var family ioctl.AddressFamily var ip net.IP if ip = cfg.Peers[i].AllowedIPs[j].IP.To4(); ip != nil { diff --git a/wgtypes/types.go b/wgtypes/types.go index 3b33b54..5780209 100644 --- a/wgtypes/types.go +++ b/wgtypes/types.go @@ -272,5 +272,15 @@ type PeerConfig struct { // AllowedIPs specifies a list of allowed IP addresses in CIDR notation // for this peer. - AllowedIPs []net.IPNet + AllowedIPs []AllowedIPConfig +} + +// An AllowedIPConfig contains an allowed IP address in CIDR notation and a flag +// indicating whether to add/remove this allowed IP to/from the peer. +type AllowedIPConfig struct { + net.IPNet + + // Remove specifies whether or not to remove this allowed IP from this + // peer. + Remove bool } From b8c2887cde4f1beda25fc652a94a13024ee22c02 Mon Sep 17 00:00:00 2001 From: Jordan Rife Date: Fri, 30 May 2025 20:03:20 -0700 Subject: [PATCH 2/3] Add shim client to simulate allowed IP removal Direct allowed IP removal is a new feature, so its availability is limited for now. Clients who want to take advantage of this capability would need to know ahead of time if WireGuard on their system supports it or probe to see if they're running on a system that supports it. To ease the transition, add the WithShim option for clients: c, err := wgctrl.New(wgctrl.WithShim) WithShim wraps internal clients with a shim client that probes to see if direct IP removal is supported on their system. If not, the shim client emulates the effect by assigning IPs to a dummy peer then removing that peer. At some point in the future, this option should no longer be necessary once the feature becomes commonplace. In my case, I plan to use WithShim to simplify Cilium's WireGuard orchestration logic. Signed-off-by: Jordan Rife --- client.go | 26 ++- client_test.go | 266 ++++++++++++++++++++++++++- internal/wgfreebsd/client_freebsd.go | 4 + internal/wginternal/client.go | 1 + internal/wglinux/client_linux.go | 34 ++++ internal/wgopenbsd/client_openbsd.go | 4 + internal/wgshim/client_shim.go | 124 +++++++++++++ internal/wguser/client.go | 30 +++ internal/wgwindows/client_windows.go | 4 + 9 files changed, 485 insertions(+), 8 deletions(-) create mode 100644 internal/wgshim/client_shim.go diff --git a/client.go b/client.go index 3f6c822..3ca013c 100644 --- a/client.go +++ b/client.go @@ -5,11 +5,22 @@ import ( "os" "golang.zx2c4.com/wireguard/wgctrl/internal/wginternal" + "golang.zx2c4.com/wireguard/wgctrl/internal/wgshim" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// Expose an identical interface to the underlying packages. -var _ wginternal.Client = &Client{} +// An Option configures a client in some way. Options may be provided when +// calling New(). +type Option func(wginternal.Client) wginternal.Client + +// WithShim wraps the client in a shim that probes for the capabilities +// supported by the underlying WireGuard implementation and emulates missing +// capabilities. +// +// This option ensures backwards and forwards compatibility. +func WithShim(c wginternal.Client) wginternal.Client { + return wgshim.New(c) +} // A Client provides access to WireGuard device information. type Client struct { @@ -18,13 +29,20 @@ type Client struct { cs []wginternal.Client } -// New creates a new Client. -func New() (*Client, error) { +// New creates a new Client. Callers may provide a list of Options that modify +// client behavior. +func New(opts ...Option) (*Client, error) { cs, err := newClients() if err != nil { return nil, err } + for _, opt := range opts { + for i := range cs { + cs[i] = opt(cs[i]) + } + } + return &Client{ cs: cs, }, nil diff --git a/client_test.go b/client_test.go index d346c9b..8062cb9 100644 --- a/client_test.go +++ b/client_test.go @@ -2,11 +2,13 @@ package wgctrl import ( "errors" + "net" "os" "testing" "github.com/google/go-cmp/cmp" "golang.zx2c4.com/wireguard/wgctrl/internal/wginternal" + "golang.zx2c4.com/wireguard/wgctrl/internal/wgtest" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -224,11 +226,263 @@ func TestClientConfigureDevice(t *testing.T) { } } +func TestClientConfigureDeviceWithShim(t *testing.T) { + type devicesFunc func() ([]*wgtypes.Device, error) + type supportsFunc func(name string) (bool, error) + + var ( + ip = wgtest.MustCIDR("192.0.2.0/32") + peerKey = wgtest.MustPublicKey() + dummyPeerKey = wgtypes.Key{} + device = "wg0" + + notSupported = func(_ string) (bool, error) { + return false, nil + } + + supported = func(_ string) (bool, error) { + return true, nil + } + + returnsError = func(_ string) (bool, error) { + return false, errFoo + } + + peerHasIP = func() ([]*wgtypes.Device, error) { + return []*wgtypes.Device{ + { + Name: device, + Peers: []wgtypes.Peer{ + { + PublicKey: peerKey, + AllowedIPs: []net.IPNet{ + ip, + }, + }, + }, + }, + }, nil + } + + peerDoesNotHaveIP = func() ([]*wgtypes.Device, error) { + return []*wgtypes.Device{ + { + Name: device, + Peers: []wgtypes.Peer{ + { + PublicKey: peerKey, + AllowedIPs: []net.IPNet{}, + }, + }, + }, + }, nil + } + + otherPeerHasIP = func() ([]*wgtypes.Device, error) { + return []*wgtypes.Device{ + { + Name: device, + Peers: []wgtypes.Peer{ + { + PublicKey: peerKey, + AllowedIPs: []net.IPNet{}, + }, + { + PublicKey: wgtest.MustPublicKey(), + AllowedIPs: []net.IPNet{ + ip, + }, + }, + }, + }, + }, nil + } + + removeAllowedIP = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: ip, + Remove: true, + }, + }, + }, + }, + } + + removeAllowedIPUndone = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: ip, + Remove: true, + }, + { + IPNet: ip, + }, + }, + }, + }, + } + + simulateRemoveAllowedIP = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: dummyPeerKey, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: ip, + }, + }, + }, + { + PublicKey: peerKey, + }, + { + PublicKey: dummyPeerKey, + Remove: true, + }, + }, + } + + addAllowedIP = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: ip, + }, + }, + }, + }, + } + + dontRemoveIP = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + }, + }, + } + ) + + tests := []struct { + name string + supportsFn supportsFunc + devicesFn devicesFunc + cfg wgtypes.Config + expectCfg wgtypes.Config + err error + }{ + { + name: "not supported + remove IP + peer has IP", + supportsFn: notSupported, + devicesFn: peerHasIP, + cfg: removeAllowedIP, + expectCfg: simulateRemoveAllowedIP, + err: nil, + }, + { + name: "not supported + remove IP + peer does not have IP", + supportsFn: notSupported, + devicesFn: peerDoesNotHaveIP, + cfg: removeAllowedIP, + expectCfg: dontRemoveIP, + err: nil, + }, + { + name: "not supported + remove IP + other peer has IP", + supportsFn: notSupported, + devicesFn: otherPeerHasIP, + cfg: removeAllowedIP, + expectCfg: dontRemoveIP, + err: nil, + }, + { + name: "not supported + remove IP undone", + supportsFn: notSupported, + devicesFn: peerHasIP, + cfg: removeAllowedIPUndone, + expectCfg: addAllowedIP, + err: nil, + }, + { + name: "not supported + don't remove IP", + supportsFn: notSupported, + devicesFn: peerHasIP, + cfg: addAllowedIP, + expectCfg: addAllowedIP, + err: nil, + }, + { + name: "supported + remove IP", + supportsFn: supported, + cfg: removeAllowedIP, + expectCfg: removeAllowedIP, + err: nil, + }, + { + name: "supported + don't remove IP", + supportsFn: supported, + cfg: addAllowedIP, + expectCfg: addAllowedIP, + err: nil, + }, + { + name: "probe error + remove IP", + supportsFn: returnsError, + cfg: removeAllowedIP, + expectCfg: wgtypes.Config{}, + err: errFoo, + }, + { + name: "probe error + don't remove IP", + supportsFn: returnsError, + cfg: addAllowedIP, + expectCfg: wgtypes.Config{}, + err: errFoo, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var finalCfg wgtypes.Config + + cs := WithShim(&testClient{ + ConfigureDeviceFunc: func(name string, cfg wgtypes.Config) error { + finalCfg = cfg + + return nil + }, + DevicesFunc: tt.devicesFn, + SupportsAllowedIPRemoveFunc: tt.supportsFn, + }) + + c := &Client{cs: []wginternal.Client{cs}} + + err := c.ConfigureDevice(device, tt.cfg) + if !errors.Is(err, tt.err) { + t.Fatalf("unexpected error: got %s, want %s", err, tt.err) + } + + if diff := cmp.Diff(tt.expectCfg, finalCfg); diff != "" { + t.Fatalf("unexpected config (-want +got):\n%s", diff) + } + }) + } +} + type testClient struct { - CloseFunc func() error - DevicesFunc func() ([]*wgtypes.Device, error) - DeviceFunc func(name string) (*wgtypes.Device, error) - ConfigureDeviceFunc func(name string, cfg wgtypes.Config) error + CloseFunc func() error + DevicesFunc func() ([]*wgtypes.Device, error) + DeviceFunc func(name string) (*wgtypes.Device, error) + ConfigureDeviceFunc func(name string, cfg wgtypes.Config) error + SupportsAllowedIPRemoveFunc func(name string) (bool, error) } func (c *testClient) Close() error { return c.CloseFunc() } @@ -240,3 +494,7 @@ func (c *testClient) Device(name string) (*wgtypes.Device, error) { func (c *testClient) ConfigureDevice(name string, cfg wgtypes.Config) error { return c.ConfigureDeviceFunc(name, cfg) } + +func (c *testClient) SupportsAllowedIPRemove(name string) (bool, error) { + return c.SupportsAllowedIPRemoveFunc(name) +} diff --git a/internal/wgfreebsd/client_freebsd.go b/internal/wgfreebsd/client_freebsd.go index de1940d..3c1adb9 100644 --- a/internal/wgfreebsd/client_freebsd.go +++ b/internal/wgfreebsd/client_freebsd.go @@ -217,6 +217,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { return nil } +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + return false, nil +} + // deviceName converts an interface name string to the format required to pass // with wgh.WGGetServ. func deviceName(name string) ([16]byte, error) { diff --git a/internal/wginternal/client.go b/internal/wginternal/client.go index 4401d72..f80e04b 100644 --- a/internal/wginternal/client.go +++ b/internal/wginternal/client.go @@ -18,4 +18,5 @@ type Client interface { Devices() ([]*wgtypes.Device, error) Device(name string) (*wgtypes.Device, error) ConfigureDevice(name string, cfg wgtypes.Config) error + SupportsAllowedIPRemove(name string) (bool, error) } diff --git a/internal/wglinux/client_linux.go b/internal/wglinux/client_linux.go index 8d8a753..670e01f 100644 --- a/internal/wglinux/client_linux.go +++ b/internal/wglinux/client_linux.go @@ -6,6 +6,7 @@ package wglinux import ( "errors" "fmt" + "net" "os" "syscall" @@ -144,6 +145,39 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { return nil } +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + err := c.ConfigureDevice(name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: wgtypes.Key{}, + }, + { + PublicKey: wgtypes.Key{}, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: net.IPNet{ + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 8*net.IPv6len), + }, + Remove: true, + }, + }, + }, + { + PublicKey: wgtypes.Key{}, + Remove: true, + }, + }, + }) + + var errno syscall.Errno + if errors.As(err, &errno) && errno == unix.EINVAL { + return false, nil + } + + return err == nil, err +} + // execute executes a single WireGuard netlink request with the specified command, // header flags, and attribute arguments. func (c *Client) execute(command uint8, flags netlink.HeaderFlags, attrb []byte) ([]genetlink.Message, error) { diff --git a/internal/wgopenbsd/client_openbsd.go b/internal/wgopenbsd/client_openbsd.go index e2dba81..23cae2e 100644 --- a/internal/wgopenbsd/client_openbsd.go +++ b/internal/wgopenbsd/client_openbsd.go @@ -233,6 +233,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { return wginternal.ErrReadOnly } +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + return false, nil +} + // deviceName converts an interface name string to the format required to pass // with wgh.WGGetServ. func deviceName(name string) ([16]byte, error) { diff --git a/internal/wgshim/client_shim.go b/internal/wgshim/client_shim.go new file mode 100644 index 0000000..3404440 --- /dev/null +++ b/internal/wgshim/client_shim.go @@ -0,0 +1,124 @@ +package wgshim + +import ( + "fmt" + + "golang.zx2c4.com/wireguard/wgctrl/internal/wginternal" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +var _ wginternal.Client = &Client{} + +type Client struct { + wginternal.Client + + probed bool + simulateIPRemoval bool +} + +func New(c wginternal.Client) *Client { + return &Client{ + Client: c, + } +} + +func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { + if !c.probed { + if supports, err := c.Client.SupportsAllowedIPRemove(name); err != nil { + return fmt.Errorf("probing capabilities: %w", err) + } else { + c.simulateIPRemoval = !supports + } + + c.probed = true + } + + if c.simulateIPRemoval { + devices, err := c.Devices() + if err != nil { + return fmt.Errorf("querying devices: %w", err) + } + + cfg = simulateAllowedIPRemovals(cfg, allowedIPs(device(devices, name))) + } + + return c.Client.ConfigureDevice(name, cfg) +} + +func device(devices []*wgtypes.Device, name string) *wgtypes.Device { + for _, d := range devices { + if d.Name == name { + return d + } + } + + return nil +} + +func allowedIPs(device *wgtypes.Device) map[wgtypes.Key]map[string]bool { + aips := make(map[wgtypes.Key]map[string]bool) + + if device != nil { + for _, peer := range device.Peers { + aips[peer.PublicKey] = make(map[string]bool) + + for _, aip := range peer.AllowedIPs { + aips[peer.PublicKey][aip.String()] = true + } + } + } + + return aips +} + +func simulateAllowedIPRemovals(cfg wgtypes.Config, current map[wgtypes.Key]map[string]bool) wgtypes.Config { + newCfg := cfg + newCfg.Peers = make([]wgtypes.PeerConfig, 0, len(cfg.Peers)) + + for _, peer := range cfg.Peers { + // Keep track if the last instance of each IPNet is a removal. + removed := make(map[string]bool) + for _, aip := range peer.AllowedIPs { + removed[aip.String()] = aip.Remove + } + + newPeer := peer + newPeer.AllowedIPs = nil + dummyPeer := wgtypes.PeerConfig{ + PublicKey: wgtypes.Key{}, + } + + for _, aip := range peer.AllowedIPs { + if aip.Remove != removed[aip.String()] { + continue + } + + if aip.Remove { + // Do nothing if aip is not currently owned by peer. + if c := current[peer.PublicKey]; c != nil && c[aip.String()] { + dummyPeer.AllowedIPs = append(dummyPeer.AllowedIPs, aip) + dummyPeer.AllowedIPs[len(dummyPeer.AllowedIPs)-1].Remove = false + } + } else { + newPeer.AllowedIPs = append(newPeer.AllowedIPs, aip) + } + } + + if len(dummyPeer.AllowedIPs) > 0 { + newCfg.Peers = append(newCfg.Peers, + // Move allowed IPs marked with Remove to + // dummy peer. + dummyPeer, + newPeer, + // Clean up dummy peer. + wgtypes.PeerConfig{ + PublicKey: wgtypes.Key{}, + Remove: true, + }) + } else { + newCfg.Peers = append(newCfg.Peers, newPeer) + } + } + + return newCfg +} diff --git a/internal/wguser/client.go b/internal/wguser/client.go index 9d4954e..4144941 100644 --- a/internal/wguser/client.go +++ b/internal/wguser/client.go @@ -1,6 +1,7 @@ package wguser import ( + "errors" "fmt" "net" "os" @@ -89,6 +90,35 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { return os.ErrNotExist } +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + err := c.ConfigureDevice(name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: wgtypes.Key{}, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: net.IPNet{ + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 8*net.IPv6len), + }, + Remove: true, + }, + }, + // Don't create the peer if sucessful; we just + // want to submit the request for validation. + UpdateOnly: true, + }, + }, + }) + + var sysErr *os.SyscallError + if errors.As(err, &sysErr) && strings.Contains(sysErr.Error(), "-22") { + return false, nil + } + + return err == nil, err +} + // deviceName infers a device name from an absolute file path with extension. func deviceName(sock string) string { return strings.TrimSuffix(filepath.Base(sock), filepath.Ext(sock)) diff --git a/internal/wgwindows/client_windows.go b/internal/wgwindows/client_windows.go index f3ff260..a65a469 100644 --- a/internal/wgwindows/client_windows.go +++ b/internal/wgwindows/client_windows.go @@ -310,3 +310,7 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { interfaze, size := b.Interface() return windows.DeviceIoControl(handle, ioctl.IoctlSet, nil, 0, (*byte)(unsafe.Pointer(interfaze)), size, &size, nil) } + +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + return false, nil +} From 56a82491e8a846f85953d9fde4caf6fa929a0088 Mon Sep 17 00:00:00 2001 From: Jordan Rife Date: Thu, 29 May 2025 12:06:24 -0700 Subject: [PATCH 3/3] Add integration tests for allowed IP removal MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add testRemoveManyIPs to the integration tests to exercise the direct allowed IP removal capability and run the test suite on all platforms. $ WGCTRL_INTEGRATION=yesreallydoit go test . ┏━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓ ┃ OPERATING SYSTEM ┃ DRIVER ┃ REMOVE IP SUPPORTED ┃ RESULT ┃ ┡━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩ │ FreeBSD 14.2 │ native │ no │ PASS │ ├──────────────────┼──────────────┼─────────────────────┼────────┤ │ OpenBSD 7.7 │ native │ no │ PASS* │ ├──────────────────┼──────────────┼─────────────────────┼────────┤ │ Windows 11 │ native │ no │ PASS** │ ├──────────────────┼──────────────┼─────────────────────┼────────┤ │ Linux │ native │ no │ PASS │ ├──────────────────┼──────────────┼─────────────────────┼────────┤ │ Linux │ wireguard-go │ no │ PASS │ ├──────────────────┼──────────────┼─────────────────────┼────────┤ │ Linux │ native │ yes │ PASS │ ├──────────────────┼──────────────┼─────────────────────┼────────┤ │ Linux │ wireguard-go │ yes │ PASS │ └──────────────────┴──────────────┴─────────────────────┴────────┘ I compiled Linux from the bpf-next/master tree which includes commit ba3d7b93dbe3 ("wireguard: allowedips: add WGALLOWEDIP_F_REMOVE_ME flag") and wireguard-go from the head of master which includes commit 256bcbd70d5b ("device: add support for removing allowedips individually") to test platforms with native support. On systems where direct IP removal is not supported, I also made sure that ConfigureDevice returns an error when Remove is used without the shim. * OpenBSD skips this test case, since the driver is read only. ** Two assertions fail in Windows due to missing protocol version, but testRemoveManyIPs passes. Signed-off-by: Jordan Rife --- client_integration_test.go | 109 +++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 4 deletions(-) diff --git a/client_integration_test.go b/client_integration_test.go index 60dd975..18f5bfb 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -20,7 +20,7 @@ import ( ) func TestIntegrationClient(t *testing.T) { - c, done := integrationClient(t) + c, done := integrationClient(t, false) defer done() devices, err := c.Devices() @@ -43,6 +43,14 @@ func TestIntegrationClient(t *testing.T) { name: "configure", fn: testConfigure, }, + { + name: "remove many IPs", + fn: func(t *testing.T, _ *wgctrl.Client, d *wgtypes.Device) { + c, done := integrationClient(t, true) + defer done() + testRemoveManyIPs(t, c, d) + }, + }, { name: "configure many IPs", fn: testConfigureManyIPs, @@ -91,7 +99,7 @@ func TestIntegrationClient(t *testing.T) { } func TestIntegrationClientIsNotExist(t *testing.T) { - c, done := integrationClient(t) + c, done := integrationClient(t, false) defer done() if _, err := c.Device("wgnotexist0"); !errors.Is(err, os.ErrNotExist) { @@ -99,7 +107,7 @@ func TestIntegrationClientIsNotExist(t *testing.T) { } } -func integrationClient(t *testing.T) (*wgctrl.Client, func()) { +func integrationClient(t *testing.T, useShim bool) (*wgctrl.Client, func()) { t.Helper() const ( @@ -112,7 +120,12 @@ func integrationClient(t *testing.T) (*wgctrl.Client, func()) { env, confirm) } - c, err := wgctrl.New() + var opts []wgctrl.Option + if useShim { + opts = append(opts, wgctrl.WithShim) + } + + c, err := wgctrl.New(opts...) if err != nil { if errors.Is(err, os.ErrNotExist) { t.Skip("skipping, wgctrl is not available on this system") @@ -223,6 +236,94 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { t.Log(out) } +func testRemoveManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { + // Apply 511 IPs per peer. + var ( + countIPs int + peers []wgtypes.PeerConfig + peersRemoveIPs []wgtypes.PeerConfig + ) + + for i := 0; i < 2; i++ { + cidr := "2001:db8::/119" + if i == 1 { + cidr = "2001:db8:ffff::/119" + } + + cur, err := ipaddr.Parse(cidr) + if err != nil { + t.Fatalf("failed to create cursor: %v", err) + } + + var ips []net.IPNet + for pos := cur.Next(); pos != nil; pos = cur.Next() { + bits := 128 + if pos.IP.To4() != nil { + bits = 32 + } + + ips = append(ips, net.IPNet{ + IP: pos.IP, + Mask: net.CIDRMask(bits, bits), + }) + } + + peers = append(peers, wgtypes.PeerConfig{ + PublicKey: wgtest.MustPublicKey(), + ReplaceAllowedIPs: true, + AllowedIPs: ipsToAllowedIPConfig(ips), + }) + + peersRemoveIPs = append(peersRemoveIPs, wgtypes.PeerConfig{ + PublicKey: peers[len(peers)-1].PublicKey, + AllowedIPs: ipsToAllowedIPConfig(ips), + }) + + // Remove every other IP + for i := range peersRemoveIPs[len(peersRemoveIPs)-1].AllowedIPs { + peersRemoveIPs[len(peersRemoveIPs)-1].AllowedIPs[i].Remove = i%2 == 0 + } + + countIPs += len(ips) + } + + cfg := wgtypes.Config{ + ReplacePeers: true, + Peers: peers, + } + removeCfg := wgtypes.Config{ + Peers: peersRemoveIPs, + } + + tryConfigure(t, c, d.Name, cfg) + + dn, err := c.Device(d.Name) + if err != nil { + t.Fatalf("failed to get %q by name: %v", d.Name, err) + } + + peerIPs := countPeerIPs(dn) + if diff := cmp.Diff(countIPs, peerIPs); diff != "" { + t.Fatalf("unexpected number of configured peer IPs (-want +got):\n%s", diff) + } + + t.Logf("device: %s: %d IPs", d.Name, peerIPs) + + tryConfigure(t, c, d.Name, removeCfg) + + dn, err = c.Device(d.Name) + if err != nil { + t.Fatalf("failed to get %q by name: %v", d.Name, err) + } + + peerIPs = countPeerIPs(dn) + if diff := cmp.Diff(countIPs/2-1, peerIPs); diff != "" { + t.Fatalf("unexpected number of configured peer IPs (-want +got):\n%s", diff) + } + + t.Logf("device: %s: %d IPs after remove", d.Name, peerIPs) +} + func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { // Apply 511 IPs per peer. var (