From 2a1ddbec0a8cae06a44f60c2fde615a84ebb829b Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Mon, 3 Feb 2025 23:46:45 +0000 Subject: [PATCH 1/8] Initial netip implementation Signed-off-by: MrMelon54 --- client_integration_test.go | 41 +++++++------- cmd/wgctrl/main.go | 4 +- internal/wglinux/client_linux_test.go | 4 +- internal/wglinux/configure_linux.go | 46 ++++++---------- internal/wglinux/configure_linux_test.go | 49 +++++++++-------- internal/wglinux/parse_linux.go | 37 ++++++------- internal/wglinux/parse_linux_test.go | 69 ++++++++++++------------ internal/wgtest/wgtest.go | 13 +---- internal/wguser/configure_test.go | 16 +++--- internal/wguser/parse.go | 7 +-- internal/wguser/parse_test.go | 27 +++------- wgtypes/types.go | 5 +- 12 files changed, 141 insertions(+), 177 deletions(-) diff --git a/client_integration_test.go b/client_integration_test.go index ab4bedc..6eb6071 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -1,10 +1,10 @@ package wgctrl_test import ( - "bytes" "errors" "fmt" "net" + "net/netip" "os" "sort" "strings" @@ -144,9 +144,9 @@ func testGet(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { var ( port = 8888 - ips = []net.IPNet{ - wgtest.MustCIDR("192.0.2.0/32"), - wgtest.MustCIDR("2001:db8::/128"), + ips = []netip.Prefix{ + netip.MustParsePrefix("192.0.2.0/32"), + netip.MustParsePrefix("2001:db8::/128"), } priv = wgtest.MustPrivateKey() @@ -194,7 +194,7 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { for i := range dn.Peers { ips := dn.Peers[i].AllowedIPs sort.Slice(ips, func(i, j int) bool { - return bytes.Compare(ips[i].IP, ips[j].IP) > 0 + return ips[i].Addr().Compare(ips[j].Addr()) > 0 }) } @@ -229,17 +229,19 @@ func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { t.Fatalf("failed to create cursor: %v", err) } - var ips []net.IPNet + var ips []netip.Prefix for pos := cur.Next(); pos != nil; pos = cur.Next() { bits := 128 if pos.IP.To4() != nil { bits = 32 } - ips = append(ips, net.IPNet{ - IP: pos.IP, - Mask: net.CIDRMask(bits, bits), - }) + addr, ok := netip.AddrFromSlice(pos.IP) + if !ok { + t.Fatalf("failed to convert net.IP to netip.Addr: %s", pos.IP) + } + + ips = append(ips, netip.PrefixFrom(addr, bits)) } peers = append(peers, wgtypes.PeerConfig{ @@ -291,7 +293,7 @@ func testConfigureManyPeers(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) { PresharedKey: &pk, ReplaceAllowedIPs: true, Endpoint: &net.UDPAddr{ - IP: ips[0].IP, + IP: ips[0].Addr().AsSlice(), Port: 1111, }, PersistentKeepaliveInterval: &dur, @@ -370,7 +372,6 @@ func testConfigurePeersUpdateOnly(t *testing.T, c *wgctrl.Client, d *wgtypes.Dev t.Skip("FreeBSD kernel devices do not support UpdateOnly flag") } - t.Fatalf("failed to configure second time on %q: %v", d.Name, err) } @@ -428,7 +429,7 @@ func countPeerIPs(d *wgtypes.Device) int { return count } -func ipsString(ipns []net.IPNet) string { +func ipsString(ipns []netip.Prefix) string { ss := make([]string, 0, len(ipns)) for _, ipn := range ipns { ss = append(ss, ipn.String()) @@ -437,23 +438,25 @@ func ipsString(ipns []net.IPNet) string { return strings.Join(ss, ", ") } -func generateIPs(n int) []net.IPNet { +func generateIPs(n int) []netip.Prefix { cur, err := ipaddr.Parse("2001:db8::/64") if err != nil { panicf("failed to create cursor: %v", err) } - ips := make([]net.IPNet, 0, n) + ips := make([]netip.Prefix, 0, n) for i := 0; i < n; i++ { pos := cur.Next() if pos == nil { panic("hit nil IP during IP generation") } - ips = append(ips, net.IPNet{ - IP: pos.IP, - Mask: net.CIDRMask(128, 128), - }) + addr, ok := netip.AddrFromSlice(pos.IP) + if !ok { + panicf("failed to convert net.IP to netip.Addr: %s", pos.IP) + } + + ips = append(ips, netip.PrefixFrom(addr, 128)) } return ips diff --git a/cmd/wgctrl/main.go b/cmd/wgctrl/main.go index 9fbba3a..5830fa6 100644 --- a/cmd/wgctrl/main.go +++ b/cmd/wgctrl/main.go @@ -6,7 +6,7 @@ import ( "flag" "fmt" "log" - "net" + "net/netip" "strings" "golang.zx2c4.com/wireguard/wgctrl" @@ -83,7 +83,7 @@ func printPeer(p wgtypes.Peer) { ) } -func ipsString(ipns []net.IPNet) string { +func ipsString(ipns []netip.Prefix) string { ss := make([]string, 0, len(ipns)) for _, ipn := range ipns { ss = append(ss, ipn.String()) diff --git a/internal/wglinux/client_linux_test.go b/internal/wglinux/client_linux_test.go index 5f6c054..4aedfea 100644 --- a/internal/wglinux/client_linux_test.go +++ b/internal/wglinux/client_linux_test.go @@ -6,7 +6,7 @@ package wglinux import ( "errors" "fmt" - "net" + "net/netip" "os" "os/user" "syscall" @@ -325,7 +325,7 @@ func diffAttrs(x, y []netlink.Attribute) string { return cmp.Diff(xPrime, yPrime) } -func mustAllowedIPs(ipns []net.IPNet) []byte { +func mustAllowedIPs(ipns []netip.Prefix) []byte { ae := netlink.NewAttributeEncoder() if err := encodeAllowedIPs(ipns)(ae); err != nil { panicf("failed to create allowed IP attributes: %v", err) diff --git a/internal/wglinux/configure_linux.go b/internal/wglinux/configure_linux.go index bf29092..f8f281c 100644 --- a/internal/wglinux/configure_linux.go +++ b/internal/wglinux/configure_linux.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "unsafe" "github.com/mdlayher/netlink" @@ -101,16 +102,16 @@ func buildBatches(cfg wgtypes.Config) []wgtypes.Config { // Iterate until no more allowed IPs. var done bool for !done { - var tmp []net.IPNet + var tmp []netip.Prefix if len(p.AllowedIPs) < ipBatchChunk { // IPs all fit within a batch; we are done. - tmp = make([]net.IPNet, len(p.AllowedIPs)) + tmp = make([]netip.Prefix, len(p.AllowedIPs)) copy(tmp, p.AllowedIPs) done = true } else { // IPs are larger than a single batch, copy a batch out and // advance the cursor. - tmp = make([]net.IPNet, ipBatchChunk) + tmp = make([]netip.Prefix, ipBatchChunk) copy(tmp, p.AllowedIPs[:ipBatchChunk]) p.AllowedIPs = p.AllowedIPs[ipBatchChunk:] @@ -214,32 +215,26 @@ func encodePeer(p wgtypes.PeerConfig) func(ae *netlink.AttributeEncoder) error { // sockaddr_in or sockaddr_in6 bytes. func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) { return func() ([]byte, error) { - if !isValidIP(endpoint.IP) { + addrPort := endpoint.AddrPort() + if !addrPort.Addr().IsValid() { return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String()) } // Is this an IPv6 address? - if isIPv6(endpoint.IP) { - var addr [16]byte - copy(addr[:], endpoint.IP.To16()) - + if addrPort.Addr().Is6() { sa := unix.RawSockaddrInet6{ Family: unix.AF_INET6, Port: sockaddrPort(endpoint.Port), - Addr: addr, + Addr: addrPort.Addr().As16(), } return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil } - // IPv4 address handling. - var addr [4]byte - copy(addr[:], endpoint.IP.To4()) - sa := unix.RawSockaddrInet4{ Family: unix.AF_INET, Port: sockaddrPort(endpoint.Port), - Addr: addr, + Addr: addrPort.Addr().As4(), } return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil @@ -247,26 +242,25 @@ func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) { } // encodeAllowedIPs returns a function to encode allowed IP nested attributes. -func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error { +func encodeAllowedIPs(ipns []netip.Prefix) func(ae *netlink.AttributeEncoder) error { return func(ae *netlink.AttributeEncoder) error { for i, ipn := range ipns { - if !isValidIP(ipn.IP) { - return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String()) + if !ipn.Addr().IsValid() { + return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.Addr()) } family := uint16(unix.AF_INET6) - if !isIPv6(ipn.IP) { + if ipn.Addr().Is4() { // Make sure address is 4 bytes if IPv4. family = unix.AF_INET - ipn.IP = ipn.IP.To4() } // Netlink arrays use type as an array index. ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error { nae.Uint16(unix.WGALLOWEDIP_A_FAMILY, family) - nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.IP) + nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.Addr().AsSlice()) - ones, _ := ipn.Mask.Size() + ones := ipn.Bits() nae.Uint8(unix.WGALLOWEDIP_A_CIDR_MASK, uint8(ones)) return nil }) @@ -276,16 +270,6 @@ func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error } } -// isValidIP determines if IP is a valid IPv4 or IPv6 address. -func isValidIP(ip net.IP) bool { - return ip.To16() != nil -} - -// isIPv6 determines if IP is a valid IPv6 address. -func isIPv6(ip net.IP) bool { - return isValidIP(ip) && ip.To4() == nil -} - // sockaddrPort interprets port as a big endian uint16 for use passing sockaddr // structures to the kernel. func sockaddrPort(port int) uint16 { diff --git a/internal/wglinux/configure_linux_test.go b/internal/wglinux/configure_linux_test.go index 858ad30..a3d5f67 100644 --- a/internal/wglinux/configure_linux_test.go +++ b/internal/wglinux/configure_linux_test.go @@ -5,6 +5,7 @@ package wglinux import ( "net" + "net/netip" "testing" "time" "unsafe" @@ -45,9 +46,9 @@ func TestLinuxClientConfigureDevice(t *testing.T) { name: "bad peer allowed IP", cfg: wgtypes.Config{ Peers: []wgtypes.PeerConfig{{ - AllowedIPs: []net.IPNet{{ - IP: net.IP{0xff}, - }}, + AllowedIPs: []netip.Prefix{ + {}, + }, }}, }, }, @@ -71,8 +72,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) { PresharedKey: keyPtr(wgtest.MustHexKey("188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52")), Endpoint: wgtest.MustUDPAddr("[abcd:23::33%2]:51820"), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.4/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.4/32"), }, }, { @@ -81,17 +82,17 @@ func TestLinuxClientConfigureDevice(t *testing.T) { Endpoint: wgtest.MustUDPAddr("182.122.22.19:3233"), PersistentKeepaliveInterval: durPtr(111 * time.Second), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.6/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.6/32"), }, }, { PublicKey: wgtest.MustHexKey("662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58"), Endpoint: wgtest.MustUDPAddr("5.152.198.39:51820"), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.10/32"), - wgtest.MustCIDR("192.168.4.11/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.10/32"), + netip.MustParsePrefix("192.168.4.11/32"), }, }, { @@ -151,8 +152,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) { }, { Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.4.4/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.4.4/32"), }), }, }...), @@ -182,8 +183,8 @@ func TestLinuxClientConfigureDevice(t *testing.T) { }, { Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.4.6/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.4.6/32"), }), }, }...), @@ -209,9 +210,9 @@ func TestLinuxClientConfigureDevice(t *testing.T) { }, { Type: netlink.Nested | unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.4.10/32"), - wgtest.MustCIDR("192.168.4.11/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.4.10/32"), + netip.MustParsePrefix("192.168.4.11/32"), }), }, }...), @@ -513,23 +514,25 @@ func keyBytes(s string) []byte { return k[:] } -func generateIPs(n int) []net.IPNet { +func generateIPs(n int) []netip.Prefix { cur, err := ipaddr.Parse("2001:db8::/64") if err != nil { panicf("failed to create cursor: %v", err) } - ips := make([]net.IPNet, 0, n) + ips := make([]netip.Prefix, 0, n) for i := 0; i < n; i++ { pos := cur.Next() if pos == nil { panic("hit nil IP during IP generation") } - ips = append(ips, net.IPNet{ - IP: pos.IP, - Mask: net.CIDRMask(128, 128), - }) + addr, ok := netip.AddrFromSlice(pos.IP) + if !ok { + panicf("failed to convert net.IP to netip.Addr: %s", pos.IP) + } + + ips = append(ips, netip.PrefixFrom(addr, 128)) } return ips diff --git a/internal/wglinux/parse_linux.go b/internal/wglinux/parse_linux.go index 9630ffc..d8893bc 100644 --- a/internal/wglinux/parse_linux.go +++ b/internal/wglinux/parse_linux.go @@ -6,6 +6,7 @@ package wglinux import ( "fmt" "net" + "net/netip" "time" "unsafe" @@ -130,24 +131,26 @@ func parsePeer(ad *netlink.AttributeDecoder) wgtypes.Peer { } // parseAllowedIPs parses a slice of net.IPNet from a netlink attribute payload. -func parseAllowedIPs(ipns *[]net.IPNet) func(ad *netlink.AttributeDecoder) error { +func parseAllowedIPs(ipns *[]netip.Prefix) func(ad *netlink.AttributeDecoder) error { return func(ad *netlink.AttributeDecoder) error { // Initialize to the number of allowed IPs and begin iterating through // the netlink array to decode each one. - *ipns = make([]net.IPNet, 0, ad.Len()) + *ipns = make([]netip.Prefix, 0, ad.Len()) for ad.Next() { // Allowed IP nested attributes. ad.Nested(func(nad *netlink.AttributeDecoder) error { var ( - ipn net.IPNet - mask int + ipn netip.Addr + mask int + // TODO: we already have the family stored in ipn, is this needed? family int + _ = family ) for nad.Next() { switch nad.Type() { case unix.WGALLOWEDIP_A_IPADDR: - nad.Do(parseAddr(&ipn.IP)) + nad.Do(parseAddr(&ipn)) case unix.WGALLOWEDIP_A_CIDR_MASK: mask = int(nad.Uint8()) case unix.WGALLOWEDIP_A_FAMILY: @@ -159,16 +162,9 @@ func parseAllowedIPs(ipns *[]net.IPNet) func(ad *netlink.AttributeDecoder) error return err } - // The address family determines the correct number of bits in - // the mask. - switch family { - case unix.AF_INET: - ipn.Mask = net.CIDRMask(mask, 32) - case unix.AF_INET6: - ipn.Mask = net.CIDRMask(mask, 128) - } + ipp := netip.PrefixFrom(ipn, mask) - *ipns = append(*ipns, ipn) + *ipns = append(*ipns, ipp) return nil }) } @@ -191,17 +187,14 @@ func parseKey(key *wgtypes.Key) func(b []byte) error { } // parseAddr parses a net.IP from raw in_addr or in6_addr struct bytes. -func parseAddr(ip *net.IP) func(b []byte) error { +func parseAddr(ip *netip.Addr) func(b []byte) error { return func(b []byte) error { - switch len(b) { - case net.IPv4len, net.IPv6len: - // Okay to convert directly to net.IP; memory layout is identical. - *ip = make(net.IP, len(b)) - copy(*ip, b) - return nil - default: + parsedIP, ok := netip.AddrFromSlice(b) + if !ok { return fmt.Errorf("wglinux: unexpected IP address size: %d", len(b)) } + *ip = parsedIP + return nil } } diff --git a/internal/wglinux/parse_linux_test.go b/internal/wglinux/parse_linux_test.go index fcb4eeb..67ba8d9 100644 --- a/internal/wglinux/parse_linux_test.go +++ b/internal/wglinux/parse_linux_test.go @@ -5,6 +5,7 @@ package wglinux import ( "net" + "net/netip" "runtime" "testing" "time" @@ -229,9 +230,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.1.10/32"), - wgtest.MustCIDR("fd00::1/128"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.1.10/32"), + netip.MustParsePrefix("fd00::1/128"), }), }, { @@ -286,9 +287,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { LastHandshakeTime: time.Unix(10, 20), ReceiveBytes: 100, TransmitBytes: 200, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.1.10/32"), - wgtest.MustCIDR("fd00::1/128"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.10/32"), + netip.MustParsePrefix("fd00::1/128"), }, ProtocolVersion: 1, }, @@ -328,9 +329,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("192.168.1.10/32"), - wgtest.MustCIDR("192.168.1.11/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("192.168.1.10/32"), + netip.MustParsePrefix("192.168.1.11/32"), }), }, }...), @@ -352,9 +353,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("fd00:dead:beef:dead::/64"), - wgtest.MustCIDR("fd00:dead:beef:ffff::/64"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("fd00:dead:beef:dead::/64"), + netip.MustParsePrefix("fd00:dead:beef:ffff::/64"), }), }, }...), @@ -368,9 +369,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("10.10.10.0/24"), - wgtest.MustCIDR("10.10.11.0/24"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), }), }, }...), @@ -392,9 +393,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("10.10.12.0/24"), - wgtest.MustCIDR("10.10.13.0/24"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("10.10.12.0/24"), + netip.MustParsePrefix("10.10.13.0/24"), }), }, }...), @@ -408,9 +409,9 @@ func TestLinuxClientDevicesOK(t *testing.T) { }, { Type: unix.WGPEER_A_ALLOWEDIPS, - Data: mustAllowedIPs([]net.IPNet{ - wgtest.MustCIDR("fd00:1234::/32"), - wgtest.MustCIDR("fd00:4567::/32"), + Data: mustAllowedIPs([]netip.Prefix{ + netip.MustParsePrefix("fd00:1234::/32"), + netip.MustParsePrefix("fd00:4567::/32"), }), }, }...), @@ -427,27 +428,27 @@ func TestLinuxClientDevicesOK(t *testing.T) { Peers: []wgtypes.Peer{ { PublicKey: keyA, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.1.10/32"), - wgtest.MustCIDR("192.168.1.11/32"), - wgtest.MustCIDR("fd00:dead:beef:dead::/64"), - wgtest.MustCIDR("fd00:dead:beef:ffff::/64"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.10/32"), + netip.MustParsePrefix("192.168.1.11/32"), + netip.MustParsePrefix("fd00:dead:beef:dead::/64"), + netip.MustParsePrefix("fd00:dead:beef:ffff::/64"), }, }, { PublicKey: keyB, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("10.10.10.0/24"), - wgtest.MustCIDR("10.10.11.0/24"), - wgtest.MustCIDR("10.10.12.0/24"), - wgtest.MustCIDR("10.10.13.0/24"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + netip.MustParsePrefix("10.10.13.0/24"), }, }, { PublicKey: keyC, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("fd00:1234::/32"), - wgtest.MustCIDR("fd00:4567::/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("fd00:1234::/32"), + netip.MustParsePrefix("fd00:4567::/32"), }, }, }, diff --git a/internal/wgtest/wgtest.go b/internal/wgtest/wgtest.go index c288e7f..62f2b68 100644 --- a/internal/wgtest/wgtest.go +++ b/internal/wgtest/wgtest.go @@ -3,21 +3,10 @@ package wgtest import ( "encoding/hex" "fmt" - "net" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "net" ) -// MustCIDR converts CIDR string s into a net.IPNet or panics. -func MustCIDR(s string) net.IPNet { - _, cidr, err := net.ParseCIDR(s) - if err != nil { - panicf("wgtest: failed to parse CIDR: %v", err) - } - - return *cidr -} - // MustHexKey decodes a hex string s as a key or panics. func MustHexKey(s string) wgtypes.Key { b, err := hex.DecodeString(s) diff --git a/internal/wguser/configure_test.go b/internal/wguser/configure_test.go index 7058c8b..a5dbf38 100644 --- a/internal/wguser/configure_test.go +++ b/internal/wguser/configure_test.go @@ -2,7 +2,7 @@ package wguser import ( "errors" - "net" + "net/netip" "os" "testing" "time" @@ -109,8 +109,8 @@ func TestClientConfigureDeviceOK(t *testing.T) { PresharedKey: keyPtr(wgtest.MustHexKey("188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52")), Endpoint: wgtest.MustUDPAddr("[abcd:23::33%2]:51820"), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.4/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.4/32"), }, }, { @@ -119,17 +119,17 @@ func TestClientConfigureDeviceOK(t *testing.T) { Endpoint: wgtest.MustUDPAddr("182.122.22.19:3233"), PersistentKeepaliveInterval: durPtr(111 * time.Second), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.6/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.6/32"), }, }, { PublicKey: wgtest.MustHexKey("662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58"), Endpoint: wgtest.MustUDPAddr("5.152.198.39:51820"), ReplaceAllowedIPs: true, - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.4.10/32"), - wgtest.MustCIDR("192.168.4.11/32"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.4.10/32"), + netip.MustParsePrefix("192.168.4.11/32"), }, }, { diff --git a/internal/wguser/parse.go b/internal/wguser/parse.go index dc996b2..e68d3e2 100644 --- a/internal/wguser/parse.go +++ b/internal/wguser/parse.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "net/netip" "os" "strconv" "time" @@ -243,16 +244,16 @@ func (dp *deviceParser) parseAddr(s string) *net.UDPAddr { } // parseInt parses an address CIDR from a string. -func (dp *deviceParser) parseCIDR(s string) *net.IPNet { +func (dp *deviceParser) parseCIDR(s string) *netip.Prefix { if dp.err != nil { return nil } - _, cidr, err := net.ParseCIDR(s) + prefix, err := netip.ParsePrefix(s) if err != nil { dp.err = err return nil } - return cidr + return &prefix } diff --git a/internal/wguser/parse_test.go b/internal/wguser/parse_test.go index 79e6b97..563a77a 100644 --- a/internal/wguser/parse_test.go +++ b/internal/wguser/parse_test.go @@ -2,6 +2,7 @@ package wguser import ( "net" + "net/netip" "testing" "time" @@ -99,11 +100,8 @@ func TestClientDevices(t *testing.T) { Zone: "2", }, LastHandshakeTime: time.Unix(1, 2), - AllowedIPs: []net.IPNet{ - { - IP: net.IP{0xc0, 0xa8, 0x4, 0x4}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff}, - }, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{0xc0, 0xa8, 0x4, 0x4}), 32), }, }, { @@ -119,11 +117,8 @@ func TestClientDevices(t *testing.T) { PersistentKeepaliveInterval: 111000000000, ReceiveBytes: 2224, TransmitBytes: 38333, - AllowedIPs: []net.IPNet{ - { - IP: net.IP{0xc0, 0xa8, 0x4, 0x6}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff}, - }, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{0xc0, 0xa8, 0x4, 0x6}), 32), }, }, { @@ -134,15 +129,9 @@ func TestClientDevices(t *testing.T) { }, ReceiveBytes: 1929999999, TransmitBytes: 1212111, - AllowedIPs: []net.IPNet{ - { - IP: net.IP{0xc0, 0xa8, 0x4, 0xa}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff}, - }, - { - IP: net.IP{0xc0, 0xa8, 0x4, 0xb}, - Mask: net.IPMask{0xff, 0xff, 0xff, 0xff}, - }, + AllowedIPs: []netip.Prefix{ + netip.PrefixFrom(netip.AddrFrom4([4]byte{0xc0, 0xa8, 0x4, 0xa}), 32), + netip.PrefixFrom(netip.AddrFrom4([4]byte{0xc0, 0xa8, 0x4, 0xb}), 32), }, ProtocolVersion: 1, }, diff --git a/wgtypes/types.go b/wgtypes/types.go index 3b33b54..fe35e81 100644 --- a/wgtypes/types.go +++ b/wgtypes/types.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "fmt" "net" + "net/netip" "time" "golang.org/x/crypto/curve25519" @@ -195,7 +196,7 @@ type Peer struct { // // 0.0.0.0/0 indicates that all IPv4 addresses are allowed, and ::/0 // indicates that all IPv6 addresses are allowed. - AllowedIPs []net.IPNet + AllowedIPs []netip.Prefix // ProtocolVersion specifies which version of the WireGuard protocol is used // for this Peer. @@ -272,5 +273,5 @@ type PeerConfig struct { // AllowedIPs specifies a list of allowed IP addresses in CIDR notation // for this peer. - AllowedIPs []net.IPNet + AllowedIPs []netip.Prefix } From 73d5b3810321f336ebc878eac21c2e5c8a1b94c3 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Tue, 4 Mar 2025 19:48:57 +0000 Subject: [PATCH 2/8] Set EquateComparable for netip.Prefix Signed-off-by: MrMelon54 --- internal/wglinux/parse_linux_test.go | 3 ++- internal/wguser/parse_test.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/internal/wglinux/parse_linux_test.go b/internal/wglinux/parse_linux_test.go index 67ba8d9..5af44af 100644 --- a/internal/wglinux/parse_linux_test.go +++ b/internal/wglinux/parse_linux_test.go @@ -4,6 +4,7 @@ package wglinux import ( + "github.com/google/go-cmp/cmp/cmpopts" "net" "net/netip" "runtime" @@ -485,7 +486,7 @@ func TestLinuxClientDevicesOK(t *testing.T) { t.Fatalf("failed to get devices: %v", err) } - if diff := cmp.Diff(tt.devices, devices); diff != "" { + if diff := cmp.Diff(tt.devices, devices, cmpopts.EquateComparable(netip.Prefix{})); diff != "" { t.Fatalf("unexpected devices (-want +got):\n%s", diff) } }) diff --git a/internal/wguser/parse_test.go b/internal/wguser/parse_test.go index 563a77a..292c488 100644 --- a/internal/wguser/parse_test.go +++ b/internal/wguser/parse_test.go @@ -1,6 +1,7 @@ package wguser import ( + "github.com/google/go-cmp/cmp/cmpopts" "net" "net/netip" "testing" @@ -157,7 +158,7 @@ func TestClientDevices(t *testing.T) { return } - if diff := cmp.Diff([]*wgtypes.Device{tt.d}, devs); diff != "" { + if diff := cmp.Diff([]*wgtypes.Device{tt.d}, devs, cmpopts.EquateComparable(netip.Prefix{})); diff != "" { t.Fatalf("unexpected Devices (-want +got):\n%s", diff) } }) From 856913145c20d1bfe53ed9fa4d609e1db45e9672 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Tue, 4 Mar 2025 20:11:29 +0000 Subject: [PATCH 3/8] Unmap 4in6 addresses to maintain previous compatibility Signed-off-by: MrMelon54 --- internal/wglinux/configure_linux.go | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/internal/wglinux/configure_linux.go b/internal/wglinux/configure_linux.go index f8f281c..e9f3f9d 100644 --- a/internal/wglinux/configure_linux.go +++ b/internal/wglinux/configure_linux.go @@ -220,12 +220,15 @@ func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) { return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String()) } + // Unmap 4in6 addresses to maintain previous compatibility + addr := addrPort.Addr().Unmap() + // Is this an IPv6 address? - if addrPort.Addr().Is6() { + if addr.Is6() { sa := unix.RawSockaddrInet6{ Family: unix.AF_INET6, Port: sockaddrPort(endpoint.Port), - Addr: addrPort.Addr().As16(), + Addr: addr.As16(), } return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil @@ -234,7 +237,7 @@ func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) { sa := unix.RawSockaddrInet4{ Family: unix.AF_INET, Port: sockaddrPort(endpoint.Port), - Addr: addrPort.Addr().As4(), + Addr: addr.As4(), } return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil @@ -249,8 +252,11 @@ func encodeAllowedIPs(ipns []netip.Prefix) func(ae *netlink.AttributeEncoder) er return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.Addr()) } + // Unmap 4in6 addresses to maintain previous compatibility + addr := ipn.Addr().Unmap() + family := uint16(unix.AF_INET6) - if ipn.Addr().Is4() { + if addr.Is4() { // Make sure address is 4 bytes if IPv4. family = unix.AF_INET } @@ -258,7 +264,7 @@ func encodeAllowedIPs(ipns []netip.Prefix) func(ae *netlink.AttributeEncoder) er // Netlink arrays use type as an array index. ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error { nae.Uint16(unix.WGALLOWEDIP_A_FAMILY, family) - nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.Addr().AsSlice()) + nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, addr.AsSlice()) ones := ipn.Bits() nae.Uint8(unix.WGALLOWEDIP_A_CIDR_MASK, uint8(ones)) From b92b7eddfaecb0956d37acf234cc2c70a2fbbf0c Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Tue, 4 Mar 2025 20:13:31 +0000 Subject: [PATCH 4/8] Use the correct bit length for the provided address Signed-off-by: MrMelon54 --- internal/wglinux/configure_linux_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/wglinux/configure_linux_test.go b/internal/wglinux/configure_linux_test.go index a3d5f67..80cb380 100644 --- a/internal/wglinux/configure_linux_test.go +++ b/internal/wglinux/configure_linux_test.go @@ -532,7 +532,7 @@ func generateIPs(n int) []netip.Prefix { panicf("failed to convert net.IP to netip.Addr: %s", pos.IP) } - ips = append(ips, netip.PrefixFrom(addr, 128)) + ips = append(ips, netip.PrefixFrom(addr, addr.BitLen())) } return ips From bc95217b06687abfd21edf2b134d739e66870d6f Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Sun, 31 Aug 2025 23:40:07 +0100 Subject: [PATCH 5/8] Validate the address family when parsing Allowed IP nested attributes Signed-off-by: MrMelon54 --- internal/wglinux/parse_linux.go | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/internal/wglinux/parse_linux.go b/internal/wglinux/parse_linux.go index d8893bc..e914a57 100644 --- a/internal/wglinux/parse_linux.go +++ b/internal/wglinux/parse_linux.go @@ -140,11 +140,9 @@ func parseAllowedIPs(ipns *[]netip.Prefix) func(ad *netlink.AttributeDecoder) er // Allowed IP nested attributes. ad.Nested(func(nad *netlink.AttributeDecoder) error { var ( - ipn netip.Addr - mask int - // TODO: we already have the family stored in ipn, is this needed? + ipn netip.Addr + mask int family int - _ = family ) for nad.Next() { @@ -162,6 +160,19 @@ func parseAllowedIPs(ipns *[]netip.Prefix) func(ad *netlink.AttributeDecoder) er return err } + switch family { + case unix.AF_INET: + if !ipn.Is4() { + return fmt.Errorf("decoded IP address does not match the address family") + } + case unix.AF_INET6: + if !ipn.Is6() { + return fmt.Errorf("decoded IP address does not match the address family") + } + default: + return fmt.Errorf("invalid IP address family") + } + ipp := netip.PrefixFrom(ipn, mask) *ipns = append(*ipns, ipp) From 6964aabb18ecf455e329d9a4efafcc9663c0dffc Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Mon, 1 Sep 2025 21:34:20 +0100 Subject: [PATCH 6/8] Convert wgopenbsd to use netip.Prefix Signed-off-by: MrMelon54 --- internal/wgopenbsd/client_openbsd.go | 17 ++++++----------- internal/wgopenbsd/client_openbsd_test.go | 12 ++++++------ 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/internal/wgopenbsd/client_openbsd.go b/internal/wgopenbsd/client_openbsd.go index e2dba81..3b42c63 100644 --- a/internal/wgopenbsd/client_openbsd.go +++ b/internal/wgopenbsd/client_openbsd.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "os" "runtime" "time" @@ -201,7 +202,7 @@ func parseDevice(name string, ifio *wgh.WGInterfaceIO) (*wgtypes.Device, error) // Same idea, we know how many allowed IPs we need to account for, so // reserve the space and advance the pointer through each WGAIP structure. - p.AllowedIPs = make([]net.IPNet, 0, peer.Aips_count) + p.AllowedIPs = make([]netip.Prefix, 0, peer.Aips_count) for j := uintptr(0); j < uintptr(peer.Aips_count); j++ { aip := (*wgh.WGAIPIO)(unsafe.Pointer( uintptr(unsafe.Pointer(peer)) + wgh.SizeofWGPeerIO + j*wgh.SizeofWGAIPIO, @@ -283,21 +284,15 @@ func parsePeer(pio *wgh.WGPeerIO) wgtypes.Peer { } // parseAllowedIP unpacks a net.IPNet from a WGAIP structure. -func parseAllowedIP(aip *wgh.WGAIPIO) net.IPNet { +func parseAllowedIP(aip *wgh.WGAIPIO) netip.Prefix { switch aip.Af { case unix.AF_INET: - return net.IPNet{ - IP: net.IP(aip.Addr[:net.IPv4len]), - Mask: net.CIDRMask(int(aip.Cidr), 32), - } + return netip.PrefixFrom(netip.AddrFrom4([4]byte(aip.Addr[:4])), int(aip.Cidr)) case unix.AF_INET6: - return net.IPNet{ - IP: net.IP(aip.Addr[:]), - Mask: net.CIDRMask(int(aip.Cidr), 128), - } + return netip.PrefixFrom(netip.AddrFrom16(aip.Addr), int(aip.Cidr)) default: panicf("wgopenbsd: invalid address family for allowed IP: %+v", aip) - return net.IPNet{} + return netip.Prefix{} } } diff --git a/internal/wgopenbsd/client_openbsd_test.go b/internal/wgopenbsd/client_openbsd_test.go index 32043cd..6959c2f 100644 --- a/internal/wgopenbsd/client_openbsd_test.go +++ b/internal/wgopenbsd/client_openbsd_test.go @@ -5,7 +5,7 @@ package wgopenbsd import ( "errors" - "net" + "net/netip" "os" "testing" "time" @@ -239,20 +239,20 @@ func TestClientDeviceBasic(t *testing.T) { ReceiveBytes: 2, TransmitBytes: 1, LastHandshakeTime: time.Unix(1, 2), - AllowedIPs: []net.IPNet{ - wgtest.MustCIDR("192.168.1.0/24"), - wgtest.MustCIDR("fd00::/64"), + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("fd00::/64"), }, ProtocolVersion: 1, }, { PublicKey: peerB, Endpoint: wgtest.MustUDPAddr("[::1]:2048"), - AllowedIPs: []net.IPNet{wgtest.MustCIDR("2001:db8::1/128")}, + AllowedIPs: []netip.Prefix{netip.MustParsePrefix("2001:db8::1/128")}, }, { PublicKey: peerC, - AllowedIPs: []net.IPNet{}, + AllowedIPs: []netip.Prefix{}, }, }, } From 2a11b05da6ee189a3420b955d2089264f23a1bef Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Mon, 1 Sep 2025 21:48:55 +0100 Subject: [PATCH 7/8] Convert wgwindows to use netip.Prefix Signed-off-by: MrMelon54 --- internal/wgwindows/client_windows.go | 39 +++++++++---------- .../internal/ioctl/winipcfg_windows.go | 21 +++++----- 2 files changed, 29 insertions(+), 31 deletions(-) diff --git a/internal/wgwindows/client_windows.go b/internal/wgwindows/client_windows.go index eb4c1c9..a1f9c86 100644 --- a/internal/wgwindows/client_windows.go +++ b/internal/wgwindows/client_windows.go @@ -2,6 +2,7 @@ package wgwindows import ( "net" + "net/netip" "os" "time" "unsafe" @@ -209,19 +210,14 @@ func (c *Client) Device(name string) (*wgtypes.Device, error) { } else { a = a.NextAllowedIP() } - var ip net.IP - var bits int - if a.AddressFamily == windows.AF_INET { - ip = a.Address[:4] - bits = 32 - } else if a.AddressFamily == windows.AF_INET6 { - ip = a.Address[:16] - bits = 128 + var prefix netip.Prefix + switch a.AddressFamily { + case windows.AF_INET: + prefix = netip.PrefixFrom(netip.AddrFrom4([4]byte(a.Address[:4])), int(a.Cidr)) + case windows.AF_INET6: + prefix = netip.PrefixFrom(netip.AddrFrom16(a.Address), int(a.Cidr)) } - peer.AllowedIPs = append(peer.AllowedIPs, net.IPNet{ - IP: ip, - Mask: net.CIDRMask(int(a.Cidr), bits), - }) + peer.AllowedIPs = append(peer.AllowedIPs, prefix) } device.Peers = append(device.Peers, peer) } @@ -276,7 +272,7 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { } if cfg.Peers[i].Endpoint != nil { peer.Flags |= ioctl.PeerHasEndpoint - peer.Endpoint.SetIP(cfg.Peers[i].Endpoint.IP, uint16(cfg.Peers[i].Endpoint.Port)) + peer.Endpoint.SetAddrPort(cfg.Peers[i].Endpoint.AddrPort()) } if cfg.Peers[i].PersistentKeepaliveInterval != nil { peer.Flags |= ioctl.PeerHasPersistentKeepalive @@ -285,20 +281,21 @@ func (c *Client) ConfigureDevice(name string, cfg wgtypes.Config) error { b.AppendPeer(peer) for j := range cfg.Peers[i].AllowedIPs { var family ioctl.AddressFamily - var ip net.IP - if ip = cfg.Peers[i].AllowedIPs[j].IP.To4(); ip != nil { + prefix := cfg.Peers[i].AllowedIPs[j] + + // Unmap 4in6 addresses to maintain previous compatibility + addr := prefix.Addr().Unmap() + switch { + case addr.Is4(): family = windows.AF_INET - } else if ip = cfg.Peers[i].AllowedIPs[j].IP.To16(); ip != nil { + case addr.Is6(): family = windows.AF_INET6 - } else { - ip = cfg.Peers[i].AllowedIPs[j].IP } - cidr, _ := cfg.Peers[i].AllowedIPs[j].Mask.Size() a := &ioctl.AllowedIP{ AddressFamily: family, - Cidr: uint8(cidr), + Cidr: uint8(prefix.Bits()), } - copy(a.Address[:], ip) + copy(a.Address[:], addr.AsSlice()) b.AppendAllowedIP(a) } } diff --git a/internal/wgwindows/internal/ioctl/winipcfg_windows.go b/internal/wgwindows/internal/ioctl/winipcfg_windows.go index 55d8cc2..298ddad 100644 --- a/internal/wgwindows/internal/ioctl/winipcfg_windows.go +++ b/internal/wgwindows/internal/ioctl/winipcfg_windows.go @@ -8,6 +8,7 @@ package ioctl import ( "encoding/binary" "net" + "net/netip" "unsafe" "golang.org/x/sys/windows" @@ -33,26 +34,26 @@ func htons(i uint16) uint16 { return *(*uint16)(unsafe.Pointer(&b[0])) } -// SetIP method sets family, address, and port to the given IPv4 or IPv6 address and port. +// SetAddrPort method sets family, address, and port to the given IPv4 or IPv6 address and port. // All other members of the structure are set to zero. -func (addr *RawSockaddrInet) SetIP(ip net.IP, port uint16) error { - if v4 := ip.To4(); v4 != nil { +func (addr *RawSockaddrInet) SetAddrPort(addrPort netip.AddrPort) error { + a := addrPort.Addr().Unmap() + switch { + case a.Is4(): addr4 := (*windows.RawSockaddrInet4)(unsafe.Pointer(addr)) addr4.Family = windows.AF_INET - copy(addr4.Addr[:], v4) - addr4.Port = htons(port) + addr4.Addr = a.As4() + addr4.Port = htons(addrPort.Port()) for i := 0; i < 8; i++ { addr4.Zero[i] = 0 } return nil - } - - if v6 := ip.To16(); v6 != nil { + case a.Is6(): addr6 := (*windows.RawSockaddrInet6)(unsafe.Pointer(addr)) addr6.Family = windows.AF_INET6 - addr6.Port = htons(port) + addr6.Port = htons(addrPort.Port()) addr6.Flowinfo = 0 - copy(addr6.Addr[:], v6) + addr6.Addr = a.As16() addr6.Scope_id = 0 return nil } From 49dce0a131087f766c17d58dd171f419b765d3b4 Mon Sep 17 00:00:00 2001 From: MrMelon54 Date: Mon, 1 Sep 2025 23:15:15 +0100 Subject: [PATCH 8/8] Convert wgfreebsd to use netip.Prefix Signed-off-by: MrMelon54 --- internal/wgfreebsd/client_freebsd.go | 67 +++++++++++++--------------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/internal/wgfreebsd/client_freebsd.go b/internal/wgfreebsd/client_freebsd.go index 98b1f81..7782c04 100644 --- a/internal/wgfreebsd/client_freebsd.go +++ b/internal/wgfreebsd/client_freebsd.go @@ -12,8 +12,10 @@ import ( "encoding/binary" "fmt" "net" + "net/netip" "os" "runtime" + "strconv" "time" "unsafe" @@ -275,24 +277,18 @@ func parseEndpoint(ep []byte) *net.UDPAddr { case unix.AF_INET: sa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&ep[0])) - ep := &net.UDPAddr{ - IP: make(net.IP, net.IPv4len), - Port: ntohs(sa.Port), - } - copy(ep.IP, sa.Addr[:]) - - return ep + return net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), sa.Port)) case unix.AF_INET6: sa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&ep[0])) - // TODO(mdlayher): IPv6 zone? - ep := &net.UDPAddr{ - IP: make(net.IP, net.IPv6len), - Port: ntohs(sa.Port), - } - copy(ep.IP, sa.Addr[:]) + addr := netip.AddrFrom16(sa.Addr) - return ep + // If the address is an IPv6 link-local address and the scope ID is non-zero + // then use the scope ID as the zone + if addr.Is6() && addr.IsLinkLocalUnicast() && sa.Scope_id != 0 { + addr = addr.WithZone(strconv.FormatUint(uint64(sa.Scope_id), 10)) + } + return net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, sa.Port)) default: // No endpoint configured. return nil @@ -302,54 +298,55 @@ func parseEndpoint(ep []byte) *net.UDPAddr { func unparseEndpoint(ep net.UDPAddr) []byte { var b []byte - if v4 := ep.IP.To4(); v4 != nil { + addrPort := ep.AddrPort() + addr := addrPort.Addr().Unmap() + + switch { + case addr.Is4(): b = make([]byte, unsafe.Sizeof(unix.RawSockaddrInet4{})) sa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&b[0])) sa.Family = unix.AF_INET sa.Port = htons(ep.Port) - copy(sa.Addr[:], v4) - } else if v6 := ep.IP.To16(); v6 != nil { + sa.Addr = addr.As4() + case addr.Is6(): b = make([]byte, unsafe.Sizeof(unix.RawSockaddrInet6{})) sa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&b[0])) sa.Family = unix.AF_INET6 sa.Port = htons(ep.Port) - copy(sa.Addr[:], v6) + sa.Addr = addr.As16() } return b } // parseAllowedIP unpacks a net.IPNet from a WGAIP structure. -func parseAllowedIP(aip nv.List) net.IPNet { +func parseAllowedIP(aip nv.List) netip.Prefix { cidr := int(aip["cidr"].(uint64)) if ip, ok := aip["ipv4"]; ok { - return net.IPNet{ - IP: net.IP(ip.([]byte)), - Mask: net.CIDRMask(cidr, 32), - } + addr, _ := netip.AddrFromSlice(ip.([]byte)) + return netip.PrefixFrom(addr, cidr) } else if ip, ok := aip["ipv6"]; ok { - return net.IPNet{ - IP: net.IP(ip.([]byte)), - Mask: net.CIDRMask(cidr, 128), - } + addr, _ := netip.AddrFromSlice(ip.([]byte)) + return netip.PrefixFrom(addr, cidr) } else { panicf("wgfreebsd: invalid address family for allowed IP: %+v", aip) - return net.IPNet{} + return netip.Prefix{} } } -func unparseAllowedIP(aip net.IPNet) nv.List { +func unparseAllowedIP(aip netip.Prefix) nv.List { m := nv.List{} - ones, _ := aip.Mask.Size() - m["cidr"] = uint64(ones) + m["cidr"] = uint64(aip.Bits()) - if v4 := aip.IP.To4(); v4 != nil { - m["ipv4"] = []byte(v4) - } else if v6 := aip.IP.To16(); v6 != nil { - m["ipv6"] = []byte(v6) + addr := aip.Addr().Unmap() + switch { + case addr.Is4(): + m["ipv4"] = addr.AsSlice() + case addr.Is6(): + m["ipv6"] = addr.AsSlice() } return m