Skip to content

Commit b8c2887

Browse files
committed
Add shim client to simulate allowed IP removal
Direct allowed IP removal is a new feature, so its availability is limited for now. Clients who want to take advantage of this capability would need to know ahead of time if WireGuard on their system supports it or probe to see if they're running on a system that supports it. To ease the transition, add the WithShim option for clients: c, err := wgctrl.New(wgctrl.WithShim) WithShim wraps internal clients with a shim client that probes to see if direct IP removal is supported on their system. If not, the shim client emulates the effect by assigning IPs to a dummy peer then removing that peer. At some point in the future, this option should no longer be necessary once the feature becomes commonplace. In my case, I plan to use WithShim to simplify Cilium's WireGuard orchestration logic. Signed-off-by: Jordan Rife <[email protected]>
1 parent f1db8c9 commit b8c2887

File tree

9 files changed

+485
-8
lines changed

9 files changed

+485
-8
lines changed

client.go

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,22 @@ import (
55
"os"
66

77
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
8+
"golang.zx2c4.com/wireguard/wgctrl/internal/wgshim"
89
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
910
)
1011

11-
// Expose an identical interface to the underlying packages.
12-
var _ wginternal.Client = &Client{}
12+
// An Option configures a client in some way. Options may be provided when
13+
// calling New().
14+
type Option func(wginternal.Client) wginternal.Client
15+
16+
// WithShim wraps the client in a shim that probes for the capabilities
17+
// supported by the underlying WireGuard implementation and emulates missing
18+
// capabilities.
19+
//
20+
// This option ensures backwards and forwards compatibility.
21+
func WithShim(c wginternal.Client) wginternal.Client {
22+
return wgshim.New(c)
23+
}
1324

