Skip to content

Commit 5820f0e

Browse files
committed
Fix reactor race conditions and resource leaks
1 parent 3bacbe6 commit 5820f0e

File tree

3 files changed

+85
-31
lines changed

3 files changed

+85
-31
lines changed

common/bufio/fd_poller_darwin.go

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,29 @@ import (
66
"context"
77
"sync"
88
"sync/atomic"
9+
"unsafe"
910

1011
"golang.org/x/sys/unix"
1112
)
1213

1314
type fdDemuxEntry struct {
14-
fd int
15-
handler FDHandler
15+
fd int
16+
registrationID uint64
17+
handler FDHandler
1618
}
1719

1820
type FDPoller struct {
19-
ctx context.Context
20-
cancel context.CancelFunc
21-
kqueueFD int
22-
mutex sync.Mutex
23-
entries map[int]*fdDemuxEntry
24-
running bool
25-
closed atomic.Bool
26-
wg sync.WaitGroup
27-
pipeFDs [2]int
21+
ctx context.Context
22+
cancel context.CancelFunc
23+
kqueueFD int
24+
mutex sync.Mutex
25+
entries map[int]*fdDemuxEntry
26+
registrationCounter uint64
27+
registrationToFD map[uint64]int
28+
running bool
29+
closed atomic.Bool
30+
wg sync.WaitGroup
31+
pipeFDs [2]int
2832
}
2933

3034
func NewFDPoller(ctx context.Context) (*FDPoller, error) {
@@ -69,11 +73,12 @@ func NewFDPoller(ctx context.Context) (*FDPoller, error) {
6973

7074
ctx, cancel := context.WithCancel(ctx)
7175
poller := &FDPoller{
72-
ctx: ctx,
73-
cancel: cancel,
74-
kqueueFD: kqueueFD,
75-
entries: make(map[int]*fdDemuxEntry),
76-
pipeFDs: pipeFDs,
76+
ctx: ctx,
77+
cancel: cancel,
78+
kqueueFD: kqueueFD,
79+
entries: make(map[int]*fdDemuxEntry),
80+
registrationToFD: make(map[uint64]int),
81+
pipeFDs: pipeFDs,
7782
}
7883
return poller, nil
7984
}
@@ -86,20 +91,26 @@ func (p *FDPoller) Add(handler FDHandler, fd int) error {
8691
return unix.EINVAL
8792
}
8893

94+
p.registrationCounter++
95+
registrationID := p.registrationCounter
96+
8997
_, err := unix.Kevent(p.kqueueFD, []unix.Kevent_t{{
9098
Ident: uint64(fd),
9199
Filter: unix.EVFILT_READ,
92-
Flags: unix.EV_ADD,
100+
Flags: unix.EV_ADD | unix.EV_ONESHOT,
101+
Udata: (*byte)(unsafe.Pointer(uintptr(registrationID))),
93102
}}, nil, nil)
94103
if err != nil {
95104
return err
96105
}
97106

98107
entry := &fdDemuxEntry{
99-
fd: fd,
100-
handler: handler,
108+
fd: fd,
109+
registrationID: registrationID,
110+
handler: handler,
101111
}
102112
p.entries[fd] = entry
113+
p.registrationToFD[registrationID] = fd
103114

104115
if !p.running {
105116
p.running = true
@@ -114,7 +125,7 @@ func (p *FDPoller) Remove(fd int) {
114125
p.mutex.Lock()
115126
defer p.mutex.Unlock()
116127

117-
_, ok := p.entries[fd]
128+
entry, ok := p.entries[fd]
118129
if !ok {
119130
return
120131
}
@@ -124,6 +135,7 @@ func (p *FDPoller) Remove(fd int) {
124135
Filter: unix.EVFILT_READ,
125136
Flags: unix.EV_DELETE,
126137
}}, nil, nil)
138+
delete(p.registrationToFD, entry.registrationID)
127139
delete(p.entries, fd)
128140
}
129141

@@ -196,18 +208,22 @@ func (p *FDPoller) run() {
196208
continue
197209
}
198210

