Skip to content

Commit 593eded

Browse files
authored
Trojan-UoT & UDP-nameserver: Fix forgotten release buffer; UDP dispatcher: Simplified and optimized (#5050)
1 parent 82ea7a3 commit 593eded

File tree

11 files changed

+83
-45
lines changed

11 files changed

+83
-45
lines changed

app/dns/nameserver_udp.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,9 @@ func (s *ClassicNameServer) RequestsCleanup() error {
9090

9191
// HandleResponse handles udp response packet from remote DNS server.
9292
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
93-
ipRec, err := parseResponse(packet.Payload.Bytes())
93+
payload := packet.Payload
94+
ipRec, err := parseResponse(payload.Bytes())
95+
payload.Release()
9496
if err != nil {
9597
errors.LogError(ctx, s.Name(), " fail to parse responded DNS udp")
9698
return
@@ -125,6 +127,8 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
125127
newReq.msg = &newMsg
126128
s.addPendingRequest(&newReq)
127129
b, _ := dns.PackMessage(newReq.msg)
130+
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
131+
b.UDP = &copyDest
128132
s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
129133
return
130134
}
@@ -158,6 +162,8 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, _ chan<- error, domai
158162
}
159163
s.addPendingRequest(udpReq)
160164
b, _ := dns.PackMessage(req.msg)
165+
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
166+
b.UDP = &copyDest
161167
s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
162168
}
163169
}

app/proxyman/outbound/handler.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,10 @@ func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
239239
}
240240
out:
241241
err := h.proxy.Process(ctx, link, h)
242+
var errC error
242243
if err != nil {
243-
if goerrors.Is(err, io.EOF) || goerrors.Is(err, io.ErrClosedPipe) || goerrors.Is(err, context.Canceled) {
244+
errC = errors.Cause(err)
245+
if goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled) {
244246
err = nil
245247
}
246248
}
@@ -251,7 +253,11 @@ out:
251253
errors.LogInfo(ctx, err.Error())
252254
common.Interrupt(link.Writer)
253255
} else {
254-
common.Close(link.Writer)
256+
if errC != nil && goerrors.Is(errC, io.ErrClosedPipe) {
257+
common.Interrupt(link.Writer)
258+
} else {
259+
common.Close(link.Writer)
260+
}
255261
}
256262
common.Interrupt(link.Reader)
257263
}

common/mux/client.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mux
22

33
import (
44
"context"
5+
goerrors "errors"
56
"io"
67
"sync"
78
"time"
@@ -154,8 +155,11 @@ func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
154155
ctx := session.ContextWithOutbounds(context.Background(), outbounds)
155156
ctx, cancel := context.WithCancel(ctx)
156157

157-
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
158-
errors.LogInfoInner(ctx, err, "failed to handler mux client connection")
158+
if errP := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); errP != nil {
159+
errC := errors.Cause(errP)
160+
if !(goerrors.Is(errC, io.EOF) || goerrors.Is(errC, io.ErrClosedPipe) || goerrors.Is(errC, context.Canceled)) {
161+
errors.LogInfoInner(ctx, errP, "failed to handler mux client connection")
162+
}
159163
}
160164
common.Must(c.Close())
161165
cancel()
@@ -222,7 +226,7 @@ func (m *ClientWorker) monitor() {
222226
select {
223227
case <-m.done.Wait():
224228
m.sessionManager.Close()
225-
common.Close(m.link.Writer)
229+
common.Interrupt(m.link.Writer)
226230
common.Interrupt(m.link.Reader)
227231
return
228232
case <-m.timer.C:
@@ -247,7 +251,7 @@ func writeFirstPayload(reader buf.Reader, writer *Writer) error {
247251
return nil
248252
}
249253

250-
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
254+
func fetchInput(ctx context.Context, s *Session, output buf.Writer, timer *time.Ticker) {
251255
outbounds := session.OutboundsFromContext(ctx)
252256
ob := outbounds[len(outbounds)-1]
253257
transferType := protocol.TransferTypeStream
@@ -258,6 +262,7 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
258262
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
259263
defer s.Close(false)
260264
defer writer.Close()
265+
defer timer.Reset(time.Second * 16)
261266

262267
errors.LogInfo(ctx, "dispatching request to ", ob.Target)
263268
if err := writeFirstPayload(s.input, writer); err != nil {
@@ -308,9 +313,9 @@ func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool
308313
s.input = link.Reader
309314
s.output = link.Writer
310315
if _, ok := link.Reader.(*pipe.Reader); ok {
311-
go fetchInput(ctx, s, m.link.Writer)
316+
go fetchInput(ctx, s, m.link.Writer, m.timer)
312317
} else {
313-
fetchInput(ctx, s, m.link.Writer)
318+
fetchInput(ctx, s, m.link.Writer, m.timer)
314319
}
315320
return true
316321
}

common/mux/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ func (w *ServerWorker) run(ctx context.Context) {
318318
reader := &buf.BufferedReader{Reader: w.link.Reader}
319319

320320
defer w.sessionManager.Close()
321-
defer common.Close(w.link.Writer)
322321
defer common.Interrupt(w.link.Reader)
322+
defer common.Interrupt(w.link.Writer)
323323

324324
for {
325325
select {

proxy/freedom/freedom.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func isValidAddress(addr *net.IPOrDomain) bool {
7373
}
7474

7575
a := addr.AsAddress()
76-
return a != net.AnyIP
76+
return a != net.AnyIP && a != net.AnyIPv6
7777
}
7878

7979
// Process implements proxy.Outbound.
@@ -418,7 +418,7 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
418418
}
419419
}
420420
}
421-
destAddr, _ := net.ResolveUDPAddr("udp", b.UDP.NetAddr())
421+
destAddr := b.UDP.RawNetAddr()
422422
if destAddr == nil {
423423
b.Release()
424424
continue

proxy/proxy.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,9 @@ func CopyRawConnIfExist(ctx context.Context, readerConn net.Conn, writerConn net
636636
}
637637
}
638638
if err != nil {
639+
if errors.Cause(err) == io.EOF {
640+
return nil
641+
}
639642
return err
640643
}
641644
}

proxy/shadowsocks/server.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,12 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Con
104104
func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
105105
udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
106106
request := protocol.RequestHeaderFromContext(ctx)
107+
payload := packet.Payload
107108
if request == nil {
109+
payload.Release()
108110
return
109111
}
110112

111-
payload := packet.Payload
112-
113113
if payload.UDP != nil {
114114
request = &protocol.RequestHeader{
115115
User: request.User,
@@ -124,9 +124,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
124124
errors.LogWarningInner(ctx, err, "failed to encode UDP packet")
125125
return
126126
}
127-
defer data.Release()
128127

129128
conn.Write(data.Bytes())
129+
data.Release()
130130
})
131131
defer udpServer.RemoveRay()
132132

