Skip to content

Commit 652d2a3

Browse files
committed
refactor(udp): enhance session management and cleanup logic
1 parent 55eada8 commit 652d2a3

File tree

1 file changed

+97
-15
lines changed

1 file changed

+97
-15
lines changed

server/udp.go

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"net"
66
"strings"
77
"sync"
8+
"sync/atomic"
89
"time"
910

1011
"github.com/AppleBlockTeam/abmc-forwarder/config"
@@ -20,6 +21,30 @@ type UDPHandler struct {
2021
wg *sync.WaitGroup
2122
}
2223

24+
type udpSession struct {
25+
conn *net.UDPConn
26+
lastActive atomic.Int64
27+
}
28+
29+
func newUDPSession(conn *net.UDPConn) *udpSession {
30+
session := &udpSession{conn: conn}
31+
session.touch()
32+
return session
33+
}
34+
35+
func (s *udpSession) touch() {
36+
s.lastActive.Store(time.Now().UnixNano())
37+
}
38+
39+
func (s *udpSession) idleFor(now time.Time) time.Duration {
40+
lastActive := s.lastActive.Load()
41+
if lastActive == 0 {
42+
return 0
43+
}
44+
45+
return now.Sub(time.Unix(0, lastActive))
46+
}
47+
2348
// NewUDPHandler 创建新的UDP处理器
2449
func NewUDPHandler(cfg config.Config, wg *sync.WaitGroup, done chan struct{}) *UDPHandler {
2550
return &UDPHandler{
@@ -58,7 +83,7 @@ func (h *UDPHandler) Stop() {
5883
// handlePackets 处理UDP数据包
5984
func (h *UDPHandler) handlePackets() {
6085
// UDP 连接映射
61-
connMap := make(map[string]*net.UDPConn)
86+
connMap := make(map[string]*udpSession)
6287
var connMapMutex sync.Mutex
6388

6489
// 读取缓冲区
@@ -70,22 +95,65 @@ func (h *UDPHandler) handlePackets() {
7095
useProxyProto = false
7196
}
7297

98+
cleanupInterval := time.Duration(0)
99+
if h.config.Timeout > 0 {
100+
cleanupInterval = h.config.Timeout
101+
if cleanupInterval > time.Minute {
102+
cleanupInterval = time.Minute
103+
}
104+
if cleanupInterval < time.Second {
105+
cleanupInterval = time.Second
106+
}
107+
}
108+
109+
cleanupIdleSessions := func(now time.Time) {
110+
if h.config.Timeout <= 0 {
111+
return
112+
}
113+
114+
expiredSessions := make([]*udpSession, 0)
115+
116+
connMapMutex.Lock()
117+
for clientAddrStr, session := range connMap {
118+
if session.idleFor(now) < h.config.Timeout {
119+
continue
120+
}
121+
122+
delete(connMap, clientAddrStr)
123+
expiredSessions = append(expiredSessions, session)
124+
125+
if h.config.LogConnections {
126+
log.Printf("[%s] UDP 会话空闲超时,已关闭\n", clientAddrStr)
127+
}
128+
}
129+
connMapMutex.Unlock()
130+
131+
for _, session := range expiredSessions {
132+
session.conn.Close()
133+
}
134+
}
135+
73136
for {
74137
select {
75138
case <-h.done:
76139
// 关闭所有连接
140+
sessions := make([]*udpSession, 0)
77141
connMapMutex.Lock()
78-
for _, conn := range connMap {
79-
conn.Close()
142+
for _, session := range connMap {
143+
sessions = append(sessions, session)
80144
}
81145
connMapMutex.Unlock()
146+
147+
for _, session := range sessions {
148+
session.conn.Close()
149+
}
82150
return
83151
default:
84152
}
85153

86154
// 设置读取超时
87-
if h.config.Timeout > 0 {
88-
h.conn.SetReadDeadline(time.Now().Add(h.config.Timeout))
155+
if cleanupInterval > 0 {
156+
h.conn.SetReadDeadline(time.Now().Add(cleanupInterval))
89157
}
90158

91159
// 读取 UDP 数据包
@@ -94,6 +162,7 @@ func (h *UDPHandler) handlePackets() {
94162
// 检查是否是超时错误
95163
netErr, ok := err.(net.Error)
96164
if ok && netErr.Timeout() {
165+
cleanupIdleSessions(time.Now())
97166
continue
98167
}
99168
if strings.Contains(err.Error(), "use of closed network connection") {
@@ -112,7 +181,7 @@ func (h *UDPHandler) handlePackets() {
112181

113182
// 查找或创建到远程的连接
114183
connMapMutex.Lock()
115-
remoteConn, exists := connMap[clientAddrStr]
184+
session, exists := connMap[clientAddrStr]
116185

117186
if !exists {
118187
var err error
@@ -123,41 +192,48 @@ func (h *UDPHandler) handlePackets() {
123192
continue
124193
}
125194

126-
remoteConn, err = net.DialUDP("udp", nil, remoteAddr)
195+
remoteConn, err := net.DialUDP("udp", nil, remoteAddr)
127196
if err != nil {
128197
log.Printf("[%s] 连接远程服务器失败: %v\n", h.config.RemoteUDPAddr, err)
129198
connMapMutex.Unlock()
130199
continue
131200
}
132201

133-
connMap[clientAddrStr] = remoteConn
202+
session = newUDPSession(remoteConn)
203+
connMap[clientAddrStr] = session
134204

135205
if h.config.LogConnections {
136206
log.Printf("[%s] 新的 UDP 连接 -> [%s]\n", clientAddr, h.config.RemoteUDPAddr)
137207
}
138208

139209
// 启动一个 goroutine 来处理远程服务器返回的响应
140210
h.wg.Add(1)
141-
go func(clientAddr net.Addr, remoteConn *net.UDPConn, clientAddrStr string) {
211+
go func(clientAddr net.Addr, session *udpSession, clientAddrStr string) {
142212
defer h.wg.Done()
143213
defer func() {
144214
connMapMutex.Lock()
145-
delete(connMap, clientAddrStr)
146-
remoteConn.Close()
215+
if currentSession, ok := connMap[clientAddrStr]; ok && currentSession == session {
216+
delete(connMap, clientAddrStr)
217+
}
147218
connMapMutex.Unlock()
219+
220+
session.conn.Close()
148221
}()
149222

150223
responseBuffer := make([]byte, h.config.BufferSize)
151224
for {
152-
if h.config.Timeout > 0 {
153-
remoteConn.SetReadDeadline(time.Now().Add(h.config.Timeout))
225+
if cleanupInterval > 0 {
226+
session.conn.SetReadDeadline(time.Now().Add(cleanupInterval))
154227
}
155228

156229
// 从远程服务器读取响应
157-
n, _, err := remoteConn.ReadFrom(responseBuffer)
230+
n, _, err := session.conn.ReadFrom(responseBuffer)
158231
if err != nil {
159232
netErr, ok := err.(net.Error)
160233
if ok && netErr.Timeout() {
234+
if h.config.Timeout > 0 && session.idleFor(time.Now()) >= h.config.Timeout {
235+
break
236+
}
161237
continue
162238
}
163239
if !utils.IsConnectionClosed(err) {
@@ -168,6 +244,7 @@ func (h *UDPHandler) handlePackets() {
168244

169245
// 获取到响应数据
170246
responseData := responseBuffer[:n]
247+
session.touch()
171248

172249
// 将响应数据发送给客户端
173250
_, err = h.conn.WriteTo(responseData, clientAddr)
@@ -176,9 +253,12 @@ func (h *UDPHandler) handlePackets() {
176253
break
177254
}
178255
}
179-
}(clientAddr, remoteConn, clientAddrStr)
256+
}(clientAddr, session, clientAddrStr)
180257
}
181258

259+
session.touch()
260+
remoteConn := session.conn
261+
182262
// 转发数据到远程服务器
183263
data := buffer[:n]
184264

@@ -210,5 +290,7 @@ func (h *UDPHandler) handlePackets() {
210290
if err != nil && !utils.IsConnectionClosed(err) {
211291
log.Printf("转发 UDP 数据到远程服务器失败: %v\n", err)
212292
}
293+
294+
cleanupIdleSessions(time.Now())
213295
}
214296
}

0 commit comments

Comments
 (0)