Skip to content

Commit 2fcdaf9

Browse files
jwhitedzx2c4
authored andcommitted
conn: fix StdNetBind fallback on Windows
If RIO is unavailable, NewWinRingBind() falls back to StdNetBind. StdNetBind uses x/net/ipv{4,6}.PacketConn for sending and receiving datagrams, specifically via the {Read,Write}Batch methods. These methods are unimplemented on Windows and will return runtime errors as a result. Additionally, only Linux benefits from these x/net types for reading and writing, so we update StdNetBind to fall back to the standard library net package for all platforms other than Linux. Reviewed-by: James Tucker <[email protected]> Signed-off-by: Jordan Whited <[email protected]> Signed-off-by: Jason A. Donenfeld <[email protected]>
1 parent dbd9493 commit 2fcdaf9

File tree

2 files changed

+150
-64
lines changed

2 files changed

+150
-64
lines changed

conn/bind_std.go

Lines changed: 128 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"errors"
1111
"net"
1212
"net/netip"
13+
"runtime"
1314
"strconv"
1415
"sync"
1516
"syscall"
@@ -22,16 +23,21 @@ var (
2223
_ Bind = (*StdNetBind)(nil)
2324
)
2425

25-
// StdNetBind implements Bind for all platforms except Windows.
26+
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
27+
// (see bind_windows.go), it may fall back to StdNetBind.
28+
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
29+
// methods for sending and receiving multiple datagrams per-syscall. See the
30+
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
2631
type StdNetBind struct {
27-
mu sync.Mutex // protects following fields
28-
ipv4 *net.UDPConn
29-
ipv6 *net.UDPConn
30-
blackhole4 bool
31-
blackhole6 bool
32-
ipv4PC *ipv4.PacketConn
33-
ipv6PC *ipv6.PacketConn
34-
udpAddrPool sync.Pool
32+
mu sync.Mutex // protects following fields
33+
ipv4 *net.UDPConn
34+
ipv6 *net.UDPConn
35+
blackhole4 bool
36+
blackhole6 bool
37+
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
38+
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
39+
40+
udpAddrPool sync.Pool // following fields are not guarded by mu
3541
ipv4MsgsPool sync.Pool
3642
ipv6MsgsPool sync.Pool
3743
}
@@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
154160
again:
155161
port := int(uport)
156162
var v4conn, v6conn *net.UDPConn
163+
var v4pc *ipv4.PacketConn
164+
var v6pc *ipv6.PacketConn
157165

158166
v4conn, port, err = listenNet("udp4", port)
159167
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
@@ -173,63 +181,92 @@ again:
173181
}
174182
var fns []ReceiveFunc
175183
if v4conn != nil {
176-
fns = append(fns, s.receiveIPv4)
184+
if runtime.GOOS == "linux" {
185+
v4pc = ipv4.NewPacketConn(v4conn)
186+
s.ipv4PC = v4pc
187+
}
188+
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn))
177189
s.ipv4 = v4conn
178190
}
179191
if v6conn != nil {
180-
fns = append(fns, s.receiveIPv6)
192+
if runtime.GOOS == "linux" {
193+
v6pc = ipv6.NewPacketConn(v6conn)
194+
s.ipv6PC = v6pc
195+
}
196+
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn))
181197
s.ipv6 = v6conn
182198
}
183199
if len(fns) == 0 {
184200
return nil, 0, syscall.EAFNOSUPPORT
185201
}
186202

187-
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
188-
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
189-
190203
return fns, uint16(port), nil
191204
}
192205

193-
func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
194-
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
195-
defer s.ipv4MsgsPool.Put(msgs)
196-
for i := range buffs {
197-
(*msgs)[i].Buffers[0] = buffs[i]
198-
}
199-
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
200-
if err != nil {
201-
return 0, err
202-
}
203-
for i := 0; i < numMsgs; i++ {
204-
msg := &(*msgs)[i]
205-
sizes[i] = msg.N
206-
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
207-
ep := asEndpoint(addrPort)
208-
getSrcFromControl(msg.OOB, ep)
209-
eps[i] = ep
206+
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn) ReceiveFunc {
207+
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
208+
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
209+
defer s.ipv4MsgsPool.Put(msgs)
210+
for i := range buffs {
211+
(*msgs)[i].Buffers[0] = buffs[i]
212+
}
213+
var numMsgs int
214+
if runtime.GOOS == "linux" {
215+
numMsgs, err = pc.ReadBatch(*msgs, 0)
216+
if err != nil {
217+
return 0, err
218+
}
219+
} else {
220+
msg := &(*msgs)[0]
221+
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
222+
if err != nil {
223+
return 0, err
224+
}
225+
numMsgs = 1
226+
}
227+
for i := 0; i < numMsgs; i++ {
228+
msg := &(*msgs)[i]
229+
sizes[i] = msg.N
230+
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
231+
ep := asEndpoint(addrPort)
232+
getSrcFromControl(msg.OOB, ep)
233+
eps[i] = ep
234+
}
235+
return numMsgs, nil
210236
}
211-
return numMsgs, nil
212237
}
213238

