@@ -2,11 +2,13 @@ package wgctrl
22
33import (
44 "errors"
5+ "net"
56 "os"
67 "testing"
78
89 "github.com/google/go-cmp/cmp"
910 "golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
11+ "golang.zx2c4.com/wireguard/wgctrl/internal/wgtest"
1012 "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
1113)
1214
@@ -224,11 +226,263 @@ func TestClientConfigureDevice(t *testing.T) {
224226 }
225227}
226228
229+ func TestClientConfigureDeviceWithShim (t * testing.T ) {
230+ type devicesFunc func () ([]* wgtypes.Device , error )
231+ type supportsFunc func (name string ) (bool , error )
232+
233+ var (
234+ ip = wgtest .MustCIDR ("192.0.2.0/32" )
235+ peerKey = wgtest .MustPublicKey ()
236+ dummyPeerKey = wgtypes.Key {}
237+ device = "wg0"
238+
239+ notSupported = func (_ string ) (bool , error ) {
240+ return false , nil
241+ }
242+
243+ supported = func (_ string ) (bool , error ) {
244+ return true , nil
245+ }
246+
247+ returnsError = func (_ string ) (bool , error ) {
248+ return false , errFoo
249+ }
250+
251+ peerHasIP = func () ([]* wgtypes.Device , error ) {
252+ return []* wgtypes.Device {
253+ {
254+ Name : device ,
255+ Peers : []wgtypes.Peer {
256+ {
257+ PublicKey : peerKey ,
258+ AllowedIPs : []net.IPNet {
259+ ip ,
260+ },
261+ },
262+ },
263+ },
264+ }, nil
265+ }
266+
267+ peerDoesNotHaveIP = func () ([]* wgtypes.Device , error ) {
268+ return []* wgtypes.Device {
269+ {
270+ Name : device ,
271+ Peers : []wgtypes.Peer {
272+ {
273+ PublicKey : peerKey ,
274+ AllowedIPs : []net.IPNet {},
275+ },
276+ },
277+ },
278+ }, nil
279+ }
280+
281+ otherPeerHasIP = func () ([]* wgtypes.Device , error ) {
282+ return []* wgtypes.Device {
283+ {
284+ Name : device ,
285+ Peers : []wgtypes.Peer {
286+ {
287+ PublicKey : peerKey ,
288+ AllowedIPs : []net.IPNet {},
289+ },
290+ {
291+ PublicKey : wgtest .MustPublicKey (),
292+ AllowedIPs : []net.IPNet {
293+ ip ,
294+ },
295+ },
296+ },
297+ },
298+ }, nil
299+ }
300+
301+ removeAllowedIP = wgtypes.Config {
302+ Peers : []wgtypes.PeerConfig {
303+ {
304+ PublicKey : peerKey ,
305+ AllowedIPs : []wgtypes.AllowedIPConfig {
306+ {
307+ IPNet : ip ,
308+ Remove : true ,
309+ },
310+ },
311+ },
312+ },
313+ }
314+
315+ removeAllowedIPUndone = wgtypes.Config {
316+ Peers : []wgtypes.PeerConfig {
317+ {
318+ PublicKey : peerKey ,
319+ AllowedIPs : []wgtypes.AllowedIPConfig {
320+ {
321+ IPNet : ip ,
322+ Remove : true ,
323+ },
324+ {
325+ IPNet : ip ,
326+ },
327+ },
328+ },
329+ },
330+ }
331+
332+ simulateRemoveAllowedIP = wgtypes.Config {
333+ Peers : []wgtypes.PeerConfig {
334+ {
335+ PublicKey : dummyPeerKey ,
336+ AllowedIPs : []wgtypes.AllowedIPConfig {
337+ {
338+ IPNet : ip ,
339+ },
340+ },
341+ },
342+ {
343+ PublicKey : peerKey ,
344+ },
345+ {
346+ PublicKey : dummyPeerKey ,
347+ Remove : true ,
348+ },
349+ },
350+ }
351+
352+ addAllowedIP = wgtypes.Config {
353+ Peers : []wgtypes.PeerConfig {
354+ {
355+ PublicKey : peerKey ,
356+ AllowedIPs : []wgtypes.AllowedIPConfig {
357+ {
358+ IPNet : ip ,
359+ },
360+ },
361+ },
362+ },
363+ }
364+
365+ dontRemoveIP = wgtypes.Config {
366+ Peers : []wgtypes.PeerConfig {
367+ {
368+ PublicKey : peerKey ,
369+ },
370+ },
371+ }
372+ )
373+
374+ tests := []struct {
375+ name string
376+ supportsFn supportsFunc
377+ devicesFn devicesFunc
378+ cfg wgtypes.Config
379+ expectCfg wgtypes.Config
380+ err error
381+ }{
382+ {
383+ name : "not supported + remove IP + peer has IP" ,
384+ supportsFn : notSupported ,
385+ devicesFn : peerHasIP ,
386+ cfg : removeAllowedIP ,
387+ expectCfg : simulateRemoveAllowedIP ,
388+ err : nil ,
389+ },
390+ {
391+ name : "not supported + remove IP + peer does not have IP" ,
392+ supportsFn : notSupported ,
393+ devicesFn : peerDoesNotHaveIP ,
394+ cfg : removeAllowedIP ,
395+ expectCfg : dontRemoveIP ,
396+ err : nil ,
397+ },
398+ {
399+ name : "not supported + remove IP + other peer has IP" ,
400+ supportsFn : notSupported ,
401+ devicesFn : otherPeerHasIP ,
402+ cfg : removeAllowedIP ,
403+ expectCfg : dontRemoveIP ,
404+ err : nil ,
405+ },
406+ {
407+ name : "not supported + remove IP undone" ,
408+ supportsFn : notSupported ,
409+ devicesFn : peerHasIP ,
410+ cfg : removeAllowedIPUndone ,
411+ expectCfg : addAllowedIP ,
412+ err : nil ,
413+ },
414+ {
415+ name : "not supported + don't remove IP" ,
416+ supportsFn : notSupported ,
417+ devicesFn : peerHasIP ,
418+ cfg : addAllowedIP ,
419+ expectCfg : addAllowedIP ,
420+ err : nil ,
421+ },
422+ {
423+ name : "supported + remove IP" ,
424+ supportsFn : supported ,
425+ cfg : removeAllowedIP ,
426+ expectCfg : removeAllowedIP ,
427+ err : nil ,
428+ },
429+ {
430+ name : "supported + don't remove IP" ,
431+ supportsFn : supported ,
432+ cfg : addAllowedIP ,
433+ expectCfg : addAllowedIP ,
434+ err : nil ,
435+ },
436+ {
437+ name : "probe error + remove IP" ,
438+ supportsFn : returnsError ,
439+ cfg : removeAllowedIP ,
440+ expectCfg : wgtypes.Config {},
441+ err : errFoo ,
442+ },
443+ {
444+ name : "probe error + don't remove IP" ,
445+ supportsFn : returnsError ,
446+ cfg : addAllowedIP ,
447+ expectCfg : wgtypes.Config {},
448+ err : errFoo ,
449+ },
450+ }
451+
452+ for _ , tt := range tests {
453+ t .Run (tt .name , func (t * testing.T ) {
454+ var finalCfg wgtypes.Config
455+
456+ cs := WithShim (& testClient {
457+ ConfigureDeviceFunc : func (name string , cfg wgtypes.Config ) error {
458+ finalCfg = cfg
459+
460+ return nil
461+ },
462+ DevicesFunc : tt .devicesFn ,
463+ SupportsAllowedIPRemoveFunc : tt .supportsFn ,
464+ })
465+
466+ c := & Client {cs : []wginternal.Client {cs }}
467+
468+ err := c .ConfigureDevice (device , tt .cfg )
469+ if ! errors .Is (err , tt .err ) {
470+ t .Fatalf ("unexpected error: got %s, want %s" , err , tt .err )
471+ }
472+
473+ if diff := cmp .Diff (tt .expectCfg , finalCfg ); diff != "" {
474+ t .Fatalf ("unexpected config (-want +got):\n %s" , diff )
475+ }
476+ })
477+ }
478+ }
479+
227480type testClient struct {
228- CloseFunc func () error
229- DevicesFunc func () ([]* wgtypes.Device , error )
230- DeviceFunc func (name string ) (* wgtypes.Device , error )
231- ConfigureDeviceFunc func (name string , cfg wgtypes.Config ) error
481+ CloseFunc func () error
482+ DevicesFunc func () ([]* wgtypes.Device , error )
483+ DeviceFunc func (name string ) (* wgtypes.Device , error )
484+ ConfigureDeviceFunc func (name string , cfg wgtypes.Config ) error
485+ SupportsAllowedIPRemoveFunc func (name string ) (bool , error )
232486}
233487
234488func (c * testClient ) Close () error { return c .CloseFunc () }
@@ -240,3 +494,7 @@ func (c *testClient) Device(name string) (*wgtypes.Device, error) {
240494func (c * testClient ) ConfigureDevice (name string , cfg wgtypes.Config ) error {
241495 return c .ConfigureDeviceFunc (name , cfg )
242496}
497+
498+ func (c * testClient ) SupportsAllowedIPRemove (name string ) (bool , error ) {
499+ return c .SupportsAllowedIPRemoveFunc (name )
500+ }
0 commit comments