Skip to content

Commit 25eb973

Browse files
raggizx2c4
authored andcommitted
conn: store IP_PKTINFO cmsg in StdNetendpoint src
Replace the src storage inside StdNetEndpoint with a copy of the raw control message buffer, to reduce allocation and perform less work on a per-packet basis. Signed-off-by: James Tucker <[email protected]> Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent b7cd547 commit 25eb973

File tree

4 files changed

+128
-98
lines changed

4 files changed

+128
-98
lines changed

conn/bind_std.go

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,10 @@ func NewStdNetBind() Bind {
8181
type StdNetEndpoint struct {
8282
// AddrPort is the endpoint destination.
8383
netip.AddrPort
84-
// src is the current sticky source address and interface index, if supported.
85-
src struct {
86-
netip.Addr
87-
ifidx int32
88-
}
84+
// src is the current sticky source address and interface index, if
85+
// supported. Typically this is a PKTINFO structure from/for control
86+
// messages, see unix.PKTINFO for an example.
87+
src []byte
8988
}
9089

9190
var (
@@ -104,21 +103,17 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
104103
}
105104

106105
func (e *StdNetEndpoint) ClearSrc() {
107-
e.src.ifidx = 0
108-
e.src.Addr = netip.Addr{}
106+
if e.src != nil {
107+
// Truncate src, no need to reallocate.
108+
e.src = e.src[:0]
109+
}
109110
}
110111

111112
func (e *StdNetEndpoint) DstIP() netip.Addr {
112113
return e.AddrPort.Addr()
113114
}
114115

115-
func (e *StdNetEndpoint) SrcIP() netip.Addr {
116-
return e.src.Addr
117-
}
118-
119-
func (e *StdNetEndpoint) SrcIfidx() int32 {
120-
return e.src.ifidx
121-
}
116+
// See sticky_default,linux, etc for implementations of SrcIP and SrcIfidx.
122117

123118
func (e *StdNetEndpoint) DstToBytes() []byte {
124119
b, _ := e.AddrPort.MarshalBinary()
@@ -129,10 +124,6 @@ func (e *StdNetEndpoint) DstToString() string {
129124
return e.AddrPort.String()
130125
}
131126

132-
func (e *StdNetEndpoint) SrcToString() string {
133-
return e.src.Addr.String()
134-
}
135-
136127
func listenNet(network string, port int) (*net.UDPConn, int, error) {
137128
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
138129
if err != nil {

conn/sticky_default.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,20 @@
77

88
package conn
99

10+
import "net/netip"
11+
12+
func (e *StdNetEndpoint) SrcIP() netip.Addr {
13+
return netip.Addr{}
14+
}
15+
16+
func (e *StdNetEndpoint) SrcIfidx() int32 {
17+
return 0
18+
}
19+
20+
func (e *StdNetEndpoint) SrcToString() string {
21+
return ""
22+
}
23+
1024
// TODO: macOS, FreeBSD and other BSDs likely do support this feature set, but
1125
// use alternatively named flags and need ports and require testing.
1226

conn/sticky_linux.go

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,37 @@ import (
1414
"golang.org/x/sys/unix"
1515
)
1616

17+
func (e *StdNetEndpoint) SrcIP() netip.Addr {
18+
switch len(e.src) {
19+
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
20+
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
21+
return netip.AddrFrom4(info.Spec_dst)
22+
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
23+
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
24+
// TODO: set zone. in order to do so we need to check if the address is
25+
// link local, and if it is perform a syscall to turn the ifindex into a
26+
// zone string because netip uses string zones.
27+
return netip.AddrFrom16(info.Addr)
28+
}
29+
return netip.Addr{}
30+
}
31+
32+
func (e *StdNetEndpoint) SrcIfidx() int32 {
33+
switch len(e.src) {
34+
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
35+
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
36+
return info.Ifindex
37+
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
38+
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
39+
return int32(info.Ifindex)
40+
}
41+
return 0
42+
}
43+
44+
func (e *StdNetEndpoint) SrcToString() string {
45+
return e.SrcIP().String()
46+
}
47+
1748
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
1849
// the source information found.
1950
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
@@ -35,81 +66,43 @@ func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
3566
if hdr.Level == unix.IPPROTO_IP &&
3667
hdr.Type == unix.IP_PKTINFO {
3768

38-
info := pktInfoFromBuf[unix.Inet4Pktinfo](data)
39-
ep.src.Addr = netip.AddrFrom4(info.Spec_dst)
40-
ep.src.ifidx = info.Ifindex
69+
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
70+
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
71+
}
72+
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
4173

74+
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
75+
copy(ep.src, hdrBuf)
76+
copy(ep.src[unix.CmsgLen(0):], data)
4277
return
4378
}
4479

4580
if hdr.Level == unix.IPPROTO_IPV6 &&
4681
hdr.Type == unix.IPV6_PKTINFO {
4782

48-
info := pktInfoFromBuf[unix.Inet6Pktinfo](data)
49-
ep.src.Addr = netip.AddrFrom16(info.Addr)
50-
ep.src.ifidx = int32(info.Ifindex)
83+
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
84+
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
85+
}
86+
87+
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
5188

89+
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
90+
copy(ep.src, hdrBuf)
91+
copy(ep.src[unix.CmsgLen(0):], data)
5292
return
5393
}
5494
}
5595
}
5696

