Skip to content

Commit ea1a3ae

Browse files
authored
Trojan UoT: Fix memory/goroutine leak (#5064)
1 parent 593eded commit ea1a3ae

File tree

1 file changed

+52
-36
lines changed

1 file changed

+52
-36
lines changed

proxy/trojan/server.go

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
233233
sessionPolicy = s.policyManager.ForLevel(user.Level)
234234

235235
if destination.Network == net.Network_UDP { // handle udp request
236-
return s.handleUDPPayload(ctx, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
236+
return s.handleUDPPayload(ctx, sessionPolicy, &PacketReader{Reader: clientReader}, &PacketWriter{Writer: conn}, dispatcher)
237237
}
238238

239239
ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
@@ -248,7 +248,11 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
248248
return s.handleConnection(ctx, sessionPolicy, destination, clientReader, buf.NewWriter(conn), dispatcher)
249249
}
250250

251-
func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
251+
func (s *Server) handleUDPPayload(ctx context.Context, sessionPolicy policy.Session, clientReader *PacketReader, clientWriter *PacketWriter, dispatcher routing.Dispatcher) error {
252+
ctx, cancel := context.WithCancel(ctx)
253+
defer cancel()
254+
timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
255+
defer timer.SetTimeout(0)
252256
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
253257
udpPayload := packet.Payload
254258
if udpPayload.UDP == nil {
@@ -257,6 +261,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
257261

258262
if err := clientWriter.WriteMultiBuffer(buf.MultiBuffer{udpPayload}); err != nil {
259263
errors.LogWarningInner(ctx, err, "failed to write response")
264+
cancel()
265+
} else {
266+
timer.Update()
260267
}
261268
})
262269
defer udpServer.RemoveRay()
@@ -266,47 +273,56 @@ func (s *Server) handleUDPPayload(ctx context.Context, clientReader *PacketReade
266273

267274
var dest *net.Destination
268275

269-
for {
270-
select {
271-
case <-ctx.Done():
272-
return nil
273-
default:
274-
mb, err := clientReader.ReadMultiBuffer()
275-
if err != nil {
276-
if errors.Cause(err) != io.EOF {
277-
return errors.New("unexpected EOF").Base(err)
278-
}
276+
requestDone := func() error {
277+
for {
278+
select {
279+
case <-ctx.Done():
279280
return nil
280-
}
281+
default:
282+
mb, err := clientReader.ReadMultiBuffer()
283+
if err != nil {
284+
if errors.Cause(err) != io.EOF {
285+
return errors.New("unexpected EOF").Base(err)
286+
}
287+
return nil
288+
}
281289

282-
mb2, b := buf.SplitFirst(mb)
283-
if b == nil {
284-
continue
285-
}
286-
destination := *b.UDP
287-
288-
currentPacketCtx := ctx
289-
if inbound.Source.IsValid() {
290-
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
291-
From: inbound.Source,
292-
To: destination,
293-
Status: log.AccessAccepted,
294-
Reason: "",
295-
Email: user.Email,
296-
})
297-
}
298-
errors.LogInfo(ctx, "tunnelling request to ", destination)
290+
mb2, b := buf.SplitFirst(mb)
291+
if b == nil {
292+
continue
293+
}
294+
timer.Update()
295+
destination := *b.UDP
296+
297+
currentPacketCtx := ctx
298+
if inbound.Source.IsValid() {
299+
currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
300+
From: inbound.Source,
301+
To: destination,
302+
Status: log.AccessAccepted,
303+
Reason: "",
304+
Email: user.Email,
305+
})
306+
}
307+
errors.LogInfo(ctx, "tunnelling request to ", destination)
299308

300-
if !s.cone || dest == nil {
301-
dest = &destination
302-
}
309+
if !s.cone || dest == nil {
310+
dest = &destination
311+
}
303312

304-
udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
305-
for _, payload := range mb2 {
306-
udpServer.Dispatch(currentPacketCtx, *dest, payload)
313+
udpServer.Dispatch(currentPacketCtx, *dest, b) // first packet
314+
for _, payload := range mb2 {
315+
udpServer.Dispatch(currentPacketCtx, *dest, payload)
316+
}
307317
}
308318
}
319+
320+
}
321+
322+
if err := task.Run(ctx, requestDone); err != nil {
323+
return err
309324
}
325+
return nil
310326
}
311327

312328
func (s *Server) handleConnection(ctx context.Context, sessionPolicy policy.Session,

0 commit comments

Comments
 (0)