214-
func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
215-
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
216-
defer s.ipv6MsgsPool.Put(msgs)
217-
for i := range buffs {
218-
(*msgs)[i].Buffers[0] = buffs[i]
219-
}
220-
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
221-
if err != nil {
222-
return 0, err
223-
}
224-
for i := 0; i < numMsgs; i++ {
225-
msg := &(*msgs)[i]
226-
sizes[i] = msg.N
227-
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
228-
ep := asEndpoint(addrPort)
229-
getSrcFromControl(msg.OOB, ep)
230-
eps[i] = ep
239+
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn) ReceiveFunc {
240+
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
241+
msgs := s.ipv4MsgsPool.Get().(*[]ipv6.Message)
242+
defer s.ipv4MsgsPool.Put(msgs)
243+
for i := range buffs {
244+
(*msgs)[i].Buffers[0] = buffs[i]
245+
}
246+
var numMsgs int
247+
if runtime.GOOS == "linux" {
248+
numMsgs, err = pc.ReadBatch(*msgs, 0)
249+
if err != nil {
250+
return 0, err
251+
}
252+
} else {
253+
msg := &(*msgs)[0]
254+
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
255+
if err != nil {
256+
return 0, err
257+
}
258+
numMsgs = 1
259+
}
260+
for i := 0; i < numMsgs; i++ {
261+
msg := &(*msgs)[i]
262+
sizes[i] = msg.N
263+
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
264+
ep := asEndpoint(addrPort)
265+
getSrcFromControl(msg.OOB, ep)
266+
eps[i] = ep
267+
}
268+
return numMsgs, nil
231269
}
232-
return numMsgs, nil
233270
}
234271

235272
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
@@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error {
246283
if s.ipv4 != nil {
247284
err1 = s.ipv4.Close()
248285
s.ipv4 = nil
286+
s.ipv4PC = nil
249287
}
250288
if s.ipv6 != nil {
251289
err2 = s.ipv6.Close()
252290
s.ipv6 = nil
291+
s.ipv6PC = nil
253292
}
254293
s.blackhole4 = false
255294
s.blackhole6 = false
@@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
263302
s.mu.Lock()
264303
blackhole := s.blackhole4
265304
conn := s.ipv4
305+
var (
306+
pc4 *ipv4.PacketConn
307+
pc6 *ipv6.PacketConn
308+
)
266309
is6 := false
267310
if endpoint.DstIP().Is6() {
268311
blackhole = s.blackhole6
269312
conn = s.ipv6
313+
pc6 = s.ipv6PC
270314
is6 = true
315+
} else {
316+
pc4 = s.ipv4PC
271317
}
272318
s.mu.Unlock()
273319

@@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
278324
return syscall.EAFNOSUPPORT
279325
}
280326
if is6 {
281-
return s.send6(s.ipv6PC, endpoint, buffs)
327+
return s.send6(conn, pc6, endpoint, buffs)
282328
} else {
283-
return s.send4(s.ipv4PC, endpoint, buffs)
329+
return s.send4(conn, pc4, endpoint, buffs)
284330
}
285331
}
286332

287-
func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
333+
func (s *StdNetBind) send4(conn *net.UDPConn, pc *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
288334
ua := s.udpAddrPool.Get().(*net.UDPAddr)
289335
as4 := ep.DstIP().As4()
290336
copy(ua.IP, as4[:])
@@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e
301347
err error
302348
start int
303349
)
304-
for {
305-
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
306-
if err != nil || n == len((*msgs)[start:len(buffs)]) {
307-
break
350+
if runtime.GOOS == "linux" {
351+
for {
352+
n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
353+
if err != nil || n == len((*msgs)[start:len(buffs)]) {
354+
break
355+
}
356+
start += n
357+
}
358+
} else {
359+
for i, buff := range buffs {
360+
_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
361+
if err != nil {
362+
break
363+
}
308364
}
309-
start += n
310365
}
311366
s.udpAddrPool.Put(ua)
312367
s.ipv4MsgsPool.Put(msgs)
313368
return err
314369
}
315370

316-
func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
371+
func (s *StdNetBind) send6(conn *net.UDPConn, pc *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
317372
ua := s.udpAddrPool.Get().(*net.UDPAddr)
318373
as16 := ep.DstIP().As16()
319374
copy(ua.IP, as16[:])
@@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e
330385
err error
331386
start int
332387
)
333-
for {
334-
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
335-
if err != nil || n == len((*msgs)[start:len(buffs)]) {
336-
break
388+
if runtime.GOOS == "linux" {
389+
for {
390+
n, err = pc.WriteBatch((*msgs)[start:len(buffs)], 0)
391+
if err != nil || n == len((*msgs)[start:len(buffs)]) {
392+
break
393+
}
394+
start += n
395+
}
396+
} else {
397+
for i, buff := range buffs {
398+
_, _, err = conn.WriteMsgUDP(buff, (*msgs)[i].OOB, ua)
399+
if err != nil {
400+
break
401+
}
337402
}
338-
start += n
339403
}
340404
s.udpAddrPool.Put(ua)
341405
s.ipv6MsgsPool.Put(msgs)

conn/bind_std_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package conn
2+
3+
import "testing"
4+
5+
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
6+
bind := NewStdNetBind().(*StdNetBind)
7+
fns, _, err := bind.Open(0)
8+
if err != nil {
9+
t.Fatal(err)
10+
}
11+
bind.Close()
12+
buffs := make([][]byte, 1)
13+
buffs[0] = make([]byte, 1)
14+
sizes := make([]int, 1)
15+
eps := make([]Endpoint, 1)
16+
for _, fn := range fns {
17+
// The ReceiveFuncs must not access conn-related fields on StdNetBind
18+
// unguarded. Close() nils the conn-related fields resulting in a panic
19+
// if they violate the mutex.
20+
fn(buffs, sizes, eps)
21+
}
22+
}

0 commit comments

Comments
 (0)