1425
// A Client provides access to WireGuard device information.
1526
type Client struct {
@@ -18,13 +29,20 @@ type Client struct {
1829
cs []wginternal.Client
1930
}
2031

21-
// New creates a new Client.
22-
func New() (*Client, error) {
32+
// New creates a new Client. Callers may provide a list of Options that modify
33+
// client behavior.
34+
func New(opts ...Option) (*Client, error) {
2335
cs, err := newClients()
2436
if err != nil {
2537
return nil, err
2638
}
2739

40+
for _, opt := range opts {
41+
for i := range cs {
42+
cs[i] = opt(cs[i])
43+
}
44+
}
45+
2846
return &Client{
2947
cs: cs,
3048
}, nil

client_test.go

Lines changed: 262 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package wgctrl
22

33
import (
44
"errors"
5+
"net"
56
"os"
67
"testing"
78

89
"github.com/google/go-cmp/cmp"
910
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
11+
"golang.zx2c4.com/wireguard/wgctrl/internal/wgtest"
1012
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
1113
)
1214

@@ -224,11 +226,263 @@ func TestClientConfigureDevice(t *testing.T) {
224226
}
225227
}
226228

229+
func TestClientConfigureDeviceWithShim(t *testing.T) {
230+
type devicesFunc func() ([]*wgtypes.Device, error)
231+
type supportsFunc func(name string) (bool, error)
232+
233+
var (
234+
ip = wgtest.MustCIDR("192.0.2.0/32")
235+
peerKey = wgtest.MustPublicKey()
236+
dummyPeerKey = wgtypes.Key{}
237+
device = "wg0"
238+
239+
notSupported = func(_ string) (bool, error) {
240+
return false, nil
241+
}
242+
243+
supported = func(_ string) (bool, error) {
244+
return true, nil
245+
}
246+
247+
returnsError = func(_ string) (bool, error) {
248+
return false, errFoo
249+
}
250+
251+
peerHasIP = func() ([]*wgtypes.Device, error) {
252+
return []*wgtypes.Device{
253+
{
254+
Name: device,
255+
Peers: []wgtypes.Peer{
256+
{
257+
PublicKey: peerKey,
258+
AllowedIPs: []net.IPNet{
259+
ip,
260+
},
261+
},
262+
},
263+
},
264+
}, nil
265+
}
266+
267+
peerDoesNotHaveIP = func() ([]*wgtypes.Device, error) {
268+
return []*wgtypes.Device{
269+
{
270+
Name: device,
271+
Peers: []wgtypes.Peer{
272+
{
273+
PublicKey: peerKey,
274+
AllowedIPs: []net.IPNet{},
275+
},
276+
},
277+
},
278+
}, nil
279+
}
280+
281+
otherPeerHasIP = func() ([]*wgtypes.Device, error) {
282+
return []*wgtypes.Device{
283+
{
284+
Name: device,
285+
Peers: []wgtypes.Peer{
286+
{
287+
PublicKey: peerKey,
288+
AllowedIPs: []net.IPNet{},
289+
},
290+
{
291+
PublicKey: wgtest.MustPublicKey(),
292+
AllowedIPs: []net.IPNet{
293+
ip,
294+
},
295+
},
296+
},
297+
},
298+
}, nil
299+
}
300+
301+
removeAllowedIP = wgtypes.Config{
302+
Peers: []wgtypes.PeerConfig{
303+
{
304+
PublicKey: peerKey,
305+
AllowedIPs: []wgtypes.AllowedIPConfig{
306+
{
307+
IPNet: ip,
308+
Remove: true,
309+
},
310+
},
311+
},
312+
},
313+
}
314+
315+
removeAllowedIPUndone = wgtypes.Config{
316+
Peers: []wgtypes.PeerConfig{
317+
{
318+
PublicKey: peerKey,
319+
AllowedIPs: []wgtypes.AllowedIPConfig{
320+
{
321+
IPNet: ip,
322+
Remove: true,
323+
},
324+
{
325+
IPNet: ip,
326+
},
327+
},
328+
},
329+
},
330+
}
331+
332+
simulateRemoveAllowedIP = wgtypes.Config{
333+
Peers: []wgtypes.PeerConfig{
334+
{
335+
PublicKey: dummyPeerKey,
336+
AllowedIPs: []wgtypes.AllowedIPConfig{
337+
{
338+
IPNet: ip,
339+
},
340+
},
341+
},
342+
{
343+
PublicKey: peerKey,
344+
},
345+
{
346+
PublicKey: dummyPeerKey,
347+
Remove: true,
348+
},
349+
},
350+
}
351+
352+
addAllowedIP = wgtypes.Config{
353+
Peers: []wgtypes.PeerConfig{
354+
{
355+
PublicKey: peerKey,
356+
AllowedIPs: []wgtypes.AllowedIPConfig{
357+
{
358+
IPNet: ip,
359+
},
360+
},
361+
},
362+
},
363+
}
364+
365+
dontRemoveIP = wgtypes.Config{
366+
Peers: []wgtypes.PeerConfig{
367+
{
368+
PublicKey: peerKey,
369+
},
370+
},
371+
}
372+
)
373+
374+
tests := []struct {
375+
name string
376+
supportsFn supportsFunc
377+
devicesFn devicesFunc
378+
cfg wgtypes.Config
379+
expectCfg wgtypes.Config
380+
err error
381+
}{
382+
{
383+
name: "not supported + remove IP + peer has IP",
384+
supportsFn: notSupported,
385+
devicesFn: peerHasIP,
386+
cfg: removeAllowedIP,
387+
expectCfg: simulateRemoveAllowedIP,
388+
err: nil,
389+
},
390+
{
391+
name: "not supported + remove IP + peer does not have IP",
392+
supportsFn: notSupported,
393+
devicesFn: peerDoesNotHaveIP,
394+
cfg: removeAllowedIP,
395+
expectCfg: dontRemoveIP,
396+
err: nil,
397+
},
398+
{
399+
name: "not supported + remove IP + other peer has IP",
400+
supportsFn: notSupported,
401+
devicesFn: otherPeerHasIP,
402+
cfg: removeAllowedIP,
403+
expectCfg: dontRemoveIP,
404+
err: nil,
405+
},
406+
{
407+
name: "not supported + remove IP undone",
408+
supportsFn: notSupported,
409+
devicesFn: peerHasIP,
410+
cfg: removeAllowedIPUndone,
411+
expectCfg: addAllowedIP,
412+
err: nil,
413+
},
414+
{
415+
name: "not supported + don't remove IP",
416+
supportsFn: notSupported,
417+
devicesFn: peerHasIP,
418+
cfg: addAllowedIP,
419+
expectCfg: addAllowedIP,
420+
err: nil,
421+
},
422+
{
423+
name: "supported + remove IP",
424+
supportsFn: supported,
425+
cfg: removeAllowedIP,
426+
expectCfg: removeAllowedIP,
427+
err: nil,
428+
},
429+
{
430+
name: "supported + don't remove IP",
431+
supportsFn: supported,
432+
cfg: addAllowedIP,
433+
expectCfg: addAllowedIP,
434+
err: nil,
435+
},
436+
{
437+
name: "probe error + remove IP",
438+
supportsFn: returnsError,
439+
cfg: removeAllowedIP,
440+
expectCfg: wgtypes.Config{},
441+
err: errFoo,
442+
},
443+
{
444+
name: "probe error + don't remove IP",
445+
supportsFn: returnsError,
446+
cfg: addAllowedIP,
447+
expectCfg: wgtypes.Config{},
448+
err: errFoo,
449+
},
450+
}
451+
452+
for _, tt := range tests {
453+
t.Run(tt.name, func(t *testing.T) {
454+
var finalCfg wgtypes.Config
455+
456+
cs := WithShim(&testClient{
457+
ConfigureDeviceFunc: func(name string, cfg wgtypes.Config) error {
458+
finalCfg = cfg
459+
460+
return nil
461+
},
462+
DevicesFunc: tt.devicesFn,
463+
SupportsAllowedIPRemoveFunc: tt.supportsFn,
464+
})
465+
466+
c := &Client{cs: []wginternal.Client{cs}}
467+
468+
err := c.ConfigureDevice(device, tt.cfg)
469+
if !errors.Is(err, tt.err) {
470+
t.Fatalf("unexpected error: got %s, want %s", err, tt.err)
471+
}
472+
473+
if diff := cmp.Diff(tt.expectCfg, finalCfg); diff != "" {
474+
t.Fatalf("unexpected config (-want +got):\n%s", diff)
475+
}
476+
})
477+
}
478+
}
479+
227480
type testClient struct {
228-
CloseFunc func() error
229-
DevicesFunc func() ([]*wgtypes.Device, error)
230-
DeviceFunc func(name string) (*wgtypes.Device, error)
231-
ConfigureDeviceFunc func(name string, cfg wgtypes.Config) error
481+
CloseFunc func() error
482+
DevicesFunc func() ([]*wgtypes.Device, error)
483+
DeviceFunc func(name string) (*wgtypes.Device, error)
484+
ConfigureDeviceFunc func(name string, cfg wgtypes.Config) error
485+
SupportsAllowedIPRemoveFunc func(name string) (bool, error)
232486
}
233487

234488
func (c *testClient) Close() error { return c.CloseFunc() }
@@ -240,3 +494,7 @@ func (c *testClient) Device(name string) (*wgtypes.Device, error) {
240494
func (c *testClient) ConfigureDevice(name string, cfg wgtypes.Config) error {
241495
return c.ConfigureDeviceFunc(name, cfg)
242496
}
497+
498+
func (c *testClient) SupportsAllowedIPRemove(name string) (bool, error) {
499+
return c.SupportsAllowedIPRemoveFunc(name)
500+
}

internal/wgfreebsd/client_freebsd.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error {
217217
return nil
218218
}
219219

220+
func (c *Client) SupportsAllowedIPRemove(name string) (bool, error) {
221+
return false, nil
222+
}
223+
220224
// deviceName converts an interface name string to the format required to pass
221225
// with wgh.WGGetServ.
222226
func deviceName(name string) ([16]byte, error) {

internal/wginternal/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ type Client interface {
1818
Devices() ([]*wgtypes.Device, error)
1919
Device(name string) (*wgtypes.Device, error)
2020
ConfigureDevice(name string, cfg wgtypes.Config) error
21+
SupportsAllowedIPRemove(name string) (bool, error)
2122
}

0 commit comments

Comments
 (0)