proxy/socks/server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
196196

197197
request := protocol.RequestHeaderFromContext(ctx)
198198
if request == nil {
199+
payload.Release()
199200
return
200201
}
201202

@@ -214,9 +215,9 @@ func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dis
214215
errors.LogWarningInner(ctx, err, "failed to write UDP response")
215216
return
216217
}
217-
defer udpMessage.Release()
218218

219219
conn.Write(udpMessage.Bytes())
220+
udpMessage.Release()
220221
})
221222
defer udpServer.RemoveRay()
222223

proxy/trojan/protocol.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,11 @@ func (w *PacketWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
113113
target = b.UDP
114114
}
115115
if _, err := w.writePacket(b.Bytes(), *target); err != nil {
116+
b.Release()
116117
buf.ReleaseMulti(mb)
117118
return err
118119
}
120+
b.Release()
119121
}
120122
return nil
121123
}

transport/internet/udp/dispatcher.go

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,24 @@ type ResponseCallback func(ctx context.Context, packet *udp.Packet)
2222

2323
type connEntry struct {
2424
link *transport.Link
25-
timer signal.ActivityUpdater
25+
timer *signal.ActivityTimer
2626
cancel context.CancelFunc
27+
closed bool
28+
}
29+
30+
func (c *connEntry) Close() error {
31+
c.timer.SetTimeout(0)
32+
return nil
33+
}
34+
35+
func (c *connEntry) terminate() {
36+
if c.closed {
37+
panic("terminate called more than once")
38+
}
39+
c.closed = true
40+
c.cancel()
41+
common.Interrupt(c.link.Reader)
42+
common.Interrupt(c.link.Writer)
2743
}
2844

2945
type Dispatcher struct {
@@ -32,6 +48,7 @@ type Dispatcher struct {
3248
dispatcher routing.Dispatcher
3349
callback ResponseCallback
3450
callClose func() error
51+
closed bool
3552
}
3653

3754
func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
@@ -44,13 +61,9 @@ func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Di
4461
func (v *Dispatcher) RemoveRay() {
4562
v.Lock()
4663
defer v.Unlock()
47-
v.removeRay()
48-
}
49-
50-
func (v *Dispatcher) removeRay() {
64+
v.closed = true
5165
if v.conn != nil {
52-
common.Interrupt(v.conn.link.Reader)
53-
common.Close(v.conn.link.Writer)
66+
v.conn.Close()
5467
v.conn = nil
5568
}
5669
}
@@ -59,35 +72,34 @@ func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*
5972
v.Lock()
6073
defer v.Unlock()
6174

75+
if v.closed {
76+
return nil, errors.New("dispatcher is closed")
77+
}
78+
6279
if v.conn != nil {
63-
return v.conn, nil
80+
if v.conn.closed {
81+
v.conn = nil
82+
} else {
83+
return v.conn, nil
84+
}
6485
}
6586

6687
errors.LogInfo(ctx, "establishing new connection for ", dest)
6788

6889
ctx, cancel := context.WithCancel(ctx)
69-
entry := &connEntry{}
70-
removeRay := func() {
71-
v.Lock()
72-
defer v.Unlock()
73-
// sometimes the entry is already removed by others, don't close again
74-
if entry == v.conn {
75-
cancel()
76-
v.removeRay()
77-
}
78-
}
79-
timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
8090

8191
link, err := v.dispatcher.Dispatch(ctx, dest)
8292
if err != nil {
93+
cancel()
8394
return nil, errors.New("failed to dispatch request to ", dest).Base(err)
8495
}
8596

86-
*entry = connEntry{
97+
entry := &connEntry{
8798
link: link,
88-
timer: timer,
89-
cancel: removeRay,
99+
cancel: cancel,
90100
}
101+
102+
entry.timer = signal.CancelAfterInactivity(ctx, entry.terminate, time.Minute)
91103
v.conn = entry
92104
go handleInput(ctx, entry, dest, v.callback, v.callClose)
93105
return entry, nil
@@ -106,15 +118,15 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination,
106118
if outputStream != nil {
107119
if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
108120
errors.LogInfoInner(ctx, err, "failed to write first UDP payload")
109-
conn.cancel()
121+
conn.Close()
110122
return
111123
}
112124
}
113125
}
114126

115127
func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
116128
defer func() {
117-
conn.cancel()
129+
conn.Close()
118130
if callClose != nil {
119131
callClose()
120132
}

0 commit comments

Comments
 (0)