Skip to content

Commit bc9c5d2

Browse files
committed
TUN-8817: Increase close session channel by one since there are two writers
When closing a session, there are two possible signals that will occur, one from the outside, indicating that the session is idle and needs to be closed, and the internal error condition that will be unblocked with a net.ErrClosed when the connection underneath is closed. Both of these routines write to the session's closeChan. Once the reader for the closeChan reads one value, it will immediately return. This means that the channel is a one-shot and one of the two writers will get stuck unless the size of the channel is increased to accomodate for the second write to the channel. With the channel size increased to two, the second writer (whichever loses the race to write) will now be unblocked to end their go routine and return. Closes TUN-8817
1 parent 1859d74 commit bc9c5d2

File tree

2 files changed

+80
-92
lines changed

2 files changed

+80
-92
lines changed

quic/v3/session.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ func NewSession(
8989
log *zerolog.Logger,
9090
) Session {
9191
logger := log.With().Str(logFlowID, id.String()).Logger()
92-
closeChan := make(chan error, 1)
92+
// closeChan has two slots to allow for both writers (the closeFn and the Serve routine) to both be able to
93+
// write to the channel without blocking since there is only ever one value read from the closeChan by the
94+
// waitForCloseCondition.
95+
closeChan := make(chan error, 2)
9396
session := &session{
9497
id: id,
9598
closeAfterIdle: closeAfterIdle,

quic/v3/session_test.go

Lines changed: 76 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ package v3_test
33
import (
44
"context"
55
"errors"
6+
"io"
67
"net"
78
"net/netip"
89
"slices"
9-
"sync/atomic"
1010
"testing"
1111
"time"
1212

13+
"github.com/fortytw2/leaktest"
1314
"github.com/rs/zerolog"
1415

1516
v3 "github.com/cloudflare/cloudflared/quic/v3"
@@ -32,45 +33,64 @@ func TestSessionNew(t *testing.T) {
3233

3334
func testSessionWrite(t *testing.T, payload []byte) {
3435
log := zerolog.Nop()
35-
origin := newTestOrigin(makePayload(1280))
36-
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
36+
origin, server := net.Pipe()
37+
defer origin.Close()
38+
defer server.Close()
39+
// Start origin server read
40+
serverRead := make(chan []byte, 1)
41+
go func() {
42+
read := make([]byte, 1500)
43+
server.Read(read[:])
44+
serverRead <- read
45+
}()
46+
// Create session and write to origin
47+
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
3748
n, err := session.Write(payload)
49+
defer session.Close()
3850
if err != nil {
3951
t.Fatal(err)
4052
}
4153
if n != len(payload) {
4254
t.Fatal("unable to write the whole payload")
4355
}
44-
if !slices.Equal(payload, origin.write[:len(payload)]) {
56+
57+
read := <-serverRead
58+
if !slices.Equal(payload, read[:len(payload)]) {
4559
t.Fatal("payload provided from origin and read value are not the same")
4660
}
4761
}
4862

4963
func TestSessionWrite_Max(t *testing.T) {
64+
defer leaktest.Check(t)()
5065
payload := makePayload(1280)
5166
testSessionWrite(t, payload)
5267
}
5368

5469
func TestSessionWrite_Min(t *testing.T) {
70+
defer leaktest.Check(t)()
5571
payload := makePayload(0)
5672
testSessionWrite(t, payload)
5773
}
5874

5975
func TestSessionServe_OriginMax(t *testing.T) {
76+
defer leaktest.Check(t)()
6077
payload := makePayload(1280)
6178
testSessionServe_Origin(t, payload)
6279
}
6380

6481
func TestSessionServe_OriginMin(t *testing.T) {
82+
defer leaktest.Check(t)()
6583
payload := makePayload(0)
6684
testSessionServe_Origin(t, payload)
6785
}
6886

6987
func testSessionServe_Origin(t *testing.T, payload []byte) {
7088
log := zerolog.Nop()
89+
origin, server := net.Pipe()
90+
defer origin.Close()
91+
defer server.Close()
7192
eyeball := newMockEyeball()
72-
origin := newTestOrigin(payload)
73-
session := v3.NewSession(testRequestID, 3*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
93+
session := v3.NewSession(testRequestID, 3*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
7494
defer session.Close()
7595

7696
ctx, cancel := context.WithCancelCause(context.Background())
@@ -80,13 +100,19 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
80100
done <- session.Serve(ctx)
81101
}()
82102

103+
// Write from the origin server
104+
_, err := server.Write(payload)
105+
if err != nil {
106+
t.Fatal(err)
107+
}
108+
83109
select {
84110
case data := <-eyeball.recvData:
85111
// check received data matches provided from origin
86112
expectedData := makePayload(1500)
87113
v3.MarshalPayloadHeaderTo(testRequestID, expectedData[:])
88114
copy(expectedData[17:], payload)
89-
if !slices.Equal(expectedData[:17+len(payload)], data) {
115+
if !slices.Equal(expectedData[:v3.DatagramPayloadHeaderLen+len(payload)], data) {
90116
t.Fatal("expected datagram did not equal expected")
91117
}
92118
cancel(expectedContextCanceled)
@@ -95,7 +121,7 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
95121
t.Fatal(err)
96122
}
97123

98-
err := <-done
124+
err = <-done
99125
if !errors.Is(err, context.Canceled) {
100126
t.Fatal(err)
101127
}
@@ -105,18 +131,27 @@ func testSessionServe_Origin(t *testing.T, payload []byte) {
105131
}
106132

107133
func TestSessionServe_OriginTooLarge(t *testing.T) {
134+
defer leaktest.Check(t)()
108135
log := zerolog.Nop()
109136
eyeball := newMockEyeball()
110137
payload := makePayload(1281)
111-
origin := newTestOrigin(payload)
112-
session := v3.NewSession(testRequestID, 2*time.Second, &origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
138+
origin, server := net.Pipe()
139+
defer origin.Close()
140+
defer server.Close()
141+
session := v3.NewSession(testRequestID, 2*time.Second, origin, testOriginAddr, testLocalAddr, &eyeball, &noopMetrics{}, &log)
113142
defer session.Close()
114143

115144
done := make(chan error)
116145
go func() {
117146
done <- session.Serve(context.Background())
118147
}()
119148

149+
// Attempt to write a payload too large from the origin
150+
_, err := server.Write(payload)
151+
if err != nil {
152+
t.Fatal(err)
153+
}
154+
120155
select {
121156
case data := <-eyeball.recvData:
122157
// we never expect a read to make it here because the origin provided a payload that is too large
@@ -130,6 +165,7 @@ func TestSessionServe_OriginTooLarge(t *testing.T) {
130165
}
131166

132167
func TestSessionServe_Migrate(t *testing.T) {
168+
defer leaktest.Check(t)()
133169
log := zerolog.Nop()
134170
eyeball := newMockEyeball()
135171
pipe1, pipe2 := net.Pipe()
@@ -186,6 +222,7 @@ func TestSessionServe_Migrate(t *testing.T) {
186222
}
187223

188224
func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
225+
defer leaktest.Check(t)()
189226
log := zerolog.Nop()
190227
eyeball := newMockEyeball()
191228
pipe1, pipe2 := net.Pipe()
@@ -245,39 +282,48 @@ func TestSessionServe_Migrate_CloseContext2(t *testing.T) {
245282
}
246283

247284
func TestSessionClose_Multiple(t *testing.T) {
285+
defer leaktest.Check(t)()
248286
log := zerolog.Nop()
249-
origin := newTestOrigin(makePayload(128))
250-
session := v3.NewSession(testRequestID, 5*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
287+
origin, server := net.Pipe()
288+
defer origin.Close()
289+
defer server.Close()
290+
session := v3.NewSession(testRequestID, 5*time.Second, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
251291
err := session.Close()
252292
if err != nil {
253293
t.Fatal(err)
254294
}
255-
if !origin.closed.Load() {
256-
t.Fatal("origin wasn't closed")
295+
b := [1500]byte{}
296+
_, err = server.Read(b[:])
297+
if !errors.Is(err, io.EOF) {
298+
t.Fatalf("origin server connection should be closed: %s", err)
257299
}
258-
// Reset the closed status to make sure it isn't closed again
259-
origin.closed.Store(false)
260300
// subsequent closes shouldn't call close again or cause any errors
261301
err = session.Close()
262302
if err != nil {
263303
t.Fatal(err)
264304
}
265-
if origin.closed.Load() {
266-
t.Fatal("origin was incorrectly closed twice")
305+
_, err = server.Read(b[:])
306+
if !errors.Is(err, io.EOF) {
307+
t.Fatalf("origin server connection should still be closed: %s", err)
267308
}
268309
}
269310

270311
func TestSessionServe_IdleTimeout(t *testing.T) {
312+
defer leaktest.Check(t)()
271313
log := zerolog.Nop()
272-
origin := newTestIdleOrigin(10 * time.Second) // Make idle time longer than closeAfterIdle
314+
origin, server := net.Pipe()
315+
defer origin.Close()
316+
defer server.Close()
273317
closeAfterIdle := 2 * time.Second
274-
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
318+
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
275319
err := session.Serve(context.Background())
276320
if !errors.Is(err, v3.SessionIdleErr{}) {
277321
t.Fatal(err)
278322
}
279323
// session should be closed
280-
if !origin.closed {
324+
b := [1500]byte{}
325+
_, err = server.Read(b[:])
326+
if !errors.Is(err, io.EOF) {
281327
t.Fatalf("session should be closed after Serve returns")
282328
}
283329
// closing a session again should not return an error
@@ -288,20 +334,24 @@ func TestSessionServe_IdleTimeout(t *testing.T) {
288334
}
289335

290336
func TestSessionServe_ParentContextCanceled(t *testing.T) {
337+
defer leaktest.Check(t)()
291338
log := zerolog.Nop()
292-
// Make idle time and idle timeout longer than closeAfterIdle
293-
origin := newTestIdleOrigin(10 * time.Second)
339+
origin, server := net.Pipe()
340+
defer origin.Close()
341+
defer server.Close()
294342
closeAfterIdle := 10 * time.Second
295343

296-
session := v3.NewSession(testRequestID, closeAfterIdle, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
344+
session := v3.NewSession(testRequestID, closeAfterIdle, origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
297345
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
298346
defer cancel()
299347
err := session.Serve(ctx)
300348
if !errors.Is(err, context.DeadlineExceeded) {
301349
t.Fatal(err)
302350
}
303351
// session should be closed
304-
if !origin.closed {
352+
b := [1500]byte{}
353+
_, err = server.Read(b[:])
354+
if !errors.Is(err, io.EOF) {
305355
t.Fatalf("session should be closed after Serve returns")
306356
}
307357
// closing a session again should not return an error
@@ -312,6 +362,7 @@ func TestSessionServe_ParentContextCanceled(t *testing.T) {
312362
}
313363

314364
func TestSessionServe_ReadErrors(t *testing.T) {
365+
defer leaktest.Check(t)()
315366
log := zerolog.Nop()
316367
origin := newTestErrOrigin(net.ErrClosed, nil)
317368
session := v3.NewSession(testRequestID, 30*time.Second, &origin, testOriginAddr, testLocalAddr, &noopEyeball{}, &noopMetrics{}, &log)
@@ -321,72 +372,6 @@ func TestSessionServe_ReadErrors(t *testing.T) {
321372
}
322373
}
323374

324-
type testOrigin struct {
325-
// bytes from Write
326-
write []byte
327-
// bytes provided to Read
328-
read []byte
329-
readOnce atomic.Bool
330-
closed atomic.Bool
331-
}
332-
333-
func newTestOrigin(payload []byte) testOrigin {
334-
return testOrigin{
335-
read: payload,
336-
}
337-
}
338-
339-
func (o *testOrigin) Read(p []byte) (n int, err error) {
340-
if o.closed.Load() {
341-
return -1, net.ErrClosed
342-
}
343-
if o.readOnce.Load() {
344-
// We only want to provide one read so all other reads will be blocked
345-
time.Sleep(10 * time.Second)
346-
}
347-
o.readOnce.Store(true)
348-
return copy(p, o.read), nil
349-
}
350-
351-
func (o *testOrigin) Write(p []byte) (n int, err error) {
352-
if o.closed.Load() {
353-
return -1, net.ErrClosed
354-
}
355-
o.write = make([]byte, len(p))
356-
copy(o.write, p)
357-
return len(p), nil
358-
}
359-
360-
func (o *testOrigin) Close() error {
361-
o.closed.Store(true)
362-
return nil
363-
}
364-
365-
type testIdleOrigin struct {
366-
duration time.Duration
367-
closed bool
368-
}
369-
370-
func newTestIdleOrigin(d time.Duration) testIdleOrigin {
371-
return testIdleOrigin{
372-
duration: d,
373-
}
374-
}
375-
376-
func (o *testIdleOrigin) Read(p []byte) (n int, err error) {
377-
time.Sleep(o.duration)
378-
return -1, nil
379-
}
380-
381-
func (o *testIdleOrigin) Write(p []byte) (n int, err error) {
382-
return 0, nil
383-
}
384-
385-
func (o *testIdleOrigin) Close() error {
386-
o.closed = true
387-
return nil
388-
}
389-
390375
type testErrOrigin struct {
391376
readErr error
392377
writeErr error

0 commit comments

Comments
 (0)