Skip to content

Commit 2a1ddbe

Browse files
committed
Initial netip implementation
Signed-off-by: MrMelon54 <[email protected]>
1 parent a9ab227 commit 2a1ddbe

File tree

12 files changed

+141
-177
lines changed

12 files changed

+141
-177
lines changed

client_integration_test.go

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package wgctrl_test
22

33
import (
4-
"bytes"
54
"errors"
65
"fmt"
76
"net"
7+
"net/netip"
88
"os"
99
"sort"
1010
"strings"
@@ -144,9 +144,9 @@ func testGet(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
144144
func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
145145
var (
146146
port = 8888
147-
ips = []net.IPNet{
148-
wgtest.MustCIDR("192.0.2.0/32"),
149-
wgtest.MustCIDR("2001:db8::/128"),
147+
ips = []netip.Prefix{
148+
netip.MustParsePrefix("192.0.2.0/32"),
149+
netip.MustParsePrefix("2001:db8::/128"),
150150
}
151151

152152
priv = wgtest.MustPrivateKey()
@@ -194,7 +194,7 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
194194
for i := range dn.Peers {
195195
ips := dn.Peers[i].AllowedIPs
196196
sort.Slice(ips, func(i, j int) bool {
197-
return bytes.Compare(ips[i].IP, ips[j].IP) > 0
197+
return ips[i].Addr().Compare(ips[j].Addr()) > 0
198198
})
199199
}
200200

@@ -229,17 +229,19 @@ func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
229229
t.Fatalf("failed to create cursor: %v", err)
230230
}
231231

232-
var ips []net.IPNet
232+
var ips []netip.Prefix
233233
for pos := cur.Next(); pos != nil; pos = cur.Next() {
234234
bits := 128
235235
if pos.IP.To4() != nil {
236236
bits = 32
237237
}
238238

239-
ips = append(ips, net.IPNet{
240-
IP: pos.IP,
241-
Mask: net.CIDRMask(bits, bits),
242-
})
239+
addr, ok := netip.AddrFromSlice(pos.IP)
240+
if !ok {
241+
t.Fatalf("failed to convert net.IP to netip.Addr: %s", pos.IP)
242+
}
243+
244+
ips = append(ips, netip.PrefixFrom(addr, bits))
243245
}
244246

245247
peers = append(peers, wgtypes.PeerConfig{
@@ -291,7 +293,7 @@ func testConfigureManyPeers(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
291293
PresharedKey: &pk,
292294
ReplaceAllowedIPs: true,
293295
Endpoint: &net.UDPAddr{
294-
IP: ips[0].IP,
296+
IP: ips[0].Addr().AsSlice(),
295297
Port: 1111,
296298
},
297299
PersistentKeepaliveInterval: &dur,
@@ -370,7 +372,6 @@ func testConfigurePeersUpdateOnly(t *testing.T, c *wgctrl.Client, d *wgtypes.Dev
370372
t.Skip("FreeBSD kernel devices do not support UpdateOnly flag")
371373
}
372374

373-
374375
t.Fatalf("failed to configure second time on %q: %v", d.Name, err)
375376
}
376377

@@ -428,7 +429,7 @@ func countPeerIPs(d *wgtypes.Device) int {
428429
return count
429430
}
430431

431-
func ipsString(ipns []net.IPNet) string {
432+
func ipsString(ipns []netip.Prefix) string {
432433
ss := make([]string, 0, len(ipns))
433434
for _, ipn := range ipns {
434435
ss = append(ss, ipn.String())
@@ -437,23 +438,25 @@ func ipsString(ipns []net.IPNet) string {
437438
return strings.Join(ss, ", ")
438439
}
439440

440-
func generateIPs(n int) []net.IPNet {
441+
func generateIPs(n int) []netip.Prefix {
441442
cur, err := ipaddr.Parse("2001:db8::/64")
442443
if err != nil {
443444
panicf("failed to create cursor: %v", err)
444445
}
445446

446-
ips := make([]net.IPNet, 0, n)
447+
ips := make([]netip.Prefix, 0, n)
447448
for i := 0; i < n; i++ {
448449
pos := cur.Next()
449450
if pos == nil {
450451
panic("hit nil IP during IP generation")
451452
}
452453

453-
ips = append(ips, net.IPNet{
454-
IP: pos.IP,
455-
Mask: net.CIDRMask(128, 128),
456-
})
454+
addr, ok := netip.AddrFromSlice(pos.IP)
455+
if !ok {
456+
panicf("failed to convert net.IP to netip.Addr: %s", pos.IP)
457+
}
458+
459+
ips = append(ips, netip.PrefixFrom(addr, 128))
457460
}
458461

459462
return ips

cmd/wgctrl/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"flag"
77
"fmt"
88
"log"
9-
"net"
9+
"net/netip"
1010
"strings"
1111

1212
"golang.zx2c4.com/wireguard/wgctrl"
@@ -83,7 +83,7 @@ func printPeer(p wgtypes.Peer) {
8383
)
8484
}
8585

86-
func ipsString(ipns []net.IPNet) string {
86+
func ipsString(ipns []netip.Prefix) string {
8787
ss := make([]string, 0, len(ipns))
8888
for _, ipn := range ipns {
8989
ss = append(ss, ipn.String())

internal/wglinux/client_linux_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ package wglinux
66
import (
77
"errors"
88
"fmt"
9-
"net"
9+
"net/netip"
1010
"os"
1111
"os/user"
1212
"syscall"
@@ -325,7 +325,7 @@ func diffAttrs(x, y []netlink.Attribute) string {
325325
return cmp.Diff(xPrime, yPrime)
326326
}
327327

328-
func mustAllowedIPs(ipns []net.IPNet) []byte {
328+
func mustAllowedIPs(ipns []netip.Prefix) []byte {
329329
ae := netlink.NewAttributeEncoder()
330330
if err := encodeAllowedIPs(ipns)(ae); err != nil {
331331
panicf("failed to create allowed IP attributes: %v", err)

internal/wglinux/configure_linux.go

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/binary"
88
"fmt"
99
"net"
10+
"net/netip"
1011
"unsafe"
1112

1213
"github.com/mdlayher/netlink"
@@ -101,16 +102,16 @@ func buildBatches(cfg wgtypes.Config) []wgtypes.Config {
101102
// Iterate until no more allowed IPs.
102103
var done bool
103104
for !done {
104-
var tmp []net.IPNet
105+
var tmp []netip.Prefix
105106
if len(p.AllowedIPs) < ipBatchChunk {
106107
// IPs all fit within a batch; we are done.
107-
tmp = make([]net.IPNet, len(p.AllowedIPs))
108+
tmp = make([]netip.Prefix, len(p.AllowedIPs))
108109
copy(tmp, p.AllowedIPs)
109110
done = true
110111
} else {
111112
// IPs are larger than a single batch, copy a batch out and
112113
// advance the cursor.
113-
tmp = make([]net.IPNet, ipBatchChunk)
114+
tmp = make([]netip.Prefix, ipBatchChunk)
114115
copy(tmp, p.AllowedIPs[:ipBatchChunk])
115116

116117
p.AllowedIPs = p.AllowedIPs[ipBatchChunk:]
@@ -214,59 +215,52 @@ func encodePeer(p wgtypes.PeerConfig) func(ae *netlink.AttributeEncoder) error {
214215
// sockaddr_in or sockaddr_in6 bytes.
215216
func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) {
216217
return func() ([]byte, error) {
217-
if !isValidIP(endpoint.IP) {
218+
addrPort := endpoint.AddrPort()
219+
if !addrPort.Addr().IsValid() {
218220
return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String())
219221
}
220222

221223
// Is this an IPv6 address?
222-
if isIPv6(endpoint.IP) {
223-
var addr [16]byte
224-
copy(addr[:], endpoint.IP.To16())
225-
224+
if addrPort.Addr().Is6() {
226225
sa := unix.RawSockaddrInet6{
227226
Family: unix.AF_INET6,
228227
Port: sockaddrPort(endpoint.Port),
229-
Addr: addr,
228+
Addr: addrPort.Addr().As16(),
230229
}
231230

232231
return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil
233232
}
234233

235-
// IPv4 address handling.
236-
var addr [4]byte
237-
copy(addr[:], endpoint.IP.To4())
238-
239234
sa := unix.RawSockaddrInet4{
240235
Family: unix.AF_INET,
241236
Port: sockaddrPort(endpoint.Port),
242-
Addr: addr,
237+
Addr: addrPort.Addr().As4(),
243238
}
244239

245240
return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil
246241
}
247242
}
248243

249244
// encodeAllowedIPs returns a function to encode allowed IP nested attributes.
250-
func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error {
245+
func encodeAllowedIPs(ipns []netip.Prefix) func(ae *netlink.AttributeEncoder) error {
251246
return func(ae *netlink.AttributeEncoder) error {
252247
for i, ipn := range ipns {
253-
if !isValidIP(ipn.IP) {
254-
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String())
248+
if !ipn.Addr().IsValid() {
249+
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.Addr())
255250
}
256251

257252
family := uint16(unix.AF_INET6)
258-
if !isIPv6(ipn.IP) {
253+
if ipn.Addr().Is4() {
259254
// Make sure address is 4 bytes if IPv4.
260255
family = unix.AF_INET
261-
ipn.IP = ipn.IP.To4()
262256
}
263257

264258
// Netlink arrays use type as an array index.
265259
ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error {
266260
nae.Uint16(unix.WGALLOWEDIP_A_FAMILY, family)
267-
nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.IP)
261+
nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.Addr().AsSlice())
268262

269-
ones, _ := ipn.Mask.Size()
263+
ones := ipn.Bits()
270264
nae.Uint8(unix.WGALLOWEDIP_A_CIDR_MASK, uint8(ones))
271265
return nil
272266
})
@@ -276,16 +270,6 @@ func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error
276270
}
277271
}
278272

279-
// isValidIP determines if IP is a valid IPv4 or IPv6 address.
280-
func isValidIP(ip net.IP) bool {
281-
return ip.To16() != nil
282-
}
283-
284-
// isIPv6 determines if IP is a valid IPv6 address.
285-
func isIPv6(ip net.IP) bool {
286-
return isValidIP(ip) && ip.To4() == nil
287-
}
288-
289273
// sockaddrPort interprets port as a big endian uint16 for use passing sockaddr
290274
// structures to the kernel.
291275
func sockaddrPort(port int) uint16 {

internal/wglinux/configure_linux_test.go

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package wglinux
55

66
import (
77
"net"
8+
"net/netip"
89
"testing"
910
"time"
1011
"unsafe"
@@ -45,9 +46,9 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
4546
name: "bad peer allowed IP",
4647
cfg: wgtypes.Config{
4748
Peers: []wgtypes.PeerConfig{{
48-
AllowedIPs: []net.IPNet{{
49-
IP: net.IP{0xff},
50-
}},
49+
AllowedIPs: []netip.Prefix{
50+
{},
51+
},
5152
}},
5253
},
5354
},
@@ -71,8 +72,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
7172
PresharedKey: keyPtr(wgtest.MustHexKey("188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52")),
7273
Endpoint: wgtest.MustUDPAddr("[abcd:23::33%2]:51820"),
7374
ReplaceAllowedIPs: true,
74-
AllowedIPs: []net.IPNet{
75-
wgtest.MustCIDR("192.168.4.4/32"),
75+
AllowedIPs: []netip.Prefix{
76+
netip.MustParsePrefix("192.168.4.4/32"),
7677
},
7778
},
7879
{
@@ -81,17 +82,17 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
8182
Endpoint: wgtest.MustUDPAddr("182.122.22.19:3233"),
8283
PersistentKeepaliveInterval: durPtr(111 * time.Second),
8384
ReplaceAllowedIPs: true,
84-
AllowedIPs: []net.IPNet{
85-
wgtest.MustCIDR("192.168.4.6/32"),
85+
AllowedIPs: []netip.Prefix{
86+
netip.MustParsePrefix("192.168.4.6/32"),
8687
},
8788
},
8889
{
8990
PublicKey: wgtest.MustHexKey("662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58"),
9091
Endpoint: wgtest.MustUDPAddr("5.152.198.39:51820"),
9192
ReplaceAllowedIPs: true,
92-
AllowedIPs: []net.IPNet{
93-
wgtest.MustCIDR("192.168.4.10/32"),
94-
wgtest.MustCIDR("192.168.4.11/32"),
93+
AllowedIPs: []netip.Prefix{
94+
netip.MustParsePrefix("192.168.4.10/32"),
95+
netip.MustParsePrefix("192.168.4.11/32"),
9596
},
9697
},
9798
{
@@ -151,8 +152,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
151152
},
152153
{
153154
Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS,
154-
Data: mustAllowedIPs([]net.IPNet{
155-
wgtest.MustCIDR("192.168.4.4/32"),
155+
Data: mustAllowedIPs([]netip.Prefix{
156+
netip.MustParsePrefix("192.168.4.4/32"),
156157
}),
157158
},
158159
}...),
@@ -182,8 +183,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
182183
},
183184
{
184185
Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS,
185-
Data: mustAllowedIPs([]net.IPNet{
186-
wgtest.MustCIDR("192.168.4.6/32"),
186+
Data: mustAllowedIPs([]netip.Prefix{
187+
netip.MustParsePrefix("192.168.4.6/32"),
187188
}),
188189
},
189190
}...),
@@ -209,9 +210,9 @@ func TestLinuxClientConfigureDevice(t *testing.T) {
209210
},
210211
{
211212
Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS,
212-
Data: mustAllowedIPs([]net.IPNet{
213-
wgtest.MustCIDR("192.168.4.10/32"),
214-
wgtest.MustCIDR("192.168.4.11/32"),
213+
Data: mustAllowedIPs([]netip.Prefix{
214+
netip.MustParsePrefix("192.168.4.10/32"),
215+
netip.MustParsePrefix("192.168.4.11/32"),
215216
}),
216217
},
217218
}...),
@@ -513,23 +514,25 @@ func keyBytes(s string) []byte {
513514
return k[:]
514515
}
515516

516-
func generateIPs(n int) []net.IPNet {
517+
func generateIPs(n int) []netip.Prefix {
517518
cur, err := ipaddr.Parse("2001:db8::/64")
518519
if err != nil {
519520
panicf("failed to create cursor: %v", err)
520521
}
521522

522-
ips := make([]net.IPNet, 0, n)
523+
ips := make([]netip.Prefix, 0, n)
523524
for i := 0; i < n; i++ {
524525
pos := cur.Next()
525526
if pos == nil {
526527
panic("hit nil IP during IP generation")
527528
}
528529

529-
ips = append(ips, net.IPNet{
530-
IP: pos.IP,
531-
Mask: net.CIDRMask(128, 128),
532-
})
530+
addr, ok := netip.AddrFromSlice(pos.IP)
531+
if !ok {
532+
panicf("failed to convert net.IP to netip.Addr: %s", pos.IP)
533+
}
534+
535+
ips = append(ips, netip.PrefixFrom(addr, 128))
533536
}
534537

535538
return ips

0 commit comments

Comments
 (0)