Skip to content

Commit 25d879e

Browse files
committed
global: switch to netip
Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent 539979e commit 25d879e

File tree

17 files changed

+260
-367
lines changed

17 files changed

+260
-367
lines changed

conf/config.go

Lines changed: 12 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,18 @@ import (
1010
"crypto/subtle"
1111
"encoding/base64"
1212
"fmt"
13-
"net"
1413
"strings"
1514
"time"
1615

16+
"golang.zx2c4.com/go118/netip"
17+
1718
"golang.org/x/crypto/curve25519"
1819

1920
"golang.zx2c4.com/wireguard/windows/l18n"
2021
)
2122

2223
const KeyLength = 32
2324

24-
type IPCidr struct {
25-
IP net.IP
26-
Cidr uint8
27-
}
28-
2925
type Endpoint struct {
3026
Host string
3127
Port uint16
@@ -43,10 +39,10 @@ type Config struct {
4339

4440
type Interface struct {
4541
PrivateKey Key
46-
Addresses []IPCidr
42+
Addresses []netip.Prefix
4743
ListenPort uint16
4844
MTU uint16
49-
DNS []net.IP
45+
DNS []netip.Addr
5046
DNSSearch []string
5147
PreUp string
5248
PostUp string
@@ -58,7 +54,7 @@ type Interface struct {
5854
type Peer struct {
5955
PublicKey Key
6056
PresharedKey Key
61-
AllowedIPs []IPCidr
57+
AllowedIPs []netip.Prefix
6258
Endpoint Endpoint
6359
PersistentKeepalive uint16
6460

@@ -67,62 +63,28 @@ type Peer struct {
6763
LastHandshakeTime HandshakeTime
6864
}
6965

70-
func (r *IPCidr) String() string {
71-
return fmt.Sprintf("%s/%d", r.IP.String(), r.Cidr)
72-
}
73-
74-
func (r *IPCidr) Bits() uint8 {
75-
if r.IP.To4() != nil {
76-
return 32
77-
}
78-
return 128
79-
}
80-
81-
func (r *IPCidr) IPNet() net.IPNet {
82-
return net.IPNet{
83-
IP: r.IP,
84-
Mask: net.CIDRMask(int(r.Cidr), int(r.Bits())),
85-
}
86-
}
87-
88-
func (r *IPCidr) MaskSelf() {
89-
bits := int(r.Bits())
90-
mask := net.CIDRMask(int(r.Cidr), bits)
91-
for i := 0; i < bits/8; i++ {
92-
r.IP[i] &= mask[i]
93-
}
94-
}
95-
9666
func (conf *Config) IntersectsWith(other *Config) bool {
97-
type hashableIPCidr struct {
98-
ip string
99-
cidr byte
100-
}
101-
allRoutes := make(map[hashableIPCidr]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3)
67+
allRoutes := make(map[netip.Prefix]bool, len(conf.Interface.Addresses)*2+len(conf.Peers)*3)
10268
for _, a := range conf.Interface.Addresses {
103-
allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] = true
104-
a.MaskSelf()
105-
allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true
69+
allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] = true
70+
allRoutes[a.Masked()] = true
10671
}
10772
for i := range conf.Peers {
10873
for _, a := range conf.Peers[i].AllowedIPs {
109-
a.MaskSelf()
110-
allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] = true
74+
allRoutes[a.Masked()] = true
11175
}
11276
}
11377
for _, a := range other.Interface.Addresses {
114-
if allRoutes[hashableIPCidr{string(a.IP), byte(len(a.IP) * 8)}] {
78+
if allRoutes[netip.PrefixFrom(a.Addr(), a.Addr().BitLen())] {
11579
return true
11680
}
117-
a.MaskSelf()
118-
if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] {
81+
if allRoutes[a.Masked()] {
11982
return true
12083
}
12184
}
12285
for i := range other.Peers {
12386
for _, a := range other.Peers[i].AllowedIPs {
124-
a.MaskSelf()
125-
if allRoutes[hashableIPCidr{string(a.IP), a.Cidr}] {
87+
if allRoutes[a.Masked()] {
12688
return true
12789
}
12890
}

conf/dnsresolver_windows.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
package conf
77

88
import (
9-
"fmt"
109
"log"
11-
"net"
12-
"syscall"
10+
"strconv"
1311
"time"
1412
"unsafe"
1513

14+
"golang.zx2c4.com/go118/netip"
15+
1616
"golang.org/x/sys/windows"
1717
"golang.zx2c4.com/wireguard/windows/services"
1818
)
@@ -66,24 +66,24 @@ func resolveHostnameOnce(name string) (resolvedIPString string, err error) {
6666
return
6767
}
6868
defer windows.FreeAddrInfoW(result)
69-
ipv6 := ""
69+
var v6 netip.Addr
7070
for ; result != nil; result = result.Next {
7171
switch result.Family {
7272
case windows.AF_INET:
73-
return (net.IP)((*syscall.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr[:]).String(), nil
73+
return netip.AddrFrom4((*windows.RawSockaddrInet4)(unsafe.Pointer(result.Addr)).Addr).String(), nil
7474
case windows.AF_INET6:
75-
if len(ipv6) != 0 {
75+
if v6.IsValid() {
7676
continue
7777
}
78-
a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(result.Addr))
79-
ipv6 = (net.IP)(a.Addr[:]).String()
78+
a := (*windows.RawSockaddrInet6)(unsafe.Pointer(result.Addr))
79+
v6 = netip.AddrFrom16(a.Addr)
8080
if a.Scope_id != 0 {
81-
ipv6 += fmt.Sprintf("%%%d", a.Scope_id)
81+
v6 = v6.WithZone(strconv.FormatUint(uint64(a.Scope_id), 10))
8282
}
8383
}
8484
}
85-
if len(ipv6) != 0 {
86-
return ipv6, nil
85+
if v6.IsValid() {
86+
return v6.String(), nil
8787
}
8888
err = windows.WSAHOST_NOT_FOUND
8989
return

conf/parser.go

Lines changed: 22 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ package conf
77

88
import (
99
"encoding/base64"
10-
"net"
1110
"strconv"
1211
"strings"
1312

13+
"golang.zx2c4.com/go118/netip"
14+
1415
"golang.org/x/sys/windows"
1516
"golang.org/x/text/encoding/unicode"
1617

@@ -27,43 +28,16 @@ func (e *ParseError) Error() string {
2728
return l18n.Sprintf("%s: %q", e.why, e.offender)
2829
}
2930

30-
func parseIPCidr(s string) (ipcidr *IPCidr, err error) {
31-
var addrStr, cidrStr string
32-
var cidr int
33-
34-
i := strings.IndexByte(s, '/')
35-
if i < 0 {
36-
addrStr = s
37-
} else {
38-
addrStr, cidrStr = s[:i], s[i+1:]
39-
}
40-
41-
err = &ParseError{l18n.Sprintf("Invalid IP address"), s}
42-
addr := net.ParseIP(addrStr)
43-
if addr == nil {
44-
return
45-
}
46-
maybeV4 := addr.To4()
47-
if maybeV4 != nil {
48-
addr = maybeV4
31+
func parseIPCidr(s string) (netip.Prefix, error) {
32+
ipcidr, err := netip.ParsePrefix(s)
33+
if err == nil {
34+
return ipcidr, nil
4935
}
50-
if len(cidrStr) > 0 {
51-
err = &ParseError{l18n.Sprintf("Invalid network prefix length"), s}
52-
cidr, err = strconv.Atoi(cidrStr)
53-
if err != nil || cidr < 0 || cidr > 128 {
54-
return
55-
}
56-
if cidr > 32 && maybeV4 != nil {
57-
return
58-
}
59-
} else {
60-
if maybeV4 != nil {
61-
cidr = 32
62-
} else {
63-
cidr = 128
64-
}
36+
addr, err := netip.ParseAddr(s)
37+
if err != nil {
38+
return netip.Prefix{}, &ParseError{l18n.Sprintf("Invalid IP address: "), s}
6539
}
66-
return &IPCidr{addr, uint8(cidr)}, nil
40+
return netip.PrefixFrom(addr, addr.BitLen()), nil
6741
}
6842

6943
func parseEndpoint(s string) (*Endpoint, error) {
@@ -87,16 +61,16 @@ func parseEndpoint(s string) (*Endpoint, error) {
8761
if i := strings.LastIndexByte(host, '%'); i > 1 {
8862
end = i
8963
}
90-
maybeV6 := net.ParseIP(host[1:end])
91-
if maybeV6 == nil || len(maybeV6) != net.IPv6len {
64+
maybeV6, err2 := netip.ParseAddr(host[1:end])
65+
if err2 != nil || !maybeV6.Is6() {
9266
return nil, err
9367
}
9468
} else {
9569
return nil, err
9670
}
9771
host = host[1 : len(host)-1]
9872
}
99-
return &Endpoint{host, uint16(port)}, nil
73+
return &Endpoint{host, port}, nil
10074
}
10175

10276
func parseMTU(s string) (uint16, error) {
@@ -256,16 +230,16 @@ func FromWgQuick(s string, name string) (*Config, error) {
256230
if err != nil {
257231
return nil, err
258232
}
259-
conf.Interface.Addresses = append(conf.Interface.Addresses, *a)
233+
conf.Interface.Addresses = append(conf.Interface.Addresses, a)
260234
}
261235
case "dns":
262236
addresses, err := splitList(val)
263237
if err != nil {
264238
return nil, err
265239
}
266240
for _, address := range addresses {
267-
a := net.ParseIP(address)
268-
if a == nil {
241+
a, err := netip.ParseAddr(address)
242+
if err != nil {
269243
conf.Interface.DNSSearch = append(conf.Interface.DNSSearch, address)
270244
} else {
271245
conf.Interface.DNS = append(conf.Interface.DNS, a)
@@ -312,7 +286,7 @@ func FromWgQuick(s string, name string) (*Config, error) {
312286
if err != nil {
313287
return nil, err
314288
}
315-
peer.AllowedIPs = append(peer.AllowedIPs, *a)
289+
peer.AllowedIPs = append(peer.AllowedIPs, a)
316290
}
317291
case "persistentkeepalive":
318292
p, err := parsePersistentKeepalive(val)
@@ -399,7 +373,7 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config
399373
}
400374
if p.Flags&driver.PeerHasEndpoint != 0 {
401375
peer.Endpoint.Port = p.Endpoint.Port()
402-
peer.Endpoint.Host = p.Endpoint.IP().String()
376+
peer.Endpoint.Host = p.Endpoint.Addr().String()
403377
}
404378
if p.Flags&driver.PeerHasPersistentKeepalive != 0 {
405379
peer.PersistentKeepalive = p.PersistentKeepalive
@@ -416,16 +390,13 @@ func FromDriverConfiguration(interfaze *driver.Interface, existingConfig *Config
416390
} else {
417391
a = a.NextAllowedIP()
418392
}
419-
var ip net.IP
393+
var ip netip.Addr
420394
if a.AddressFamily == windows.AF_INET {
421-
ip = a.Address[:4]
395+
ip = netip.AddrFrom4(*(*[4]byte)(a.Address[:4]))
422396
} else if a.AddressFamily == windows.AF_INET6 {
423-
ip = a.Address[:16]
397+
ip = netip.AddrFrom16(*(*[16]byte)(a.Address[:16]))
424398
}
425-
peer.AllowedIPs = append(peer.AllowedIPs, IPCidr{
426-
IP: ip,
427-
Cidr: a.Cidr,
428-
})
399+
peer.AllowedIPs = append(peer.AllowedIPs, netip.PrefixFrom(ip, int(a.Cidr)))
429400
}
430401
conf.Peers = append(conf.Peers, peer)
431402
}

conf/parser_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
package conf
77

88
import (
9-
"net"
109
"reflect"
1110
"runtime"
1211
"testing"
12+
13+
"golang.zx2c4.com/go118/netip"
1314
)
1415

1516
const testInput = `
@@ -77,10 +78,9 @@ func contains(t *testing.T, list, element interface{}) bool {
7778
func TestFromWgQuick(t *testing.T) {
7879
conf, err := FromWgQuick(testInput, "test")
7980
if noError(t, err) {
80-
8181
lenTest(t, conf.Interface.Addresses, 2)
82-
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 10, 0, 1), uint8(16)})
83-
contains(t, conf.Interface.Addresses, IPCidr{net.IPv4(10, 192, 122, 1), uint8(24)})
82+
contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{0, 10, 0, 1}), 16))
83+
contains(t, conf.Interface.Addresses, netip.PrefixFrom(netip.AddrFrom4([4]byte{10, 192, 122, 1}), 24))
8484
equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.Interface.PrivateKey.String())
8585
equal(t, uint16(51820), conf.Interface.ListenPort)
8686

conf/writer.go

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@ package conf
77

88
import (
99
"fmt"
10-
"net"
1110
"strings"
1211
"unsafe"
1312

13+
"golang.zx2c4.com/go118/netip"
14+
1415
"golang.org/x/sys/windows"
1516
"golang.zx2c4.com/wireguard/windows/driver"
1617
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
@@ -111,8 +112,11 @@ func (config *Config) ToDriverConfiguration() (*driver.Interface, uint32) {
111112
}
112113
var endpoint winipcfg.RawSockaddrInet
113114
if !config.Peers[i].Endpoint.IsEmpty() {
114-
flags |= driver.PeerHasEndpoint
115-
endpoint.SetIP(net.ParseIP(config.Peers[i].Endpoint.Host), config.Peers[i].Endpoint.Port)
115+
addr, err := netip.ParseAddr(config.Peers[i].Endpoint.Host)
116+
if err == nil {
117+
flags |= driver.PeerHasEndpoint
118+
endpoint.SetAddrPort(netip.AddrPortFrom(addr, config.Peers[i].Endpoint.Port))
119+
}
116120
}
117121
c.AppendPeer(&driver.Peer{
118122
Flags: flags,
@@ -123,20 +127,13 @@ func (config *Config) ToDriverConfiguration() (*driver.Interface, uint32) {
123127
AllowedIPsCount: uint32(len(config.Peers[i].AllowedIPs)),
124128
})
125129
for j := range config.Peers[i].AllowedIPs {
126-
var family winipcfg.AddressFamily
127-
var ip net.IP
128-
if ip = config.Peers[i].AllowedIPs[j].IP.To4(); ip != nil {
129-
family = windows.AF_INET
130-
} else if ip = config.Peers[i].AllowedIPs[j].IP.To16(); ip != nil {
131-
family = windows.AF_INET6
132-
} else {
133-
ip = config.Peers[i].AllowedIPs[j].IP
134-
}
135-
a := &driver.AllowedIP{
136-
AddressFamily: family,
137-
Cidr: config.Peers[i].AllowedIPs[j].Cidr,
130+
a := &driver.AllowedIP{Cidr: uint8(config.Peers[i].AllowedIPs[j].Bits())}
131+
copy(a.Address[:], config.Peers[i].AllowedIPs[j].Addr().AsSlice())
132+
if config.Peers[i].AllowedIPs[j].Addr().Is4() {
133+
a.AddressFamily = windows.AF_INET
134+
} else if config.Peers[i].AllowedIPs[j].Addr().Is6() {
135+
a.AddressFamily = windows.AF_INET6
138136
}
139-
copy(a.Address[:], ip)
140137
c.AppendAllowedIP(a)
141138
}
142139
}

0 commit comments

Comments
 (0)