57-
// pktInfoFromBuf returns type T populated from the provided buf via copy(). It
58-
// panics if buf is of insufficient size.
59-
func pktInfoFromBuf[T unix.Inet4Pktinfo | unix.Inet6Pktinfo](buf []byte) (t T) {
60-
size := int(unsafe.Sizeof(t))
61-
if len(buf) < size {
62-
panic("pktInfoFromBuf: buffer too small")
63-
}
64-
copy(unsafe.Slice((*byte)(unsafe.Pointer(&t)), size), buf)
65-
return t
66-
}
67-
6897
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
6998
// and source ifindex found in ep. control's len will be set to 0 in the event
7099
// that ep is a default value.
71100
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
72-
*control = (*control)[:cap(*control)]
73-
if len(*control) < int(unsafe.Sizeof(unix.Cmsghdr{})) {
74-
*control = (*control)[:0]
101+
if cap(*control) < len(ep.src) {
75102
return
76103
}
77-
78-
if ep.src.ifidx == 0 && !ep.SrcIP().IsValid() {
79-
*control = (*control)[:0]
80-
return
81-
}
82-
83-
if len(*control) < srcControlSize {
84-
*control = (*control)[:0]
85-
return
86-
}
87-
88-
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(*control)[0]))
89-
if ep.SrcIP().Is4() {
90-
hdr.Level = unix.IPPROTO_IP
91-
hdr.Type = unix.IP_PKTINFO
92-
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
93-
94-
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
95-
info.Ifindex = ep.src.ifidx
96-
if ep.SrcIP().IsValid() {
97-
info.Spec_dst = ep.SrcIP().As4()
98-
}
99-
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
100-
} else {
101-
hdr.Level = unix.IPPROTO_IPV6
102-
hdr.Type = unix.IPV6_PKTINFO
103-
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
104-
105-
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&(*control)[unix.SizeofCmsghdr]))
106-
info.Ifindex = uint32(ep.src.ifidx)
107-
if ep.SrcIP().IsValid() {
108-
info.Addr = ep.SrcIP().As16()
109-
}
110-
*control = (*control)[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
111-
}
112-
104+
*control = (*control)[:0]
105+
*control = append(*control, ep.src...)
113106
}
114107

115108
var srcControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)

conn/sticky_linux_test.go

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,47 @@ import (
1818
"golang.org/x/sys/unix"
1919
)
2020

