55 "net"
66 "strings"
77 "sync"
8+ "sync/atomic"
89 "time"
910
1011 "github.com/AppleBlockTeam/abmc-forwarder/config"
@@ -20,6 +21,30 @@ type UDPHandler struct {
2021 wg * sync.WaitGroup
2122}
2223
24+ type udpSession struct {
25+ conn * net.UDPConn
26+ lastActive atomic.Int64
27+ }
28+
29+ func newUDPSession (conn * net.UDPConn ) * udpSession {
30+ session := & udpSession {conn : conn }
31+ session .touch ()
32+ return session
33+ }
34+
35+ func (s * udpSession ) touch () {
36+ s .lastActive .Store (time .Now ().UnixNano ())
37+ }
38+
39+ func (s * udpSession ) idleFor (now time.Time ) time.Duration {
40+ lastActive := s .lastActive .Load ()
41+ if lastActive == 0 {
42+ return 0
43+ }
44+
45+ return now .Sub (time .Unix (0 , lastActive ))
46+ }
47+
2348// NewUDPHandler 创建新的UDP处理器
2449func NewUDPHandler (cfg config.Config , wg * sync.WaitGroup , done chan struct {}) * UDPHandler {
2550 return & UDPHandler {
@@ -58,7 +83,7 @@ func (h *UDPHandler) Stop() {
5883// handlePackets 处理UDP数据包
5984func (h * UDPHandler ) handlePackets () {
6085 // UDP 连接映射
61- connMap := make (map [string ]* net. UDPConn )
86+ connMap := make (map [string ]* udpSession )
6287 var connMapMutex sync.Mutex
6388
6489 // 读取缓冲区
@@ -70,22 +95,65 @@ func (h *UDPHandler) handlePackets() {
7095 useProxyProto = false
7196 }
7297
98+ cleanupInterval := time .Duration (0 )
99+ if h .config .Timeout > 0 {
100+ cleanupInterval = h .config .Timeout
101+ if cleanupInterval > time .Minute {
102+ cleanupInterval = time .Minute
103+ }
104+ if cleanupInterval < time .Second {
105+ cleanupInterval = time .Second
106+ }
107+ }
108+
109+ cleanupIdleSessions := func (now time.Time ) {
110+ if h .config .Timeout <= 0 {
111+ return
112+ }
113+
114+ expiredSessions := make ([]* udpSession , 0 )
115+
116+ connMapMutex .Lock ()
117+ for clientAddrStr , session := range connMap {
118+ if session .idleFor (now ) < h .config .Timeout {
119+ continue
120+ }
121+
122+ delete (connMap , clientAddrStr )
123+ expiredSessions = append (expiredSessions , session )
124+
125+ if h .config .LogConnections {
126+ log .Printf ("[%s] UDP 会话空闲超时,已关闭\n " , clientAddrStr )
127+ }
128+ }
129+ connMapMutex .Unlock ()
130+
131+ for _ , session := range expiredSessions {
132+ session .conn .Close ()
133+ }
134+ }
135+
73136 for {
74137 select {
75138 case <- h .done :
76139 // 关闭所有连接
140+ sessions := make ([]* udpSession , 0 )
77141 connMapMutex .Lock ()
78- for _ , conn := range connMap {
79- conn . Close ( )
142+ for _ , session := range connMap {
143+ sessions = append ( sessions , session )
80144 }
81145 connMapMutex .Unlock ()
146+
147+ for _ , session := range sessions {
148+ session .conn .Close ()
149+ }
82150 return
83151 default :
84152 }
85153
86154 // 设置读取超时
87- if h . config . Timeout > 0 {
88- h .conn .SetReadDeadline (time .Now ().Add (h . config . Timeout ))
155+ if cleanupInterval > 0 {
156+ h .conn .SetReadDeadline (time .Now ().Add (cleanupInterval ))
89157 }
90158
91159 // 读取 UDP 数据包
@@ -94,6 +162,7 @@ func (h *UDPHandler) handlePackets() {
94162 // 检查是否是超时错误
95163 netErr , ok := err .(net.Error )
96164 if ok && netErr .Timeout () {
165+ cleanupIdleSessions (time .Now ())
97166 continue
98167 }
99168 if strings .Contains (err .Error (), "use of closed network connection" ) {
@@ -112,7 +181,7 @@ func (h *UDPHandler) handlePackets() {
112181
113182 // 查找或创建到远程的连接
114183 connMapMutex .Lock ()
115- remoteConn , exists := connMap [clientAddrStr ]
184+ session , exists := connMap [clientAddrStr ]
116185
117186 if ! exists {
118187 var err error
@@ -123,41 +192,48 @@ func (h *UDPHandler) handlePackets() {
123192 continue
124193 }
125194
126- remoteConn , err = net .DialUDP ("udp" , nil , remoteAddr )
195+ remoteConn , err : = net .DialUDP ("udp" , nil , remoteAddr )
127196 if err != nil {
128197 log .Printf ("[%s] 连接远程服务器失败: %v\n " , h .config .RemoteUDPAddr , err )
129198 connMapMutex .Unlock ()
130199 continue
131200 }
132201
133- connMap [clientAddrStr ] = remoteConn
202+ session = newUDPSession (remoteConn )
203+ connMap [clientAddrStr ] = session
134204
135205 if h .config .LogConnections {
136206 log .Printf ("[%s] 新的 UDP 连接 -> [%s]\n " , clientAddr , h .config .RemoteUDPAddr )
137207 }
138208
139209 // 启动一个 goroutine 来处理远程服务器返回的响应
140210 h .wg .Add (1 )
141- go func (clientAddr net.Addr , remoteConn * net. UDPConn , clientAddrStr string ) {
211+ go func (clientAddr net.Addr , session * udpSession , clientAddrStr string ) {
142212 defer h .wg .Done ()
143213 defer func () {
144214 connMapMutex .Lock ()
145- delete (connMap , clientAddrStr )
146- remoteConn .Close ()
215+ if currentSession , ok := connMap [clientAddrStr ]; ok && currentSession == session {
216+ delete (connMap , clientAddrStr )
217+ }
147218 connMapMutex .Unlock ()
219+
220+ session .conn .Close ()
148221 }()
149222
150223 responseBuffer := make ([]byte , h .config .BufferSize )
151224 for {
152- if h . config . Timeout > 0 {
153- remoteConn . SetReadDeadline (time .Now ().Add (h . config . Timeout ))
225+ if cleanupInterval > 0 {
226+ session . conn . SetReadDeadline (time .Now ().Add (cleanupInterval ))
154227 }
155228
156229 // 从远程服务器读取响应
157- n , _ , err := remoteConn .ReadFrom (responseBuffer )
230+ n , _ , err := session . conn .ReadFrom (responseBuffer )
158231 if err != nil {
159232 netErr , ok := err .(net.Error )
160233 if ok && netErr .Timeout () {
234+ if h .config .Timeout > 0 && session .idleFor (time .Now ()) >= h .config .Timeout {
235+ break
236+ }
161237 continue
162238 }
163239 if ! utils .IsConnectionClosed (err ) {
@@ -168,6 +244,7 @@ func (h *UDPHandler) handlePackets() {
168244
169245 // 获取到响应数据
170246 responseData := responseBuffer [:n ]
247+ session .touch ()
171248
172249 // 将响应数据发送给客户端
173250 _ , err = h .conn .WriteTo (responseData , clientAddr )
@@ -176,9 +253,12 @@ func (h *UDPHandler) handlePackets() {
176253 break
177254 }
178255 }
179- }(clientAddr , remoteConn , clientAddrStr )
256+ }(clientAddr , session , clientAddrStr )
180257 }
181258
259+ session .touch ()
260+ remoteConn := session .conn
261+
182262 // 转发数据到远程服务器
183263 data := buffer [:n ]
184264
@@ -210,5 +290,7 @@ func (h *UDPHandler) handlePackets() {
210290 if err != nil && ! utils .IsConnectionClosed (err ) {
211291 log .Printf ("转发 UDP 数据到远程服务器失败: %v\n " , err )
212292 }
293+
294+ cleanupIdleSessions (time .Now ())
213295 }
214296}
0 commit comments