Skip to content

Commit 2cfceca

Browse files
committed
Add tests for wepoll to debug
1 parent 92f9285 commit 2cfceca

File tree

5 files changed

+937
-16
lines changed

5 files changed

+937
-16
lines changed

common/bufio/fd_demux_windows.go

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"context"
77
"sync"
88
"sync/atomic"
9+
"unsafe"
910

1011
"github.com/sagernet/sing/common/wepoll"
1112

@@ -36,7 +37,7 @@ type FDDemultiplexer struct {
3637
mutex sync.Mutex
3738
entries map[int]*fdDemuxEntry
3839
registrationCounter uint64
39-
registrationToFD map[uint64]int
40+
iosbToFD map[uintptr]int
4041
running bool
4142
closed atomic.Bool
4243
wg sync.WaitGroup
@@ -56,12 +57,12 @@ func NewFDDemultiplexer(ctx context.Context) (*FDDemultiplexer, error) {
5657

5758
ctx, cancel := context.WithCancel(ctx)
5859
demux := &FDDemultiplexer{
59-
ctx: ctx,
60-
cancel: cancel,
61-
iocp: iocp,
62-
afd: afd,
63-
entries: make(map[int]*fdDemuxEntry),
64-
registrationToFD: make(map[uint64]int),
60+
ctx: ctx,
61+
cancel: cancel,
62+
iocp: iocp,
63+
afd: afd,
64+
entries: make(map[int]*fdDemuxEntry),
65+
iosbToFD: make(map[uintptr]int),
6566
}
6667
return demux, nil
6768
}
@@ -94,14 +95,14 @@ func (p *FDDemultiplexer) Add(stream *reactorStream, fd int) error {
9495
entry.pinner.Pin(&entry.state)
9596

9697
events := uint32(wepoll.AFD_POLL_RECEIVE | wepoll.AFD_POLL_DISCONNECT | wepoll.AFD_POLL_ABORT | wepoll.AFD_POLL_LOCAL_CLOSE)
97-
err = p.afd.Poll(baseHandle, events, &entry.state.iosb, &entry.state.pollInfo, uintptr(regID))
98+
err = p.afd.Poll(baseHandle, events, &entry.state.iosb, &entry.state.pollInfo)
9899
if err != nil {
99100
entry.pinner.Unpin()
100101
return err
101102
}
102103

103104
p.entries[fd] = entry
104-
p.registrationToFD[regID] = fd
105+
p.iosbToFD[uintptr(unsafe.Pointer(&entry.state.iosb))] = fd
105106

106107
if !p.running {
107108
p.running = true
@@ -178,27 +179,28 @@ func (p *FDDemultiplexer) run() {
178179

179180
for i := uint32(0); i < numRemoved; i++ {
180181
ev := entries[i]
181-
regID := uint64(ev.CompletionKey)
182182

183-
if regID == 0 {
183+
if ev.Overlapped == nil {
184184
continue
185185
}
186186

187+
iosbPtr := uintptr(unsafe.Pointer(ev.Overlapped))
188+
187189
p.mutex.Lock()
188-
fd, ok := p.registrationToFD[regID]
190+
fd, ok := p.iosbToFD[iosbPtr]
189191
if !ok {
190192
p.mutex.Unlock()
191193
continue
192194
}
193195

194196
entry := p.entries[fd]
195-
if entry == nil || entry.registrationID != regID {
197+
if entry == nil {
196198
p.mutex.Unlock()
197199
continue
198200
}
199201

200202
entry.pinner.Unpin()
201-
delete(p.registrationToFD, regID)
203+
delete(p.iosbToFD, iosbPtr)
202204
delete(p.entries, fd)
203205

204206
if entry.cancelled {
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
//go:build windows
2+
3+
package bufio
4+
5+
import (
6+
"context"
7+
"net"
8+
"sync"
9+
"sync/atomic"
10+
"syscall"
11+
"testing"
12+
"time"
13+
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func getSocketFD(t *testing.T, conn net.PacketConn) int {
18+
syscallConn, ok := conn.(syscall.Conn)
19+
require.True(t, ok)
20+
rawConn, err := syscallConn.SyscallConn()
21+
require.NoError(t, err)
22+
var fd int
23+
err = rawConn.Control(func(f uintptr) { fd = int(f) })
24+
require.NoError(t, err)
25+
return fd
26+
}
27+
28+
func TestFDDemultiplexer_Create(t *testing.T) {
29+
t.Parallel()
30+
31+
demux, err := NewFDDemultiplexer(context.Background())
32+
require.NoError(t, err)
33+
34+
err = demux.Close()
35+
require.NoError(t, err)
36+
}
37+
38+
func TestFDDemultiplexer_CreateMultiple(t *testing.T) {
39+
t.Parallel()
40+
41+
demux1, err := NewFDDemultiplexer(context.Background())
42+
require.NoError(t, err)
43+
defer demux1.Close()
44+
45+
demux2, err := NewFDDemultiplexer(context.Background())
46+
require.NoError(t, err)
47+
defer demux2.Close()
48+
}
49+
50+
func TestFDDemultiplexer_AddRemove(t *testing.T) {
51+
t.Parallel()
52+
53+
demux, err := NewFDDemultiplexer(context.Background())
54+
require.NoError(t, err)
55+
defer demux.Close()
56+
57+
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
58+
require.NoError(t, err)
59+
defer conn.Close()
60+
61+
fd := getSocketFD(t, conn)
62+
63+
stream := &reactorStream{}
64+
65+
err = demux.Add(stream, fd)
66+
require.NoError(t, err)
67+
68+
demux.Remove(fd)
69+
}
70+
71+
func TestFDDemultiplexer_RapidAddRemove(t *testing.T) {
72+
t.Parallel()
73+
74+
demux, err := NewFDDemultiplexer(context.Background())
75+
require.NoError(t, err)
76+
defer demux.Close()
77+
78+
const iterations = 50
79+
80+
for i := 0; i < iterations; i++ {
81+
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
82+
require.NoError(t, err)
83+
84+
fd := getSocketFD(t, conn)
85+
stream := &reactorStream{}
86+
87+
err = demux.Add(stream, fd)
88+
require.NoError(t, err)
89+
90+
demux.Remove(fd)
91+
conn.Close()
92+
}
93+
}
94+
95+
func TestFDDemultiplexer_ConcurrentAccess(t *testing.T) {
96+
t.Parallel()
97+
98+
demux, err := NewFDDemultiplexer(context.Background())
99+
require.NoError(t, err)
100+
defer demux.Close()
101+
102+
const numGoroutines = 10
103+
const iterations = 20
104+
105+
var wg sync.WaitGroup
106+
wg.Add(numGoroutines)
107+
108+
for g := 0; g < numGoroutines; g++ {
109+
go func() {
110+
defer wg.Done()
111+
112+
for i := 0; i < iterations; i++ {
113+
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
114+
if err != nil {
115+
continue
116+
}
117+
118+
fd := getSocketFD(t, conn)
119+
stream := &reactorStream{}
120+
121+
err = demux.Add(stream, fd)
122+
if err == nil {
123+
demux.Remove(fd)
124+
}
125+
conn.Close()
126+
}
127+
}()
128+
}
129+
130+
wg.Wait()
131+
}
132+
133+
func TestFDDemultiplexer_ReceiveEvent(t *testing.T) {
134+
t.Parallel()
135+
136+
ctx, cancel := context.WithCancel(context.Background())
137+
defer cancel()
138+
139+
demux, err := NewFDDemultiplexer(ctx)
140+
require.NoError(t, err)
141+
defer demux.Close()
142+
143+
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
144+
require.NoError(t, err)
145+
defer conn.Close()
146+
147+
fd := getSocketFD(t, conn)
148+
149+
triggered := make(chan struct{}, 1)
150+
stream := &reactorStream{
151+
state: atomic.Int32{},
152+
}
153+
stream.connection = &reactorConnection{
154+
upload: stream,
155+
download: stream,
156+
done: make(chan struct{}),
157+
}
158+
159+
originalRunActiveLoop := stream.runActiveLoop
160+
_ = originalRunActiveLoop
161+
162+
err = demux.Add(stream, fd)
163+
require.NoError(t, err)
164+
165+
sender, err := net.ListenPacket("udp", "127.0.0.1:0")
166+
require.NoError(t, err)
167+
defer sender.Close()
168+
169+
_, err = sender.WriteTo([]byte("test data"), conn.LocalAddr())
170+
require.NoError(t, err)
171+
172+
time.Sleep(200 * time.Millisecond)
173+
174+
select {
175+
case <-triggered:
176+
default:
177+
}
178+
179+
demux.Remove(fd)
180+
}
181+
182+
func TestFDDemultiplexer_CloseWhilePolling(t *testing.T) {
183+
t.Parallel()
184+
185+
demux, err := NewFDDemultiplexer(context.Background())
186+
require.NoError(t, err)
187+
188+
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
189+
require.NoError(t, err)
190+
defer conn.Close()
191+
192+
fd := getSocketFD(t, conn)
193+
stream := &reactorStream{}
194+
195+
err = demux.Add(stream, fd)
196+
require.NoError(t, err)
197+
198+
time.Sleep(50 * time.Millisecond)
199+
200+
done := make(chan struct{})
201+
go func() {
202+
demux.Close()
203+
close(done)
204+
}()
205+
206+
select {
207+
case <-done:
208+
case <-time.After(5 * time.Second):
209+
t.Fatal("Close blocked - possible deadlock")
210+
}
211+
}
212+
213+
func TestFDDemultiplexer_RemoveNonExistent(t *testing.T) {
214+
t.Parallel()
215+
216+
demux, err := NewFDDemultiplexer(context.Background())
217+
require.NoError(t, err)
218+
defer demux.Close()
219+
220+
demux.Remove(99999)
221+
}
222+
223+
func TestFDDemultiplexer_AddAfterClose(t *testing.T) {
224+
t.Parallel()
225+
226+
demux, err := NewFDDemultiplexer(context.Background())
227+
require.NoError(t, err)
228+
229+
err = demux.Close()
230+
require.NoError(t, err)
231+
232+
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
233+
require.NoError(t, err)
234+
defer conn.Close()
235+
236+
fd := getSocketFD(t, conn)
237+
stream := &reactorStream{}
238+
239+
err = demux.Add(stream, fd)
240+
require.Error(t, err)
241+
}
242+
243+
func TestFDDemultiplexer_MultipleSocketsSimultaneous(t *testing.T) {
244+
t.Parallel()
245+
246+
demux, err := NewFDDemultiplexer(context.Background())
247+
require.NoError(t, err)
248+
defer demux.Close()
249+
250+
const numSockets = 5
251+
conns := make([]net.PacketConn, numSockets)
252+
fds := make([]int, numSockets)
253+
254+
for i := 0; i < numSockets; i++ {
255+
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
256+
require.NoError(t, err)
257+
defer conn.Close()
258+
conns[i] = conn
259+
260+
fd := getSocketFD(t, conn)
261+
fds[i] = fd
262+
263+
stream := &reactorStream{}
264+
err = demux.Add(stream, fd)
265+
require.NoError(t, err)
266+
}
267+
268+
for i := 0; i < numSockets; i++ {
269+
demux.Remove(fds[i])
270+
}
271+
}
272+
273+
func TestFDDemultiplexer_ContextCancellation(t *testing.T) {
274+
t.Parallel()
275+
276+
ctx, cancel := context.WithCancel(context.Background())
277+
demux, err := NewFDDemultiplexer(ctx)
278+
require.NoError(t, err)
279+
280+
conn, err := net.ListenPacket("udp", "127.0.0.1:0")
281+
require.NoError(t, err)
282+
defer conn.Close()
283+
284+
fd := getSocketFD(t, conn)
285+
stream := &reactorStream{}
286+
287+
err = demux.Add(stream, fd)
288+
require.NoError(t, err)
289+
290+
cancel()
291+
292+
time.Sleep(100 * time.Millisecond)
293+
294+
done := make(chan struct{})
295+
go func() {
296+
demux.Close()
297+
close(done)
298+
}()
299+
300+
select {
301+
case <-done:
302+
case <-time.After(5 * time.Second):
303+
t.Fatal("Close blocked after context cancellation")
304+
}
305+
}

0 commit comments

Comments
 (0)