211+
registrationID := uint64(uintptr(unsafe.Pointer(event.Udata)))
212+
199213
p.mutex.Lock()
200-
entry, ok := p.entries[fd]
201-
if !ok {
214+
mappedFD, ok := p.registrationToFD[registrationID]
215+
if !ok || mappedFD != fd {
216+
p.mutex.Unlock()
217+
continue
218+
}
219+
220+
entry := p.entries[fd]
221+
if entry == nil || entry.registrationID != registrationID {
202222
p.mutex.Unlock()
203223
continue
204224
}
205225

206-
unix.Kevent(p.kqueueFD, []unix.Kevent_t{{
207-
Ident: uint64(fd),
208-
Filter: unix.EVFILT_READ,
209-
Flags: unix.EV_DELETE,
210-
}}, nil, nil)
226+
delete(p.registrationToFD, registrationID)
211227
delete(p.entries, fd)
212228
p.mutex.Unlock()
213229

common/bufio/fd_poller_windows.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type fdDemuxEntry struct {
2222
baseHandle windows.Handle
2323
registrationID uint64
2424
cancelled bool
25+
unpinned bool
2526
pinner wepoll.Pinner
2627
}
2728

@@ -138,7 +139,10 @@ func (p *FDPoller) Close() error {
138139
defer p.mutex.Unlock()
139140

140141
for fd, entry := range p.entries {
141-
entry.pinner.Unpin()
142+
if !entry.unpinned {
143+
entry.unpinned = true
144+
entry.pinner.Unpin()
145+
}
142146
delete(p.entries, fd)
143147
}
144148

@@ -153,6 +157,32 @@ func (p *FDPoller) Close() error {
153157
return nil
154158
}
155159

160+
func (p *FDPoller) drainCompletions(completions []wepoll.OverlappedEntry) {
161+
for {
162+
var numRemoved uint32
163+
err := wepoll.GetQueuedCompletionStatusEx(p.iocp, &completions[0], uint32(len(completions)), &numRemoved, 0, false)
164+
if err != nil || numRemoved == 0 {
165+
break
166+
}
167+
168+
for i := uint32(0); i < numRemoved; i++ {
169+
event := completions[i]
170+
if event.Overlapped == nil {
171+
continue
172+
}
173+
174+
entry := (*fdDemuxEntry)(unsafe.Pointer(event.Overlapped))
175+
p.mutex.Lock()
176+
if p.entries[entry.fd] == entry && !entry.unpinned {
177+
entry.unpinned = true
178+
entry.pinner.Unpin()
179+
}
180+
delete(p.entries, entry.fd)
181+
p.mutex.Unlock()
182+
}
183+
}
184+
}
185+
156186
func (p *FDPoller) run() {
157187
defer p.wg.Done()
158188

@@ -161,6 +191,7 @@ func (p *FDPoller) run() {
161191
for {
162192
select {
163193
case <-p.ctx.Done():
194+
p.drainCompletions(completions)
164195
p.mutex.Lock()
165196
p.running = false
166197
p.mutex.Unlock()
@@ -193,7 +224,10 @@ func (p *FDPoller) run() {
193224
continue
194225
}
195226

196-
entry.pinner.Unpin()
227+
if !entry.unpinned {
228+
entry.unpinned = true
229+
entry.pinner.Unpin()
230+
}
197231
delete(p.entries, entry.fd)
198232

199233
if entry.cancelled {

common/bufio/stream_reactor.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,13 @@ func (d *streamDirection) handleEOFOrError(err error) {
331331

332332
// Try half-close on destination
333333
if d.isUpload {
334-
N.CloseWrite(d.connection.download.originSource)
334+
if d.connection.download != nil {
335+
N.CloseWrite(d.connection.download.originSource)
336+
}
335337
} else {
336-
N.CloseWrite(d.connection.upload.originSource)
338+
if d.connection.upload != nil {
339+
N.CloseWrite(d.connection.upload.originSource)
340+
}
337341
}
338342

339343
d.connection.checkBothClosed()

0 commit comments

Comments
 (0)