Skip to content

Commit 616ba7f

Browse files
committed
Add shim client to simulate allowed IP removal
Direct allowed IP removal is a new feature, so its availability is limited for now. Clients who want to take advantage of this capability would need to know ahead of time if WireGuard on their system supports it or probe to see if they're running on a system that supports it. To ease the transition, add the WithShim option for clients: c, err := wgctrl.New(wgctrl.WithShim) WithShim wraps internal clients with a shim client that probes to see if direct IP removal is supported on their system. If not, the shim client emulates the effect by assigning IPs to a dummy peer then removing that peer. At some point in the future, this option should no longer be necessary once the feature becomes commonplace. In my case, I plan to use WithShim to simplify Cilium's WireGuard orchestration logic. Signed-off-by: Jordan Rife <[email protected]>
1 parent 75de1a4 commit 616ba7f

File tree

9 files changed

+362
-9
lines changed

9 files changed

+362
-9
lines changed

client.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,22 @@ import (
55
"os"
66

77
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
8+
"golang.zx2c4.com/wireguard/wgctrl/internal/wgshim"
89
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
910
)
1011

11-
// Expose an identical interface to the underlying packages.
12-
var _ wginternal.Client = &Client{}
12+
// An Option configures a client in some way. Options may be provided when
13+
// calling New().
14+
type Option func(wginternal.Client) wginternal.Client
15+
16+
// WithShim wraps the client in a shim that probes for the capabilities
17+
// supported by the underlying WireGuard implementation and emulates missing
18+
// capabilities.
19+
//
20+
// This option ensures backwards and forwards compatibility.
21+
func WithShim(c wginternal.Client) wginternal.Client {
22+
return wgshim.New(c)
23+
}
1324

1425
// A Client provides access to WireGuard device information.
1526
type Client struct {
@@ -18,13 +29,20 @@ type Client struct {
1829
cs []wginternal.Client
1930
}
2031

21-
// New creates a new Client.
22-
func New() (*Client, error) {
32+
// New creates a new Client. Callers may provide a list of Options that modify
33+
// client behavior.
34+
func New(opts ...Option) (*Client, error) {
2335
cs, err := newClients()
2436
if err != nil {
2537
return nil, err
2638
}
2739

40+
for _, opt := range opts {
41+
for i := range cs {
42+
cs[i] = opt(cs[i])
43+
}
44+
}
45+
2846
return &Client{
2947
cs: cs,
3048
}, nil

client_test.go

Lines changed: 177 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
"github.com/google/go-cmp/cmp"
99
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
10+
"golang.zx2c4.com/wireguard/wgctrl/internal/wgtest"
1011
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
1112
)
1213

@@ -224,11 +225,182 @@ func TestClientConfigureDevice(t *testing.T) {
224225
}
225226
}
226227

228+
func TestClientConfigureDeviceWithShim(t *testing.T) {
229+
type supportsFunc func(name string) (bool, error)
230+
231+
var (
232+
dummyPeerKey = wgtypes.Key{}
233+
peerKey = wgtest.MustPublicKey()
234+
ip = wgtest.MustCIDR("192.0.2.0/32")
235+
236+
notSupported = func(_ string) (bool, error) {
237+
return false, nil
238+
}
239+
240+
supported = func(_ string) (bool, error) {
241+
return true, nil
242+
}
243+
244+
returnsError = func(_ string) (bool, error) {
245+
return false, errFoo
246+
}
247+
248+
removeAllowedIP = wgtypes.Config{
249+
Peers: []wgtypes.PeerConfig{
250+
{
251+
PublicKey: peerKey,
252+
AllowedIPs: []wgtypes.AllowedIPConfig{
253+
{
254+
IPNet: ip,
255+
Remove: true,
256+
},
257+
},
258+
},
259+
},
260+
}
261+
262+
removeAllowedIPUndone = wgtypes.Config{
263+
Peers: []wgtypes.PeerConfig{
264+
{
265+
PublicKey: peerKey,
266+
AllowedIPs: []wgtypes.AllowedIPConfig{
267+
{
268+
IPNet: ip,
269+
Remove: true,
270+
},
271+
{
272+
IPNet: ip,
273+
},
274+
},
275+
},
276+
},
277+
}
278+
279+
simulatedRemoveAllowedIP = wgtypes.Config{
280+
Peers: []wgtypes.PeerConfig{
281+
{
282+
PublicKey: dummyPeerKey,
283+
AllowedIPs: []wgtypes.AllowedIPConfig{
284+
{
285+
IPNet: ip,
286+
},
287+
},
288+
},
289+
{
290+
PublicKey: peerKey,
291+
},
292+
{
293+
PublicKey: dummyPeerKey,
294+
Remove: true,
295+
},
296+
},
297+
}
298+
299+
dontRemoveAllowedIP = wgtypes.Config{
300+
Peers: []wgtypes.PeerConfig{
301+
{
302+
PublicKey: peerKey,
303+
AllowedIPs: []wgtypes.AllowedIPConfig{
304+
{
305+
IPNet: ip,
306+
},
307+
},
308+
},
309+
},
310+
}
311+
)
312+
313+
tests := []struct {
314+
name string
315+
fn supportsFunc
316+
cfg wgtypes.Config
317+
expectCfg wgtypes.Config
318+
err error
319+
}{
320+
{
321+
name: "not supported + remove IP",
322+
fn: notSupported,
323+
cfg: removeAllowedIP,
324+
expectCfg: simulatedRemoveAllowedIP,
325+
err: nil,
326+
},
327+
{
328+
name: "not supported + remove IP undone",
329+
fn: notSupported,
330+
cfg: removeAllowedIPUndone,
331+
expectCfg: dontRemoveAllowedIP,
332+
err: nil,
333+
},
334+
{
335+
name: "not supported + don't remove IP",
336+
fn: notSupported,
337+
cfg: dontRemoveAllowedIP,
338+
expectCfg: dontRemoveAllowedIP,
339+
err: nil,
340+
},
341+
{
342+
name: "supported + remove IP",
343+
fn: supported,
344+
cfg: removeAllowedIP,
345+
expectCfg: removeAllowedIP,
346+
err: nil,
347+
},
348+
{
349+
name: "supported + don't remove IP",
350+
fn: supported,
351+
cfg: dontRemoveAllowedIP,
352+
expectCfg: dontRemoveAllowedIP,
353+
err: nil,
354+
},
355+
{
356+
name: "probe error + remove IP",
357+
fn: returnsError,
358+
cfg: removeAllowedIP,
359+
expectCfg: wgtypes.Config{},
360+
err: errFoo,
361+
},
362+
{
363+
name: "probe error + don't remove IP",
364+
fn: returnsError,
365+
cfg: dontRemoveAllowedIP,
366+
expectCfg: wgtypes.Config{},
367+
err: errFoo,
368+
},
369+
}
370+
371+
for _, tt := range tests {
372+
t.Run(tt.name, func(t *testing.T) {
373+
var finalCfg wgtypes.Config
374+
375+
cs := WithShim(&testClient{
376+
ConfigureDeviceFunc: func(name string, cfg wgtypes.Config) error {
377+
finalCfg = cfg
378+
379+
return nil
380+
},
381+
SupportsAllowedIPRemoveFunc: tt.fn,
382+
})
383+
384+
c := &Client{cs: []wginternal.Client{cs}}
385+
386+
err := c.ConfigureDevice("", tt.cfg)
387+
if !errors.Is(err, tt.err) {
388+
t.Fatalf("unexpected error: got %s, want %s", err, tt.err)
389+
}
390+
391+
if diff := cmp.Diff(tt.expectCfg, finalCfg); diff != "" {
392+
t.Fatalf("unexpected config (-want +got):\n%s", diff)
393+
}
394+
})
395+
}
396+
}
397+
227398
type testClient struct {
228-
CloseFunc func() error
229-
DevicesFunc func() ([]*wgtypes.Device, error)
230-
DeviceFunc func(name string) (*wgtypes.Device, error)
231-
ConfigureDeviceFunc func(name string, cfg wgtypes.Config) error
399+
CloseFunc func() error
400+
DevicesFunc func() ([]*wgtypes.Device, error)
401+
DeviceFunc func(name string) (*wgtypes.Device, error)
402+
ConfigureDeviceFunc func(name string, cfg wgtypes.Config) error
403+
SupportsAllowedIPRemoveFunc func(name string) (bool, error)
232404
}
233405

234406
func (c *testClient) Close() error { return c.CloseFunc() }
@@ -242,5 +414,5 @@ func (c *testClient) ConfigureDevice(name string, cfg wgtypes.Config) error {
242414
}
243415

244416
func (c *testClient) SupportsAllowedIPRemove(name string) (bool, error) {
245-
return false, nil
417+
return c.SupportsAllowedIPRemoveFunc(name)
246418
}

internal/wgfreebsd/client_freebsd.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
217217
return nil
218218
}
219219

