@@ -10,6 +10,7 @@ import (
10
10
"errors"
11
11
"net"
12
12
"net/netip"
13
+ "runtime"
13
14
"strconv"
14
15
"sync"
15
16
"syscall"
@@ -22,16 +23,21 @@ var (
22
23
_ Bind = (* StdNetBind )(nil )
23
24
)
24
25
25
- // StdNetBind implements Bind for all platforms except Windows.
26
+ // StdNetBind implements Bind for all platforms. While Windows has its own Bind
27
+ // (see bind_windows.go), it may fall back to StdNetBind.
28
+ // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
29
+ // methods for sending and receiving multiple datagrams per-syscall. See the
30
+ // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
26
31
type StdNetBind struct {
27
- mu sync.Mutex // protects following fields
28
- ipv4 * net.UDPConn
29
- ipv6 * net.UDPConn
30
- blackhole4 bool
31
- blackhole6 bool
32
- ipv4PC * ipv4.PacketConn
33
- ipv6PC * ipv6.PacketConn
34
- udpAddrPool sync.Pool
32
+ mu sync.Mutex // protects following fields
33
+ ipv4 * net.UDPConn
34
+ ipv6 * net.UDPConn
35
+ blackhole4 bool
36
+ blackhole6 bool
37
+ ipv4PC * ipv4.PacketConn // will be nil on non-Linux
38
+ ipv6PC * ipv6.PacketConn // will be nil on non-Linux
39
+
40
+ udpAddrPool sync.Pool // following fields are not guarded by mu
35
41
ipv4MsgsPool sync.Pool
36
42
ipv6MsgsPool sync.Pool
37
43
}
@@ -154,6 +160,8 @@ func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
154
160
again:
155
161
port := int (uport )
156
162
var v4conn , v6conn * net.UDPConn
163
+ var v4pc * ipv4.PacketConn
164
+ var v6pc * ipv6.PacketConn
157
165
158
166
v4conn , port , err = listenNet ("udp4" , port )
159
167
if err != nil && ! errors .Is (err , syscall .EAFNOSUPPORT ) {
@@ -173,63 +181,92 @@ again:
173
181
}
174
182
var fns []ReceiveFunc
175
183
if v4conn != nil {
176
- fns = append (fns , s .receiveIPv4 )
184
+ if runtime .GOOS == "linux" {
185
+ v4pc = ipv4 .NewPacketConn (v4conn )
186
+ s .ipv4PC = v4pc
187
+ }
188
+ fns = append (fns , s .makeReceiveIPv4 (v4pc , v4conn ))
177
189
s .ipv4 = v4conn
178
190
}
179
191
if v6conn != nil {
180
- fns = append (fns , s .receiveIPv6 )
192
+ if runtime .GOOS == "linux" {
193
+ v6pc = ipv6 .NewPacketConn (v6conn )
194
+ s .ipv6PC = v6pc
195
+ }
196
+ fns = append (fns , s .makeReceiveIPv6 (v6pc , v6conn ))
181
197
s .ipv6 = v6conn
182
198
}
183
199
if len (fns ) == 0 {
184
200
return nil , 0 , syscall .EAFNOSUPPORT
185
201
}
186
202
187
- s .ipv4PC = ipv4 .NewPacketConn (s .ipv4 )
188
- s .ipv6PC = ipv6 .NewPacketConn (s .ipv6 )
189
-
190
203
return fns , uint16 (port ), nil
191
204
}
192
205
193
- func (s * StdNetBind ) receiveIPv4 (buffs [][]byte , sizes []int , eps []Endpoint ) (n int , err error ) {
194
- msgs := s .ipv4MsgsPool .Get ().(* []ipv4.Message )
195
- defer s .ipv4MsgsPool .Put (msgs )
196
- for i := range buffs {
197
- (* msgs )[i ].Buffers [0 ] = buffs [i ]
198
- }
199
- numMsgs , err := s .ipv4PC .ReadBatch (* msgs , 0 )
200
- if err != nil {
201
- return 0 , err
202
- }
203
- for i := 0 ; i < numMsgs ; i ++ {
204
- msg := & (* msgs )[i ]
205
- sizes [i ] = msg .N
206
- addrPort := msg .Addr .(* net.UDPAddr ).AddrPort ()
207
- ep := asEndpoint (addrPort )
208
- getSrcFromControl (msg .OOB , ep )
209
- eps [i ] = ep
206
+ func (s * StdNetBind ) makeReceiveIPv4 (pc * ipv4.PacketConn , conn * net.UDPConn ) ReceiveFunc {
207
+ return func (buffs [][]byte , sizes []int , eps []Endpoint ) (n int , err error ) {
208
+ msgs := s .ipv4MsgsPool .Get ().(* []ipv4.Message )
209
+ defer s .ipv4MsgsPool .Put (msgs )
210
+ for i := range buffs {
211
+ (* msgs )[i ].Buffers [0 ] = buffs [i ]
212
+ }
213
+ var numMsgs int
214
+ if runtime .GOOS == "linux" {
215
+ numMsgs , err = pc .ReadBatch (* msgs , 0 )
216
+ if err != nil {
217
+ return 0 , err
218
+ }
219
+ } else {
220
+ msg := & (* msgs )[0 ]
221
+ msg .N , msg .NN , _ , msg .Addr , err = conn .ReadMsgUDP (msg .Buffers [0 ], msg .OOB )
222
+ if err != nil {
223
+ return 0 , err
224
+ }
225
+ numMsgs = 1
226
+ }
227
+ for i := 0 ; i < numMsgs ; i ++ {
228
+ msg := & (* msgs )[i ]
229
+ sizes [i ] = msg .N
230
+ addrPort := msg .Addr .(* net.UDPAddr ).AddrPort ()
231
+ ep := asEndpoint (addrPort )
232
+ getSrcFromControl (msg .OOB , ep )
233
+ eps [i ] = ep
234
+ }
235
+ return numMsgs , nil
210
236
}
211
- return numMsgs , nil
212
237
}
213
238
214
- func (s * StdNetBind ) receiveIPv6 (buffs [][]byte , sizes []int , eps []Endpoint ) (n int , err error ) {
215
- msgs := s .ipv6MsgsPool .Get ().(* []ipv6.Message )
216
- defer s .ipv6MsgsPool .Put (msgs )
217
- for i := range buffs {
218
- (* msgs )[i ].Buffers [0 ] = buffs [i ]
219
- }
220
- numMsgs , err := s .ipv6PC .ReadBatch (* msgs , 0 )
221
- if err != nil {
222
- return 0 , err
223
- }
224
- for i := 0 ; i < numMsgs ; i ++ {
225
- msg := & (* msgs )[i ]
226
- sizes [i ] = msg .N
227
- addrPort := msg .Addr .(* net.UDPAddr ).AddrPort ()
228
- ep := asEndpoint (addrPort )
229
- getSrcFromControl (msg .OOB , ep )
230
- eps [i ] = ep
239
+ func (s * StdNetBind ) makeReceiveIPv6 (pc * ipv6.PacketConn , conn * net.UDPConn ) ReceiveFunc {
240
+ return func (buffs [][]byte , sizes []int , eps []Endpoint ) (n int , err error ) {
241
+ msgs := s .ipv4MsgsPool .Get ().(* []ipv6.Message )
242
+ defer s .ipv4MsgsPool .Put (msgs )
243
+ for i := range buffs {
244
+ (* msgs )[i ].Buffers [0 ] = buffs [i ]
245
+ }
246
+ var numMsgs int
247
+ if runtime .GOOS == "linux" {
248
+ numMsgs , err = pc .ReadBatch (* msgs , 0 )
249
+ if err != nil {
250
+ return 0 , err
251
+ }
252
+ } else {
253
+ msg := & (* msgs )[0 ]
254
+ msg .N , msg .NN , _ , msg .Addr , err = conn .ReadMsgUDP (msg .Buffers [0 ], msg .OOB )
255
+ if err != nil {
256
+ return 0 , err
257
+ }
258
+ numMsgs = 1
259
+ }
260
+ for i := 0 ; i < numMsgs ; i ++ {
261
+ msg := & (* msgs )[i ]
262
+ sizes [i ] = msg .N
263
+ addrPort := msg .Addr .(* net.UDPAddr ).AddrPort ()
264
+ ep := asEndpoint (addrPort )
265
+ getSrcFromControl (msg .OOB , ep )
266
+ eps [i ] = ep
267
+ }
268
+ return numMsgs , nil
231
269
}
232
- return numMsgs , nil
233
270
}
234
271
235
272
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
@@ -246,10 +283,12 @@ func (s *StdNetBind) Close() error {
246
283
if s .ipv4 != nil {
247
284
err1 = s .ipv4 .Close ()
248
285
s .ipv4 = nil
286
+ s .ipv4PC = nil
249
287
}
250
288
if s .ipv6 != nil {
251
289
err2 = s .ipv6 .Close ()
252
290
s .ipv6 = nil
291
+ s .ipv6PC = nil
253
292
}
254
293
s .blackhole4 = false
255
294
s .blackhole6 = false
@@ -263,11 +302,18 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
263
302
s .mu .Lock ()
264
303
blackhole := s .blackhole4
265
304
conn := s .ipv4
305
+ var (
306
+ pc4 * ipv4.PacketConn
307
+ pc6 * ipv6.PacketConn
308
+ )
266
309
is6 := false
267
310
if endpoint .DstIP ().Is6 () {
268
311
blackhole = s .blackhole6
269
312
conn = s .ipv6
313
+ pc6 = s .ipv6PC
270
314
is6 = true
315
+ } else {
316
+ pc4 = s .ipv4PC
271
317
}
272
318
s .mu .Unlock ()
273
319
@@ -278,13 +324,13 @@ func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
278
324
return syscall .EAFNOSUPPORT
279
325
}
280
326
if is6 {
281
- return s .send6 (s . ipv6PC , endpoint , buffs )
327
+ return s .send6 (conn , pc6 , endpoint , buffs )
282
328
} else {
283
- return s .send4 (s . ipv4PC , endpoint , buffs )
329
+ return s .send4 (conn , pc4 , endpoint , buffs )
284
330
}
285
331
}
286
332
287
- func (s * StdNetBind ) send4 (conn * ipv4.PacketConn , ep Endpoint , buffs [][]byte ) error {
333
+ func (s * StdNetBind ) send4 (conn * net. UDPConn , pc * ipv4.PacketConn , ep Endpoint , buffs [][]byte ) error {
288
334
ua := s .udpAddrPool .Get ().(* net.UDPAddr )
289
335
as4 := ep .DstIP ().As4 ()
290
336
copy (ua .IP , as4 [:])
@@ -301,19 +347,28 @@ func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) e
301
347
err error
302
348
start int
303
349
)
304
- for {
305
- n , err = conn .WriteBatch ((* msgs )[start :len (buffs )], 0 )
306
- if err != nil || n == len ((* msgs )[start :len (buffs )]) {
307
- break
350
+ if runtime .GOOS == "linux" {
351
+ for {
352
+ n , err = pc .WriteBatch ((* msgs )[start :len (buffs )], 0 )
353
+ if err != nil || n == len ((* msgs )[start :len (buffs )]) {
354
+ break
355
+ }
356
+ start += n
357
+ }
358
+ } else {
359
+ for i , buff := range buffs {
360
+ _ , _ , err = conn .WriteMsgUDP (buff , (* msgs )[i ].OOB , ua )
361
+ if err != nil {
362
+ break
363
+ }
308
364
}
309
- start += n
310
365
}
311
366
s .udpAddrPool .Put (ua )
312
367
s .ipv4MsgsPool .Put (msgs )
313
368
return err
314
369
}
315
370
316
- func (s * StdNetBind ) send6 (conn * ipv6.PacketConn , ep Endpoint , buffs [][]byte ) error {
371
+ func (s * StdNetBind ) send6 (conn * net. UDPConn , pc * ipv6.PacketConn , ep Endpoint , buffs [][]byte ) error {
317
372
ua := s .udpAddrPool .Get ().(* net.UDPAddr )
318
373
as16 := ep .DstIP ().As16 ()
319
374
copy (ua .IP , as16 [:])
@@ -330,12 +385,21 @@ func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) e
330
385
err error
331
386
start int
332
387
)
333
- for {
334
- n , err = conn .WriteBatch ((* msgs )[start :len (buffs )], 0 )
335
- if err != nil || n == len ((* msgs )[start :len (buffs )]) {
336
- break
388
+ if runtime .GOOS == "linux" {
389
+ for {
390
+ n , err = pc .WriteBatch ((* msgs )[start :len (buffs )], 0 )
391
+ if err != nil || n == len ((* msgs )[start :len (buffs )]) {
392
+ break
393
+ }
394
+ start += n
395
+ }
396
+ } else {
397
+ for i , buff := range buffs {
398
+ _ , _ , err = conn .WriteMsgUDP (buff , (* msgs )[i ].OOB , ua )
399
+ if err != nil {
400
+ break
401
+ }
337
402
}
338
- start += n
339
403
}
340
404
s .udpAddrPool .Put (ua )
341
405
s .ipv6MsgsPool .Put (msgs )
0 commit comments