Skip to content

Commit d2f7a9c

Browse files
committed
Fix reactor issues and add context propagation
1 parent 91d05d2 commit d2f7a9c

File tree

7 files changed

+119
-25
lines changed

7 files changed

+119
-25
lines changed

common/bufio/fd_poller_windows.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ func (p *FDPoller) Remove(fd int) {
120120
if p.afd != nil {
121121
p.afd.Cancel(&entry.ioStatusBlock)
122122
}
123+
124+
if !entry.unpinned {
125+
entry.unpinned = true
126+
entry.pinner.Unpin()
127+
}
128+
delete(p.entries, fd)
123129
}
124130

125131
func (p *FDPoller) wakeup() {

common/bufio/packet_reactor.go

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,13 @@ func (r *PacketReactor) Close() error {
5555
}
5656

5757
type reactorConnection struct {
58-
ctx context.Context
59-
cancel context.CancelFunc
60-
reactor *PacketReactor
61-
onClose N.CloseHandlerFunc
62-
upload *reactorStream
63-
download *reactorStream
58+
ctx context.Context
59+
cancel context.CancelFunc
60+
reactor *PacketReactor
61+
onClose N.CloseHandlerFunc
62+
upload *reactorStream
63+
download *reactorStream
64+
stopReactorWatch func() bool
6465

6566
closeOnce sync.Once
6667
done chan struct{}
@@ -93,6 +94,9 @@ func (r *PacketReactor) Copy(ctx context.Context, source N.PacketConn, destinati
9394
onClose: onClose,
9495
done: make(chan struct{}),
9596
}
97+
conn.stopReactorWatch = common.ContextAfterFunc(r.ctx, func() {
98+
conn.closeWithError(r.ctx.Err())
99+
})
96100

97101
conn.upload = r.prepareStream(conn, source, destination)
98102
select {
@@ -126,10 +130,12 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe
126130
if cachedReader, isCached := source.(N.CachedPacketReader); isCached {
127131
packet := cachedReader.ReadCachedPacket()
128132
if packet != nil {
129-
dataLen := packet.Buffer.Len()
130-
err := destination.WritePacket(packet.Buffer, packet.Destination)
133+
buffer := packet.Buffer
134+
dataLen := buffer.Len()
135+
err := destination.WritePacket(buffer, packet.Destination)
131136
N.PutPacketBuffer(packet)
132137
if err != nil {
138+
buffer.Leak()
133139
conn.closeWithError(err)
134140
return stream
135141
}
@@ -151,7 +157,10 @@ func (r *PacketReactor) prepareStream(conn *reactorConnection, source N.PacketRe
151157

152158
stream.readWaiter, _ = CreatePacketReadWaiter(source)
153159
if stream.readWaiter != nil {
154-
stream.readWaiter.InitializeReadWaiter(stream.options)
160+
needCopy := stream.readWaiter.InitializeReadWaiter(stream.options)
161+
if needCopy {
162+
stream.readWaiter = nil
163+
}
155164
}
156165

157166
if pushable, ok := source.(N.PacketPushable); ok {
@@ -343,7 +352,7 @@ func (s *reactorStream) HandleFDEvent() {
343352
}
344353

345354
func (s *reactorStream) runLegacyCopy() {
346-
_, err := CopyPacket(s.destination, s.source)
355+
_, err := CopyPacketWithCounters(s.destination, s.source, s.originSource, s.readCounters, s.writeCounters)
347356
s.closeWithError(err)
348357
}
349358

@@ -353,6 +362,12 @@ func (s *reactorStream) closeWithError(err error) {
353362

354363
func (c *reactorConnection) closeWithError(err error) {
355364
c.closeOnce.Do(func() {
365+
defer close(c.done)
366+
367+
if c.stopReactorWatch != nil {
368+
c.stopReactorWatch()
369+
}
370+
356371
c.err = err
357372
c.cancel()
358373

@@ -375,8 +390,6 @@ func (c *reactorConnection) closeWithError(err error) {
375390
if c.onClose != nil {
376391
c.onClose(c.err)
377392
}
378-
379-
close(c.done)
380393
})
381394
}
382395

common/bufio/stream_reactor.go

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@ func (r *StreamReactor) Close() error {
5050
}
5151

5252
type streamConnection struct {
53-
ctx context.Context
54-
cancel context.CancelFunc
55-
reactor *StreamReactor
56-
onClose N.CloseHandlerFunc
57-
upload *streamDirection
58-
download *streamDirection
53+
ctx context.Context
54+
cancel context.CancelFunc
55+
reactor *StreamReactor
56+
onClose N.CloseHandlerFunc
57+
upload *streamDirection
58+
download *streamDirection
59+
stopReactorWatch func() bool
5960

6061
closeOnce sync.Once
6162
done chan struct{}
@@ -96,6 +97,9 @@ func (r *StreamReactor) Copy(ctx context.Context, source net.Conn, destination n
9697
onClose: onClose,
9798
done: make(chan struct{}),
9899
}
100+
conn.stopReactorWatch = common.ContextAfterFunc(r.ctx, func() {
101+
conn.closeWithError(r.ctx.Err())
102+
})
99103

100104
conn.upload = r.prepareDirection(conn, source, destination, source, true)
101105
select {
@@ -171,7 +175,10 @@ func (r *StreamReactor) prepareDirection(conn *streamConnection, source io.Reade
171175

172176
direction.readWaiter, _ = CreateReadWaiter(source)
173177
if direction.readWaiter != nil {
174-
direction.readWaiter.InitializeReadWaiter(direction.options)
178+
needCopy := direction.readWaiter.InitializeReadWaiter(direction.options)
179+
if needCopy {
180+
direction.readWaiter = nil
181+
}
175182
}
176183

177184
// Try to get stream pollable for FD-based idle detection
@@ -320,7 +327,7 @@ func (d *streamDirection) HandleFDEvent() {
320327
}
321328

322329
func (d *streamDirection) runLegacyCopy() {
323-
_, err := Copy(d.destination, d.source)
330+
_, err := CopyWithCounters(d.destination, d.source, d.originSource, d.readCounters, d.writeCounters, DefaultIncreaseBufferAfter, DefaultBatchSize)
324331
d.handleEOFOrError(err)
325332
}
326333

@@ -358,6 +365,12 @@ func (c *streamConnection) checkBothClosed() {
358365

359366
if uploadClosed && downloadClosed {
360367
c.closeOnce.Do(func() {
368+
defer close(c.done)
369+
370+
if c.stopReactorWatch != nil {
371+
c.stopReactorWatch()
372+
}
373+
361374
c.cancel()
362375
c.removeFromPoller()
363376

@@ -367,14 +380,18 @@ func (c *streamConnection) checkBothClosed() {
367380
if c.onClose != nil {
368381
c.onClose(nil)
369382
}
370-
371-
close(c.done)
372383
})
373384
}
374385
}
375386

376387
func (c *streamConnection) closeWithError(err error) {
377388
c.closeOnce.Do(func() {
389+
defer close(c.done)
390+
391+
if c.stopReactorWatch != nil {
392+
c.stopReactorWatch()
393+
}
394+
378395
c.err = err
379396
c.cancel()
380397

@@ -397,8 +414,6 @@ func (c *streamConnection) closeWithError(err error) {
397414
if c.onClose != nil {
398415
c.onClose(c.err)
399416
}
400-
401-
close(c.done)
402417
})
403418
}
404419

common/context_afterfunc.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//go:build go1.21
2+
3+
package common
4+
5+
import "context"
6+
7+
// ContextAfterFunc arranges to call f in its own goroutine after ctx is done.
8+
// Returns a stop function that prevents f from being run.
9+
func ContextAfterFunc(ctx context.Context, f func()) (stop func() bool) {
10+
return context.AfterFunc(ctx, f)
11+
}

common/context_afterfunc_compat.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//go:build go1.20 && !go1.21
2+
3+
package common
4+
5+
import (
6+
"context"
7+
"sync"
8+
)
9+
10+
// ContextAfterFunc arranges to call f in its own goroutine after ctx is done.
11+
// Returns a stop function that prevents f from being run.
12+
func ContextAfterFunc(ctx context.Context, f func()) (stop func() bool) {
13+
stopCh := make(chan struct{})
14+
var once sync.Once
15+
stopped := false
16+
17+
go func() {
18+
select {
19+
case <-ctx.Done():
20+
once.Do(func() {
21+
if !stopped {
22+
f()
23+
}
24+
})
25+
case <-stopCh:
26+
}
27+
}()
28+
29+
return func() bool {
30+
select {
31+
case <-ctx.Done():
32+
return false
33+
default:
34+
stopped = true
35+
once.Do(func() {
36+
close(stopCh)
37+
})
38+
return true
39+
}
40+
}
41+
}

common/network/handshake.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func ReportHandshakeFailure(reporter any, err error) error {
3030
return E.Cause(err, "write handshake failure")
3131
})
3232
}
33-
return nil
33+
return err
3434
}
3535

3636
func CloseOnHandshakeFailure(reporter io.Closer, onClose CloseHandlerFunc, err error) error {

common/udpnat2/conn.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,14 @@ func (c *natConn) SetTimeout(timeout time.Duration) bool {
161161
func (c *natConn) Close() error {
162162
c.closeOnce.Do(func() {
163163
close(c.doneChan)
164+
165+
c.queueMutex.Lock()
166+
pending := c.dataQueue
167+
c.dataQueue = nil
168+
c.onDataReady = nil
169+
c.queueMutex.Unlock()
170+
171+
N.ReleaseMultiPacketBuffer(pending)
164172
common.Close(c.handler)
165173
})
166174
return nil

0 commit comments

Comments
 (0)