220+
func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) {
221+
return false, nil
222+
}
223+
220224
// deviceName converts an interface name string to the format required to pass
221225
// with wgh.WGGetServ.
222226
func deviceName(name string) ([16]byte, error) {

internal/wginternal/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ type Client interface {
1818
Devices() ([]*wgtypes.Device, error)
1919
Device(name string) (*wgtypes.Device, error)
2020
ConfigureDevice(name string, cfg wgtypes.Config) error
21+
SupportsAllowedIPRemove(name string) (bool, error)
2122
}

internal/wglinux/client_linux.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package wglinux
66
import (
77
"errors"
88
"fmt"
9+
"net"
910
"os"
1011
"syscall"
1112

@@ -144,6 +145,39 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
144145
return nil
145146
}
146147

148+
func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) {
149+
err := c.ConfigureDevice(name, wgtypes.Config{
150+
Peers: []wgtypes.PeerConfig{
151+
{
152+
PublicKey: wgtypes.Key{},
153+
},
154+
{
155+
PublicKey: wgtypes.Key{},
156+
AllowedIPs: []wgtypes.AllowedIPConfig{
157+
{
158+
IPNet: net.IPNet{
159+
IP: net.IPv6zero,
160+
Mask: net.CIDRMask(0, 8*net.IPv6len),
161+
},
162+
Remove: true,
163+
},
164+
},
165+
},
166+
{
167+
PublicKey: wgtypes.Key{},
168+
Remove: true,
169+
},
170+
},
171+
})
172+
173+
var errno syscall.Errno
174+
if errors.As(err, &errno) && errno == unix.EINVAL {
175+
return false, nil
176+
}
177+
178+
return err == nil, err
179+
}
180+
147181
// execute executes a single WireGuard netlink request with the specified command,
148182
// header flags, and attribute arguments.
149183
func (c *Client) execute(command uint8, flags netlink.HeaderFlags, attrb []byte) ([]genetlink.Message, error) {

internal/wgopenbsd/client_openbsd.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
233233
return wginternal.ErrReadOnly
234234
}
235235

236+
func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) {
237+
return false, nil
238+
}
239+
236240
// deviceName converts an interface name string to the format required to pass
237241
// with wgh.WGGetServ.
238242
func deviceName(name string) ([16]byte, error) {

0 commit comments

Comments
 (0)