From 5a43767f612e76ba0d227a49d197c4a48b9e7684 Mon Sep 17 00:00:00 2001 From: wenxichang Date: Tue, 29 Jul 2025 16:05:36 +0800 Subject: [PATCH 1/3] Add extra control functions for StdNetBind --- conn/bind_std.go | 14 +++++++++----- conn/controlfns.go | 12 +++++++++--- conn/default.go | 2 +- conn/sticky_linux_test.go | 4 ++-- 4 files changed, 21 insertions(+), 11 deletions(-) diff --git a/conn/bind_std.go b/conn/bind_std.go index f5c88160e..d6e0a21be 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -46,9 +46,11 @@ type StdNetBind struct { blackhole4 bool blackhole6 bool + + extraFns []ControlFn } -func NewStdNetBind() Bind { +func NewStdNetBind(fns []ControlFn) Bind { return &StdNetBind{ udpAddrPool: sync.Pool{ New: func() any { @@ -70,6 +72,8 @@ func NewStdNetBind() Bind { return &msgs }, }, + + extraFns: fns, } } @@ -119,8 +123,8 @@ func (e *StdNetEndpoint) DstToString() string { return e.AddrPort.String() } -func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) +func listenNet(network string, port int, fns []ControlFn) (*net.UDPConn, int, error) { + conn, err := listenConfig(fns).ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) if err != nil { return nil, 0, err } @@ -156,13 +160,13 @@ again: var v4pc *ipv4.PacketConn var v6pc *ipv6.PacketConn - v4conn, port, err = listenNet("udp4", port) + v4conn, port, err = listenNet("udp4", port, s.extraFns) if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { return nil, 0, err } // Listen on the same port as we're using for ipv4. - v6conn, port, err = listenNet("udp6", port) + v6conn, port, err = listenNet("udp6", port, s.extraFns) if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { v4conn.Close() tries++ diff --git a/conn/controlfns.go b/conn/controlfns.go index 27421bd26..61b7233f4 100644 --- a/conn/controlfns.go +++ b/conn/controlfns.go @@ -20,16 +20,16 @@ const socketBufferSize = 7 << 20 // controlFn is the callback function signature from net.ListenConfig.Control. // It is used to apply platform specific configuration to the socket prior to // bind. -type controlFn func(network, address string, c syscall.RawConn) error +type ControlFn func(network, address string, c syscall.RawConn) error // controlFns is a list of functions that are called from the listen config // that can apply socket options. -var controlFns = []controlFn{} +var controlFns = []ControlFn{} // listenConfig returns a net.ListenConfig that applies the controlFns to the // socket prior to bind. This is used to apply socket buffer sizing and packet // information OOB configuration for sticky sockets. -func listenConfig() *net.ListenConfig { +func listenConfig(extraFns []ControlFn) *net.ListenConfig { return &net.ListenConfig{ Control: func(network, address string, c syscall.RawConn) error { for _, fn := range controlFns { @@ -37,6 +37,12 @@ func listenConfig() *net.ListenConfig { return err } } + + for _, fn := range extraFns { + if err := fn(network, address, c); err != nil { + return err + } + } return nil }, } diff --git a/conn/default.go b/conn/default.go index 2ce157956..6fdcabd04 100644 --- a/conn/default.go +++ b/conn/default.go @@ -7,4 +7,4 @@ package conn -func NewDefaultBind() Bind { return NewStdNetBind() } +func NewDefaultBind() Bind { return NewStdNetBind(nil) } diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go index 1b1ee6833..9a5f6e372 100644 --- a/conn/sticky_linux_test.go +++ b/conn/sticky_linux_test.go @@ -213,7 +213,7 @@ func Test_getSrcFromControl(t *testing.T) { func Test_listenConfig(t *testing.T) { t.Run("IPv4", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp4", ":0") if err != nil { t.Fatal(err) } @@ -239,7 +239,7 @@ func Test_listenConfig(t *testing.T) { } }) t.Run("IPv6", func(t *testing.T) { - conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + conn, err := listenConfig(nil).ListenPacket(context.Background(), "udp6", ":0") if err != nil { t.Fatal(err) } From ffbeae2dbeff597f675dda95b757191d1a10be4e Mon Sep 17 00:00:00 2001 From: wenxichang Date: Mon, 29 Sep 2025 16:45:51 +0800 Subject: [PATCH 2/3] Add a service identification mechanism so that the Bind object can obtain the service information of the inner data packet when sending encapsulated packets. --- conn/bind_std.go | 2 +- conn/bind_windows.go | 4 ++-- conn/bindtest/bindtest.go | 2 +- conn/conn.go | 2 +- device/device_test.go | 10 +++++----- device/peer.go | 4 ++-- device/send.go | 24 ++++++++++++++++++++---- device/service.go | 26 ++++++++++++++++++++++++++ 8 files changed, 58 insertions(+), 16 deletions(-) create mode 100644 device/service.go diff --git a/conn/bind_std.go b/conn/bind_std.go index d6e0a21be..5c24bedb4 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -342,7 +342,7 @@ func (e ErrUDPGSODisabled) Unwrap() error { return e.RetryErr } -func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { +func (s *StdNetBind) Send(bufs [][]byte, services []uint64, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 diff --git a/conn/bind_windows.go b/conn/bind_windows.go index a3b846067..4f32623a4 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -81,7 +81,7 @@ func NewDefaultBind() Bind { return NewWinRingBind() } func NewWinRingBind() Bind { if !winrio.Initialize() { - return NewStdNetBind() + return NewStdNetBind([]ControlFn{}) } return new(WinRingBind) } @@ -486,7 +486,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { +func (bind *WinRingBind) Send(bufs [][]byte, services []uint64, endpoint Endpoint) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 46e20e68c..3769f5700 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -107,7 +107,7 @@ func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { } } -func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { +func (c *ChannelBind) Send(bufs [][]byte, services []uint64, ep conn.Endpoint) error { for _, b := range bufs { select { case <-c.closeSignal: diff --git a/conn/conn.go b/conn/conn.go index 1304657e5..6a8b6893a 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -47,7 +47,7 @@ type Bind interface { // Send writes one or more packets in bufs to address ep. The length of // bufs must not exceed BatchSize(). - Send(bufs [][]byte, ep Endpoint) error + Send(bufs [][]byte, services []uint64, ep Endpoint) error // ParseEndpoint creates a new endpoint from a string. ParseEndpoint(s string) (Endpoint, error) diff --git a/device/device_test.go b/device/device_test.go index 0091e2052..dbfe416fa 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -426,11 +426,11 @@ type fakeBindSized struct { func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 0, nil } -func (b *fakeBindSized) Close() error { return nil } -func (b *fakeBindSized) SetMark(mark uint32) error { return nil } -func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } -func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } -func (b *fakeBindSized) BatchSize() int { return b.size } +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, services []uint64, ep conn.Endpoint) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } type fakeTUNDeviceSized struct { size int diff --git a/device/peer.go b/device/peer.go index ebf25f941..1c9469546 100644 --- a/device/peer.go +++ b/device/peer.go @@ -113,7 +113,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } -func (peer *Peer) SendBuffers(buffers [][]byte) error { +func (peer *Peer) SendBuffers(buffers [][]byte, services []uint64) error { peer.device.net.RLock() defer peer.device.net.RUnlock() @@ -133,7 +133,7 @@ func (peer *Peer) SendBuffers(buffers [][]byte) error { } peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffers, endpoint) + err := peer.device.net.bind.Send(buffers, services, endpoint) if err == nil { var totalLen uint64 for _, b := range buffers { diff --git a/device/send.go b/device/send.go index ff8f7da50..17b05693a 100644 --- a/device/send.go +++ b/device/send.go @@ -50,6 +50,8 @@ type QueueOutboundElement struct { nonce uint64 // nonce for encryption keypair *Keypair // keypair for encryption peer *Peer // related peer + service uint64 // inner packet service identifier + drop bool // service identifier result, should drop this packet } type QueueOutboundElementsContainer struct { @@ -130,7 +132,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{packet}) + err = peer.SendBuffers([][]byte{packet}, []uint64{0}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -167,7 +169,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{packet}) + err = peer.SendBuffers([][]byte{packet}, []uint64{0}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -187,7 +189,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) packet := make([]byte, MessageCookieReplySize) _ = reply.marshal(packet) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{packet}, []uint64{0}, initiatingElem.endpoint) return nil } @@ -445,6 +447,14 @@ func (device *Device) RoutineEncryption(id int) { for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { + // identify inner packet + service, shouldDrop := ExecuteServiceFns(elem.packet) + if shouldDrop { + elem.drop = true + continue + } + elem.service = service + // populate header fields header := elem.buffer[:MessageTransportHeaderSize] @@ -483,9 +493,11 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device.log.Verbosef("%v - Routine: sequential sender - started", peer) bufs := make([][]byte, 0, maxBatchSize) + services := make([]uint64, 0, maxBatchSize) for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] + services = services[:0] if elemsContainer == nil { return } @@ -507,16 +519,20 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { dataSent := false elemsContainer.Lock() for _, elem := range elemsContainer.elems { + if elem.drop { + continue + } if len(elem.packet) != MessageKeepaliveSize { dataSent = true } bufs = append(bufs, elem.packet) + services = append(services, elem.service) } peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err := peer.SendBuffers(bufs) + err := peer.SendBuffers(bufs, services) if dataSent { peer.timersDataSent() } diff --git a/device/service.go b/device/service.go new file mode 100644 index 000000000..9a2471f24 --- /dev/null +++ b/device/service.go @@ -0,0 +1,26 @@ +package device + +// ServiceFn process packet and return serivce id and drop flag +type ServiceFn func(buff []byte) (service uint64, shouldDrop bool) + +var serviceFns []ServiceFn + +// RegisterServiceFn register service function to identify packet +func RegisterServiceFn(fn ServiceFn) { + serviceFns = append(serviceFns, fn) +} + +// ExecuteServiceFns to process packet data +func ExecuteServiceFns(buff []byte) (service uint64, shouldDrop bool) { + finalService := uint64(0) + for _, fn := range serviceFns { + service, shouldDrop = fn(buff) + if service != 0 { + finalService = service + } + if shouldDrop { + return finalService, true + } + } + return finalService, false +} From 356455abacaf333081f417fee061c420dbf50941 Mon Sep 17 00:00:00 2001 From: wenxichang Date: Thu, 23 Oct 2025 09:31:03 +0800 Subject: [PATCH 3/3] use conn.Service to replace uint64 service ID --- conn/bind_std.go | 2 +- conn/bind_windows.go | 2 +- conn/bindtest/bindtest.go | 2 +- conn/conn.go | 2 +- {device => conn}/service.go | 17 +++++++++++------ device/device_test.go | 12 +++++++----- device/peer.go | 2 +- device/send.go | 12 ++++++------ 8 files changed, 29 insertions(+), 22 deletions(-) rename {device => conn}/service.go (52%) diff --git a/conn/bind_std.go b/conn/bind_std.go index 5c24bedb4..39145d6a4 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -342,7 +342,7 @@ func (e ErrUDPGSODisabled) Unwrap() error { return e.RetryErr } -func (s *StdNetBind) Send(bufs [][]byte, services []uint64, endpoint Endpoint) error { +func (s *StdNetBind) Send(bufs [][]byte, services []Service, endpoint Endpoint) error { s.mu.Lock() blackhole := s.blackhole4 conn := s.ipv4 diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 4f32623a4..620200bf1 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -486,7 +486,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomi return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (bind *WinRingBind) Send(bufs [][]byte, services []uint64, endpoint Endpoint) error { +func (bind *WinRingBind) Send(bufs [][]byte, services []Service, endpoint Endpoint) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 3769f5700..d81591fe5 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -107,7 +107,7 @@ func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { } } -func (c *ChannelBind) Send(bufs [][]byte, services []uint64, ep conn.Endpoint) error { +func (c *ChannelBind) Send(bufs [][]byte, services []conn.Service, ep conn.Endpoint) error { for _, b := range bufs { select { case <-c.closeSignal: diff --git a/conn/conn.go b/conn/conn.go index 6a8b6893a..263d8250e 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -47,7 +47,7 @@ type Bind interface { // Send writes one or more packets in bufs to address ep. The length of // bufs must not exceed BatchSize(). - Send(bufs [][]byte, services []uint64, ep Endpoint) error + Send(bufs [][]byte, services []Service, ep Endpoint) error // ParseEndpoint creates a new endpoint from a string. ParseEndpoint(s string) (Endpoint, error) diff --git a/device/service.go b/conn/service.go similarity index 52% rename from device/service.go rename to conn/service.go index 9a2471f24..9f26d8152 100644 --- a/device/service.go +++ b/conn/service.go @@ -1,7 +1,12 @@ -package device +package conn -// ServiceFn process packet and return serivce id and drop flag -type ServiceFn func(buff []byte) (service uint64, shouldDrop bool) +// Service pass inner packet info to outer bind +type Service interface { + ID() uint64 +} + +// ServiceFn process inner packet and return service info and drop flag +type ServiceFn func(buff []byte) (service Service, shouldDrop bool) var serviceFns []ServiceFn @@ -11,11 +16,11 @@ func RegisterServiceFn(fn ServiceFn) { } // ExecuteServiceFns to process packet data -func ExecuteServiceFns(buff []byte) (service uint64, shouldDrop bool) { - finalService := uint64(0) +func ExecuteServiceFns(buff []byte) (service Service, shouldDrop bool) { + finalService := Service(nil) for _, fn := range serviceFns { service, shouldDrop = fn(buff) - if service != 0 { + if service != nil { finalService = service } if shouldDrop { diff --git a/device/device_test.go b/device/device_test.go index dbfe416fa..72311b7b2 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -426,11 +426,13 @@ type fakeBindSized struct { func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 0, nil } -func (b *fakeBindSized) Close() error { return nil } -func (b *fakeBindSized) SetMark(mark uint32) error { return nil } -func (b *fakeBindSized) Send(bufs [][]byte, services []uint64, ep conn.Endpoint) error { return nil } -func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } -func (b *fakeBindSized) BatchSize() int { return b.size } +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, services []conn.Service, ep conn.Endpoint) error { + return nil +} +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } type fakeTUNDeviceSized struct { size int diff --git a/device/peer.go b/device/peer.go index 1c9469546..501b2b628 100644 --- a/device/peer.go +++ b/device/peer.go @@ -113,7 +113,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } -func (peer *Peer) SendBuffers(buffers [][]byte, services []uint64) error { +func (peer *Peer) SendBuffers(buffers [][]byte, services []conn.Service) error { peer.device.net.RLock() defer peer.device.net.RUnlock() diff --git a/device/send.go b/device/send.go index 17b05693a..5f207858d 100644 --- a/device/send.go +++ b/device/send.go @@ -50,7 +50,7 @@ type QueueOutboundElement struct { nonce uint64 // nonce for encryption keypair *Keypair // keypair for encryption peer *Peer // related peer - service uint64 // inner packet service identifier + service conn.Service // inner packet service drop bool // service identifier result, should drop this packet } @@ -132,7 +132,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{packet}, []uint64{0}) + err = peer.SendBuffers([][]byte{packet}, []conn.Service{nil}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -169,7 +169,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{packet}, []uint64{0}) + err = peer.SendBuffers([][]byte{packet}, []conn.Service{nil}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -189,7 +189,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) packet := make([]byte, MessageCookieReplySize) _ = reply.marshal(packet) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{packet}, []uint64{0}, initiatingElem.endpoint) + device.net.bind.Send([][]byte{packet}, []conn.Service{nil}, initiatingElem.endpoint) return nil } @@ -448,7 +448,7 @@ func (device *Device) RoutineEncryption(id int) { for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { // identify inner packet - service, shouldDrop := ExecuteServiceFns(elem.packet) + service, shouldDrop := conn.ExecuteServiceFns(elem.packet) if shouldDrop { elem.drop = true continue @@ -493,7 +493,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device.log.Verbosef("%v - Routine: sequential sender - started", peer) bufs := make([][]byte, 0, maxBatchSize) - services := make([]uint64, 0, maxBatchSize) + services := make([]conn.Service, 0, maxBatchSize) for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0]