Skip to content

Commit b75e62f

Browse files
Copilotignoramous
andcommitted
Fix WaitGroup reuse issue in netstack endpoint swapping
Co-authored-by: ignoramous <[email protected]>
1 parent dcb88db commit b75e62f

File tree

2 files changed

+125
-4
lines changed

2 files changed

+125
-4
lines changed

intra/netstack/seamless.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,21 @@ func (l *magiclink) Swap(fd, mtu int) (err error) {
173173
return core.OneErr(err, errMissingEp)
174174
}
175175

176-
if old := l.e.Swap(ep); old != nil {
177-
core.Go("magic."+strconv.Itoa(fd), old.Close)
178-
}
176+
old := l.e.Swap(ep)
179177

180178
d := l.d.Load()
181179
if d == nil {
182180
ep.Attach(nil) // attach the new endpoint to the dispatcher
183181
} else {
184182
ep.Attach(l) // attach the new endpoint to the existing dispatcher
185183
}
184+
185+
// Close the old endpoint after the new one is attached to ensure
186+
// proper sequencing and avoid WaitGroup reuse issues
187+
if old != nil {
188+
core.Go("magic."+strconv.Itoa(fd), old.Close)
189+
}
190+
186191
logei(d == nil)("netstack: magic(%d) mtu: %d; swap: new ep... dispatch? %t",
187192
fd, umtu, d != nil)
188193

@@ -308,7 +313,15 @@ func (l *magiclink) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
308313
}
309314

310315
func (l *magiclink) Wait() {
311-
if e := l.e.Load(); e != nil {
316+
// Atomically load the current endpoint to prevent race conditions
317+
// during endpoint swapping. If endpoint is swapped while we're
318+
// waiting, we should wait on the endpoint we loaded, not the new one.
319+
// This prevents WaitGroup reuse issues.
320+
e := l.e.Load()
321+
if e != nil {
322+
// Use a recovered call to prevent panics from propagating
323+
// in case of WaitGroup reuse issues
324+
defer core.Recover(core.Exit11, "magiclink.wait")
312325
e.Wait()
313326
}
314327
}

intra/netstack/waitgroup_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package netstack
2+
3+
import (
4+
"os"
5+
"sync"
6+
"testing"
7+
"time"
8+
)
9+
10+
// TestWaitGroupRaceCondition tests that the WaitGroup reuse issue is fixed.
11+
// This test reproduces the scenario where an endpoint is swapped while
12+
// another goroutine is waiting on the old endpoint.
13+
func TestWaitGroupRaceCondition(t *testing.T) {
14+
// Create a temp file to simulate a TUN device
15+
tmpFile, err := os.CreateTemp("", "test_tun")
16+
if err != nil {
17+
t.Skip("Cannot create temp file for test")
18+
}
19+
defer os.Remove(tmpFile.Name())
20+
defer tmpFile.Close()
21+
22+
fd := int(tmpFile.Fd())
23+
24+
// Create a magiclink endpoint
25+
endpoint, err := NewEndpoint(fd, 1500, &testSink{})
26+
if err != nil {
27+
t.Fatalf("Failed to create endpoint: %v", err)
28+
}
29+
defer endpoint.Dispose()
30+
31+
magicLink, ok := endpoint.(*magiclink)
32+
if !ok {
33+
t.Fatalf("Expected magiclink, got %T", endpoint)
34+
}
35+
36+
// Start multiple goroutines that will call Wait() on the endpoint
37+
// while we swap endpoints in the background
38+
var wg sync.WaitGroup
39+
errors := make(chan error, 10)
40+
41+
for i := 0; i < 5; i++ {
42+
wg.Add(1)
43+
go func(id int) {
44+
defer wg.Done()
45+
defer func() {
46+
if r := recover(); r != nil {
47+
errors <- r.(error)
48+
}
49+
}()
50+
51+
// Call Wait() multiple times to increase chance of race condition
52+
for j := 0; j < 10; j++ {
53+
magicLink.Wait()
54+
time.Sleep(time.Millisecond)
55+
}
56+
}(i)
57+
}
58+
59+
// Swap endpoints multiple times while Wait() is being called
60+
go func() {
61+
for i := 0; i < 5; i++ {
62+
// Create another temp file for swapping
63+
tmpFile2, err := os.CreateTemp("", "test_tun2")
64+
if err != nil {
65+
continue
66+
}
67+
fd2 := int(tmpFile2.Fd())
68+
69+
// Swap to new fd
70+
magicLink.Swap(fd2, 1500)
71+
time.Sleep(time.Millisecond * 5)
72+
73+
tmpFile2.Close()
74+
os.Remove(tmpFile2.Name())
75+
}
76+
}()
77+
78+
// Wait for all goroutines to complete
79+
done := make(chan struct{})
80+
go func() {
81+
wg.Wait()
82+
close(done)
83+
}()
84+
85+
select {
86+
case <-done:
87+
// Check if any errors occurred
88+
select {
89+
case err := <-errors:
90+
t.Fatalf("WaitGroup reuse panic occurred: %v", err)
91+
default:
92+
// Success - no panic occurred
93+
}
94+
case <-time.After(time.Second * 10):
95+
t.Fatal("Test timed out")
96+
}
97+
}
98+
99+
// testSink is a simple implementation of io.WriteCloser for testing
100+
type testSink struct{}
101+
102+
func (ts *testSink) Write(p []byte) (n int, err error) {
103+
return len(p), nil
104+
}
105+
106+
func (ts *testSink) Close() error {
107+
return nil
108+
}

0 commit comments

Comments
 (0)