Skip to content

Commit 4776166

Browse files
committed
global: switch to netip
Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent 539979e commit 4776166

17 files changed

+284
-356
lines changed

conf/config.go

Lines changed: 33 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ import (
99
"crypto/rand"
1010
"crypto/subtle"
1111
"encoding/base64"
12+
"errors"
1213
"fmt"
13-
"net"
1414
"strings"
1515
"time"
1616

17+
"golang.zx2c4.com/go118/netip"
18+
1719
"golang.org/x/crypto/curve25519"
1820

1921
"golang.zx2c4.com/wireguard/windows/l18n"
@@ -22,8 +24,7 @@ import (
2224
const KeyLength = 32
2325

2426
type IPCidr struct {
25-
IP net.IP
26-
Cidr uint8
27+
netip.Prefix
2728
}
2829

2930
type Endpoint struct {
@@ -46,7 +47,7 @@ type Interface struct {
4647
Addresses []IPCidr
4748
ListenPort uint16
4849
MTU uint16
49-
DNS []net.IP
50+
DNS []netip.Addr
5051
DNSSearch []string
5152
PreUp string
5253
PostUp string
@@ -67,62 +68,28 @@ type Peer struct {
6768
LastHandshakeTime HandshakeTime
6869
}
6970

70-
func (r *IPCidr) String() string {
71-
return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr)
72-
}
73-
74-
func (r *IPCidr) Bits() uint8 {
75-
if r.IP.To4() != nil {
76-
return 32
77-
}
78-
return 128
79-
}
80-
81-
func (r *IPCidr) IPNet() net.IPNet {
82-
return net.IPNet{
83-
IP: r.IP,
84-
Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())),
85-
}
86-
}
87-
88-
func (r *IPCidr) MaskSelf() {
89-
bits := int(r.Bits())
90-
mask := net.CIDRMask(int(r.Cidr), bits)
91-
for i := 0; i < bits/8; i++ {
92-
r.IP[i] &= mask[i]
93-
}
94-
}
95-
9671
func (conf *Config) IntersectsWith(other *Config) bool {
97-
type hashableIPCidr struct {
98-
ip string
99-
cidr byte
100-
}
101-
allRoutes := make(map[hashableIPCidr]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3)
72+
allRoutes := make(map[netip.Prefix]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3)
10273
for _, a := range conf.Interface.Addresses {
103-
allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] = true
104-
a.MaskSelf()
105-
allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true
74+
allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] = true
75+
allRoutes[a.Masked()] = true
10676
}
10777
for i := range conf.Peers {
10878
for _, a := range conf.Peers[i].AllowedIPs {
109-
a.MaskSelf()
110-
allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true
79+
allRoutes[a.Masked()] = true
11180
}
11281
}
11382
for _, a := range other.Interface.Addresses {
114-
if allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] {
83+
if allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] {
11584
return true
11685
}
117-
a.MaskSelf()
118-
if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] {
86+
if allRoutes[a.Masked()] {
11987
return true
12088
}
12189
}
12290
for i := range other.Peers {
12391
for _, a := range other.Peers[i].AllowedIPs {
124-
a.MaskSelf()
125-
if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] {
92+
if allRoutes[a.Masked()] {
12693
return true
12794
}
12895
}
@@ -233,6 +200,27 @@ func (b Bytes) String() string {
233200
return l18n.Sprintf("%.2f\u00a0TiB", float64(b)/(1024*1024*1024)/1024)
234201
}
235202

203+
func (p IPCidr) MarshalBinary() ([]byte, error) {
204+
b, err := p.Addr().MarshalBinary()
205+
if err != nil {
206+
return nil, err
207+
}
208+
return append(b, uint8(p.Bits())), nil
209+
}
210+
211+
func (p *IPCidr) UnmarshalBinary(b []byte) error {
212+
if len(b) < 1 {
213+
return errors.New("unexpected byte slice")
214+
}
215+
var addr netip.Addr
216+
err := addr.UnmarshalBinary(b[:len(b)-1])
217+
if err != nil {
218+
return err
219+
}
220+
*p = IPCidr{netip.PrefixFrom(addr, int(b[len(b)-1]))}
221+
return nil
222+
}
223+
236224
func (conf *Config) DeduplicateNetworkEntries() {
237225
m := make(map[string]bool, len(conf.Interface.Addresses))
238226
i := 0

conf/dnsresolver_windows.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@ package conf
88
import (
99
"fmt"
1010
"log"
11-
"net"
1211
"syscall"
1312
"time"
1413
"unsafe"
1514

15+
"golang.zx2c4.com/go118/netip"
16+
1617
"golang.org/x/sys/windows"
1718
"golang.zx2c4.com/wireguard/windows/services"
1819
)
@@ -66,24 +67,24 @@ func resolveHostnameOnce(name string) (resolvedIPString string, err error) {
6667
return
6768
}
6869
defer windows.FreeAddrInfoW(result)
69-
ipv6 := ""
70+
var v6 netip.Addr
7071
for ; result != nil; result = result.Next {
7172
switch result.Family {
7273
case windows.AF_INET:
73-
return (net.IP)((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr[:]).String(), nil
74+
return netip.AddrFrom4((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr).String(), nil
7475
case windows.AF_INET6:
75-
if len(ipv6) != 0 {
76+
if v6.IsValid() {
7677
continue
7778
}
7879
a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(result.Addr))
79-
ipv6 = (net.IP)(a.Addr[:]).String()
80+
v6 = netip.AddrFrom16(a.Addr)
8081
if a.Scope_id != 0 {
81-
ipv6 += fmt.Sprintf("%%%d", a.Scope_id)
82+
v6 = v6.WithZone(fmt.Sprint(a.Scope_id))
8283
}
8384
}
8485
}
85-
if len(ipv6) != 0 {
86-
return ipv6, nil
86+
if v6.IsValid() {
87+
return v6.String(), nil
8788
}
8889
err = windows.WSAHOST_NOT_FOUND
8990
return

conf/parser.go

Lines changed: 22 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ package conf
77

88
import (
99
"encoding/base64"
10-
"net"
1110
"strconv"
1211
"strings"
1312

13+
"golang.zx2c4.com/go118/netip"
14+
1415
"golang.org/x/sys/windows"
1516
"golang.org/x/text/encoding/unicode"
1617

@@ -27,43 +28,16 @@ func (e *ParseError) Error() string {
2728
return l18n.Sprintf("%s: %q", e.why, e.offender)
2829
}
2930

30-
func parseIPCidr(s string) (ipcidr *IPCidr, err error) {
31-
var addrStr, cidrStr string
32-
var cidr int
33-
34-
i := strings.IndexByte(s, '/')
35-
if i < 0 {
36-
addrStr = s
37-
} else {
38-
addrStr, cidrStr = s[:i], s[i+1:]
39-
}
40-
41-
err = &ParseError{l18n.Sprintf("Invalid IP address"), s}
42-
addr := net.ParseIP(addrStr)
43-
if addr == nil {
44-
return
45-
}
46-
maybeV4 := addr.To4()
47-
if maybeV4 != nil {
48-
addr = maybeV4
31+
func parseIPCidr(s string) (IPCidr, error) {
32+
ipcidr, err := netip.ParsePrefix(s)
33+
if err == nil {
34+
return IPCidr{ipcidr}, nil
4935
}
50-
if len(cidrStr) > 0 {
51-
err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s}
52-
cidr, err = strconv.Atoi(cidrStr)
53-
if err != nil || cidr < 0 || cidr > 128 {
54-
return
55-
}
56-
if cidr > 32 && maybeV4 != nil {
57-
return
58-
}
59-
} else {
60-
if maybeV4 != nil {
61-
cidr = 32
62-
} else {
63-
cidr = 128
64-
}
36+
addr, err := netip.ParseAddr(s)
37+
if err != nil {
38+
return IPCidr{}, &ParseError{l18n.Sprintf("Invalid IP address: "), s}
6539
}
66-
return &IPCidr{addr, uint8(cidr)}, nil
40+
return IPCidr{netip.PrefixFrom(addr, addr.BitLen())}, nil
6741
}
6842

6943
func parseEndpoint(s string) (*Endpoint, error) {
@@ -87,16 +61,16 @@ func parseEndpoint(s string) (*Endpoint, error) {
8761
if i := strings.LastIndexByte(host, '%'); i > 1 {
8862
end = i
8963
}
90-
maybeV6 := net.ParseIP(host[1:end])
91-
if maybeV6 == nil || len(maybeV6) != net.IPv6len {
64+
maybeV6, err2 := netip.ParseAddr(host[1:end])
65+
if err2 != nil || !maybeV6.Is6() {
9266
return nil, err
9367
}
9468
} else {
9569
return nil, err
9670
}
9771
host = host[1 : len(host)-1]
9872
}
99-
return &Endpoint{host, uint16(port)}, nil
73+
return &Endpoint{host, port}, nil
10074
}
10175

10276
func parseMTU(s string) (uint16, error) {
@@ -256,16 +230,16 @@ func FromWgQuick(s string, name string) (*Config, error) {
256230
if err != nil {
257231
return nil, err
258232
}
259-
conf.Interface.Addresses = append(conf.Interface.Addresses, *a)
233+
conf.Interface.Addresses = append(conf.Interface.Addresses, a)
260234
}
261235
case "dns":
262236
addresses, err := splitList(val)
263237
if err != nil {
264238
return nil, err
265239
}
266240
for _, address := range addresses {
267-
a := net.ParseIP(address)
268-
if a == nil {
241+
a, err := netip.ParseAddr(address)
242+
if err != nil {
269243
conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address)
270244
} else {
271245
conf.Interface.DNS = append(conf.Interface.DNS, a)
@@ -312,7 +286,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
312286
if err != nil {
313287
return nil, err
314288
}
315-
peer.AllowedIPs = append(peer.AllowedIPs, *a)
289+
peer.AllowedIPs = append(peer.AllowedIPs, a)
316290
}
317291
case "persistentkeepalive":
318292
p, err := parsePersistentKeepalive(val)
@@ -399,7 +373,7 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config
399373
}
400374
if p.Flags&driver.PeerHasEndpoint != 0 {
401375
peer.Endpoint.Port = p.Endpoint.Port()
402-
peer.Endpoint.Host = p.Endpoint.IP().String()
376+
peer.Endpoint.Host = p.Endpoint.Addr().String()
403377
}
404378
if p.Flags&driver.PeerHasPersistentKeepalive != 0 {
405379
peer.PersistentKeepalive = p.PersistentKeepalive
@@ -416,16 +390,13 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config
416390
} else {
417391
a = a.NextAllowedIP()
418392
}
419-
var ip net.IP
393+
var ip netip.Addr
420394
if a.AddressFamily == windows.AF_INET {
421-
ip = a.Address[:4]
395+
ip = netip.AddrFrom4(*(*[4]byte)(a.Address[:4]))
422396
} else if a.AddressFamily == windows.AF_INET6 {
423-
ip = a.Address[:16]
397+
ip = netip.AddrFrom16(*(*[16]byte)(a.Address[:16]))
424398
}
425-
peer.AllowedIPs = append(peer.AllowedIPs, IPCidr{
426-
IP: ip,
427-
Cidr: a.Cidr,
428-
})
399+
peer.AllowedIPs = append(peer.AllowedIPs, IPCidr{netip.PrefixFrom(ip, int(a.Cidr))})
429400
}
430401
conf.Peers = append(conf.Peers, peer)
431402
}

conf/parser_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
package conf
77

88
import (
9-
"net"
109
"reflect"
1110
"runtime"
1211
"testing"
12+
13+
"golang.zx2c4.com/go118/netip"
1314
)
1415

1516
const testInput = `
@@ -77,10 +78,9 @@ func contains(t *testing.T, list, element interface{}) bool {
7778
func TestFromWgQuick(t *testing.T) {
7879
conf, err := FromWgQuick(testInput, "test")
7980
if noError(t, err) {
80-
8181
lenTest(t, conf.Interface.Addresses, 2)
82-
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)})
83-
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)})
82+
contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{0, 10, 0, 1}), 16))
83+
contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 192, 122, 1}), 24))
8484
equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String())
8585
equal(t, uint16(51820), conf.Interface.ListenPort)
8686

0 commit comments

Comments
 (0)