21+
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
22+
var buf []byte
23+
if addr.Is4() {
24+
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
25+
hdr := unix.Cmsghdr{
26+
Level: unix.IPPROTO_IP,
27+
Type: unix.IP_PKTINFO,
28+
}
29+
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
30+
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
31+
32+
info := unix.Inet4Pktinfo{
33+
Ifindex: ifidx,
34+
Spec_dst: addr.As4(),
35+
}
36+
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
37+
} else {
38+
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
39+
hdr := unix.Cmsghdr{
40+
Level: unix.IPPROTO_IPV6,
41+
Type: unix.IPV6_PKTINFO,
42+
}
43+
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
44+
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
45+
46+
info := unix.Inet6Pktinfo{
47+
Ifindex: uint32(ifidx),
48+
Addr: addr.As16(),
49+
}
50+
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
51+
}
52+
53+
ep.src = buf
54+
}
55+
2156
func Test_setSrcControl(t *testing.T) {
2257
t.Run("IPv4", func(t *testing.T) {
2358
ep := &StdNetEndpoint{
2459
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
2560
}
26-
ep.src.Addr = netip.MustParseAddr("127.0.0.1")
27-
ep.src.ifidx = 5
61+
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
2862

2963
control := make([]byte, srcControlSize)
3064

@@ -53,8 +87,7 @@ func Test_setSrcControl(t *testing.T) {
5387
ep := &StdNetEndpoint{
5488
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
5589
}
56-
ep.src.Addr = netip.MustParseAddr("::1")
57-
ep.src.ifidx = 5
90+
setSrc(ep, netip.MustParseAddr("::1"), 5)
5891

5992
control := make([]byte, srcControlSize)
6093

@@ -80,7 +113,7 @@ func Test_setSrcControl(t *testing.T) {
80113
})
81114

82115
t.Run("ClearOnNoSrc", func(t *testing.T) {
83-
control := make([]byte, srcControlSize)
116+
control := make([]byte, unix.CmsgLen(0))
84117
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
85118
hdr.Level = 1
86119
hdr.Type = 2
@@ -96,7 +129,7 @@ func Test_setSrcControl(t *testing.T) {
96129

97130
func Test_getSrcFromControl(t *testing.T) {
98131
t.Run("IPv4", func(t *testing.T) {
99-
control := make([]byte, srcControlSize)
132+
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
100133
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
101134
hdr.Level = unix.IPPROTO_IP
102135
hdr.Type = unix.IP_PKTINFO
@@ -108,15 +141,15 @@ func Test_getSrcFromControl(t *testing.T) {
108141
ep := &StdNetEndpoint{}
109142
getSrcFromControl(control, ep)
110143

111-
if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
112-
t.Errorf("unexpected address: %v", ep.src.Addr)
144+
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
145+
t.Errorf("unexpected address: %v", ep.SrcIP())
113146
}
114-
if ep.src.ifidx != 5 {
115-
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
147+
if ep.SrcIfidx() != 5 {
148+
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
116149
}
117150
})
118151
t.Run("IPv6", func(t *testing.T) {
119-
control := make([]byte, srcControlSize)
152+
control := make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
120153
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
121154
hdr.Level = unix.IPPROTO_IPV6
122155
hdr.Type = unix.IPV6_PKTINFO
@@ -131,30 +164,29 @@ func Test_getSrcFromControl(t *testing.T) {
131164
if ep.SrcIP() != netip.MustParseAddr("::1") {
132165
t.Errorf("unexpected address: %v", ep.SrcIP())
133166
}
134-
if ep.src.ifidx != 5 {
135-
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
167+
if ep.SrcIfidx() != 5 {
168+
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
136169
}
137170
})
138171
t.Run("ClearOnEmpty", func(t *testing.T) {
139-
control := make([]byte, srcControlSize)
172+
var control []byte
140173
ep := &StdNetEndpoint{}
141-
ep.src.Addr = netip.MustParseAddr("::1")
142-
ep.src.ifidx = 5
174+
setSrc(ep, netip.MustParseAddr("::1"), 5)
143175

144176
getSrcFromControl(control, ep)
145177
if ep.SrcIP().IsValid() {
146-
t.Errorf("unexpected address: %v", ep.src.Addr)
178+
t.Errorf("unexpected address: %v", ep.SrcIP())
147179
}
148-
if ep.src.ifidx != 0 {
149-
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
180+
if ep.SrcIfidx() != 0 {
181+
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
150182
}
151183
})
152184
t.Run("Multiple", func(t *testing.T) {
153185
zeroControl := make([]byte, unix.CmsgSpace(0))
154186
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
155187
zeroHdr.SetLen(unix.CmsgLen(0))
156188

157-
control := make([]byte, srcControlSize)
189+
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
158190
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
159191
hdr.Level = unix.IPPROTO_IP
160192
hdr.Type = unix.IP_PKTINFO
@@ -170,11 +202,11 @@ func Test_getSrcFromControl(t *testing.T) {
170202
ep := &StdNetEndpoint{}
171203
getSrcFromControl(combined, ep)
172204

173-
if ep.src.Addr != netip.MustParseAddr("127.0.0.1") {
174-
t.Errorf("unexpected address: %v", ep.src.Addr)
205+
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
206+
t.Errorf("unexpected address: %v", ep.SrcIP())
175207
}
176-
if ep.src.ifidx != 5 {
177-
t.Errorf("unexpected ifindex: %d", ep.src.ifidx)
208+
if ep.SrcIfidx() != 5 {
209+
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
178210
}
179211
})
180212
}

0 commit comments

Comments
 (0)