Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 22 additions & 19 deletions client_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package wgctrl_test

import (
"bytes"
"errors"
"fmt"
"net"
"net/netip"
"os"
"sort"
"strings"
Expand Down Expand Up @@ -144,9 +144,9 @@ func testGet(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
var (
port = 8888
ips = []net.IPNet{
wgtest.MustCIDR("192.0.2.0/32"),
wgtest.MustCIDR("2001:db8::/128"),
ips = []netip.Prefix{
netip.MustParsePrefix("192.0.2.0/32"),
netip.MustParsePrefix("2001:db8::/128"),
}

priv = wgtest.MustPrivateKey()
Expand Down Expand Up @@ -194,7 +194,7 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
for i := range dn.Peers {
ips := dn.Peers[i].AllowedIPs
sort.Slice(ips, func(i, j int) bool {
return bytes.Compare(ips[i].IP, ips[j].IP) > 0
return ips[i].Addr().Compare(ips[j].Addr()) > 0
})
}

Expand Down Expand Up @@ -229,17 +229,19 @@ func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
t.Fatalf("failed to create cursor: %v", err)
}

var ips []net.IPNet
var ips []netip.Prefix
for pos := cur.Next(); pos != nil; pos = cur.Next() {
bits := 128
if pos.IP.To4() != nil {
bits = 32
}

ips = append(ips, net.IPNet{
IP: pos.IP,
Mask: net.CIDRMask(bits, bits),
})
addr, ok := netip.AddrFromSlice(pos.IP)
if !ok {
t.Fatalf("failed to convert net.IP to netip.Addr: %s", pos.IP)
}

ips = append(ips, netip.PrefixFrom(addr, bits))
}

peers = append(peers, wgtypes.PeerConfig{
Expand Down Expand Up @@ -291,7 +293,7 @@ func testConfigureManyPeers(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
PresharedKey: &pk,
ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{
IP: ips[0].IP,
IP: ips[0].Addr().AsSlice(),
Port: 1111,
},
PersistentKeepaliveInterval: &dur,
Expand Down Expand Up @@ -370,7 +372,6 @@ func testConfigurePeersUpdateOnly(t *testing.T, c *wgctrl.Client, d *wgtypes.Dev
t.Skip("FreeBSD kernel devices do not support UpdateOnly flag")
}


t.Fatalf("failed to configure second time on %q: %v", d.Name, err)
}

Expand Down Expand Up @@ -428,7 +429,7 @@ func countPeerIPs(d *wgtypes.Device) int {
return count
}

func ipsString(ipns []net.IPNet) string {
func ipsString(ipns []netip.Prefix) string {
ss := make([]string, 0, len(ipns))
for _, ipn := range ipns {
ss = append(ss, ipn.String())
Expand All @@ -437,23 +438,25 @@ func ipsString(ipns []net.IPNet) string {
return strings.Join(ss, ", ")
}

func generateIPs(n int) []net.IPNet {
func generateIPs(n int) []netip.Prefix {
cur, err := ipaddr.Parse("2001:db8::/64")
if err != nil {
panicf("failed to create cursor: %v", err)
}

ips := make([]net.IPNet, 0, n)
ips := make([]netip.Prefix, 0, n)
for i := 0; i < n; i++ {
pos := cur.Next()
if pos == nil {
panic("hit nil IP during IP generation")
}

ips = append(ips, net.IPNet{
IP: pos.IP,
Mask: net.CIDRMask(128, 128),
})
addr, ok := netip.AddrFromSlice(pos.IP)
if !ok {
panicf("failed to convert net.IP to netip.Addr: %s", pos.IP)
}

ips = append(ips, netip.PrefixFrom(addr, 128))
}

return ips
Expand Down
4 changes: 2 additions & 2 deletions cmd/wgctrl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"flag"
"fmt"
"log"
"net"
"net/netip"
"strings"

"golang.zx2c4.com/wireguard/wgctrl"
Expand Down Expand Up @@ -83,7 +83,7 @@ func printPeer(p wgtypes.Peer) {
)
}

func ipsString(ipns []net.IPNet) string {
func ipsString(ipns []netip.Prefix) string {
ss := make([]string, 0, len(ipns))
for _, ipn := range ipns {
ss = append(ss, ipn.String())
Expand Down
67 changes: 32 additions & 35 deletions internal/wgfreebsd/client_freebsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"strconv"
"time"
"unsafe"

Expand Down Expand Up @@ -275,24 +277,18 @@ func parseEndpoint(ep []byte) *net.UDPAddr {
case unix.AF_INET:
sa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&ep[0]))

ep := &net.UDPAddr{
IP: make(net.IP, net.IPv4len),
Port: ntohs(sa.Port),
}
copy(ep.IP, sa.Addr[:])

return ep
return net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), sa.Port))
case unix.AF_INET6:
sa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&ep[0]))

// TODO(mdlayher): IPv6 zone?
ep := &net.UDPAddr{
IP: make(net.IP, net.IPv6len),
Port: ntohs(sa.Port),
}
copy(ep.IP, sa.Addr[:])
addr := netip.AddrFrom16(sa.Addr)

return ep
// If the address is an IPv6 link-local address and the scope ID is non-zero
// then use the scope ID as the zone
if addr.Is6() && addr.IsLinkLocalUnicast() && sa.Scope_id != 0 {
addr = addr.WithZone(strconv.FormatUint(uint64(sa.Scope_id), 10))
}
return net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, sa.Port))
default:
// No endpoint configured.
return nil
Expand All @@ -302,54 +298,55 @@ func parseEndpoint(ep []byte) *net.UDPAddr {
func unparseEndpoint(ep net.UDPAddr) []byte {
var b []byte

if v4 := ep.IP.To4(); v4 != nil {
addrPort := ep.AddrPort()
addr := addrPort.Addr().Unmap()

switch {
case addr.Is4():
b = make([]byte, unsafe.Sizeof(unix.RawSockaddrInet4{}))
sa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&b[0]))

sa.Family = unix.AF_INET
sa.Port = htons(ep.Port)
copy(sa.Addr[:], v4)
} else if v6 := ep.IP.To16(); v6 != nil {
sa.Addr = addr.As4()
case addr.Is6():
b = make([]byte, unsafe.Sizeof(unix.RawSockaddrInet6{}))
sa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&b[0]))

sa.Family = unix.AF_INET6
sa.Port = htons(ep.Port)
copy(sa.Addr[:], v6)
sa.Addr = addr.As16()
}

return b
}

// parseAllowedIP unpacks a net.IPNet from a WGAIP structure.
func parseAllowedIP(aip nv.List) net.IPNet {
func parseAllowedIP(aip nv.List) netip.Prefix {
cidr := int(aip["cidr"].(uint64))
if ip, ok := aip["ipv4"]; ok {
return net.IPNet{
IP: net.IP(ip.([]byte)),
Mask: net.CIDRMask(cidr, 32),
}
addr, _ := netip.AddrFromSlice(ip.([]byte))
return netip.PrefixFrom(addr, cidr)
} else if ip, ok := aip["ipv6"]; ok {
return net.IPNet{
IP: net.IP(ip.([]byte)),
Mask: net.CIDRMask(cidr, 128),
}
addr, _ := netip.AddrFromSlice(ip.([]byte))
return netip.PrefixFrom(addr, cidr)
} else {
panicf("wgfreebsd: invalid address family for allowed IP: %+v", aip)
return net.IPNet{}
return netip.Prefix{}
}
}

func unparseAllowedIP(aip net.IPNet) nv.List {
func unparseAllowedIP(aip netip.Prefix) nv.List {
m := nv.List{}

ones, _ := aip.Mask.Size()
m["cidr"] = uint64(ones)
m["cidr"] = uint64(aip.Bits())

if v4 := aip.IP.To4(); v4 != nil {
m["ipv4"] = []byte(v4)
} else if v6 := aip.IP.To16(); v6 != nil {
m["ipv6"] = []byte(v6)
addr := aip.Addr().Unmap()
switch {
case addr.Is4():
m["ipv4"] = addr.AsSlice()
case addr.Is6():
m["ipv6"] = addr.AsSlice()
}

return m
Expand Down
4 changes: 2 additions & 2 deletions internal/wglinux/client_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package wglinux
import (
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/user"
"syscall"
Expand Down Expand Up @@ -325,7 +325,7 @@ func diffAttrs(x, y []netlink.Attribute) string {
return cmp.Diff(xPrime, yPrime)
}

func mustAllowedIPs(ipns []net.IPNet) []byte {
func mustAllowedIPs(ipns []netip.Prefix) []byte {
ae := netlink.NewAttributeEncoder()
if err := encodeAllowedIPs(ipns)(ae); err != nil {
panicf("failed to create allowed IP attributes: %v", err)
Expand Down
52 changes: 21 additions & 31 deletions internal/wglinux/configure_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"unsafe"

"github.com/mdlayher/netlink"
Expand Down Expand Up @@ -101,16 +102,16 @@ func buildBatches(cfg wgtypes.Config) []wgtypes.Config {
// Iterate until no more allowed IPs.
var done bool
for !done {
var tmp []net.IPNet
var tmp []netip.Prefix
if len(p.AllowedIPs) < ipBatchChunk {
// IPs all fit within a batch; we are done.
tmp = make([]net.IPNet, len(p.AllowedIPs))
tmp = make([]netip.Prefix, len(p.AllowedIPs))
copy(tmp, p.AllowedIPs)
done = true
} else {
// IPs are larger than a single batch, copy a batch out and
// advance the cursor.
tmp = make([]net.IPNet, ipBatchChunk)
tmp = make([]netip.Prefix, ipBatchChunk)
copy(tmp, p.AllowedIPs[:ipBatchChunk])

p.AllowedIPs = p.AllowedIPs[ipBatchChunk:]
Expand Down Expand Up @@ -214,59 +215,58 @@ func encodePeer(p wgtypes.PeerConfig) func(ae *netlink.AttributeEncoder) error {
// sockaddr_in or sockaddr_in6 bytes.
func encodeSockaddr(endpoint net.UDPAddr) func() ([]byte, error) {
return func() ([]byte, error) {
if !isValidIP(endpoint.IP) {
addrPort := endpoint.AddrPort()
if !addrPort.Addr().IsValid() {
return nil, fmt.Errorf("wglinux: invalid endpoint IP: %s", endpoint.IP.String())
}

// Is this an IPv6 address?
if isIPv6(endpoint.IP) {
var addr [16]byte
copy(addr[:], endpoint.IP.To16())
// Unmap 4in6 addresses to maintain previous compatibility
addr := addrPort.Addr().Unmap()

// Is this an IPv6 address?
if addr.Is6() {
sa := unix.RawSockaddrInet6{
Family: unix.AF_INET6,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
Addr: addr.As16(),
}

return (*(*[unix.SizeofSockaddrInet6]byte)(unsafe.Pointer(&sa)))[:], nil
}

// IPv4 address handling.
var addr [4]byte
copy(addr[:], endpoint.IP.To4())

sa := unix.RawSockaddrInet4{
Family: unix.AF_INET,
Port: sockaddrPort(endpoint.Port),
Addr: addr,
Addr: addr.As4(),
}

return (*(*[unix.SizeofSockaddrInet4]byte)(unsafe.Pointer(&sa)))[:], nil
}
}

// encodeAllowedIPs returns a function to encode allowed IP nested attributes.
func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error {
func encodeAllowedIPs(ipns []netip.Prefix) func(ae *netlink.AttributeEncoder) error {
return func(ae *netlink.AttributeEncoder) error {
for i, ipn := range ipns {
if !isValidIP(ipn.IP) {
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.IP.String())
if !ipn.Addr().IsValid() {
return fmt.Errorf("wglinux: invalid allowed IP: %s", ipn.Addr())
}

// Unmap 4in6 addresses to maintain previous compatibility
addr := ipn.Addr().Unmap()

family := uint16(unix.AF_INET6)
if !isIPv6(ipn.IP) {
if addr.Is4() {
// Make sure address is 4 bytes if IPv4.
family = unix.AF_INET
ipn.IP = ipn.IP.To4()
}

// Netlink arrays use type as an array index.
ae.Nested(uint16(i), func(nae *netlink.AttributeEncoder) error {
nae.Uint16(unix.WGALLOWEDIP_A_FAMILY, family)
nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, ipn.IP)
nae.Bytes(unix.WGALLOWEDIP_A_IPADDR, addr.AsSlice())

ones, _ := ipn.Mask.Size()
ones := ipn.Bits()
nae.Uint8(unix.WGALLOWEDIP_A_CIDR_MASK, uint8(ones))
return nil
})
Expand All @@ -276,16 +276,6 @@ func encodeAllowedIPs(ipns []net.IPNet) func(ae *netlink.AttributeEncoder) error
}
}

// isValidIP determines if IP is a valid IPv4 or IPv6 address.
func isValidIP(ip net.IP) bool {
return ip.To16() != nil
}

// isIPv6 determines if IP is a valid IPv6 address.
func isIPv6(ip net.IP) bool {
return isValidIP(ip) && ip.To4() == nil
}

// sockaddrPort interprets port as a big endian uint16 for use passing sockaddr
// structures to the kernel.
func sockaddrPort(port int) uint16 {
Expand Down
Loading