Skip to content

Commit 73a265f

Browse files
committed
TUN-5488: Close session after it's idle for a period defined by registerUdpSession RPC
1 parent 9bc59bc commit 73a265f

File tree

13 files changed

+454
-251
lines changed

13 files changed

+454
-251
lines changed

connection/quic.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/http"
1010
"strconv"
1111
"strings"
12+
"time"
1213

1314
"github.com/google/uuid"
1415
"github.com/lucas-clemente/quic-go"
@@ -167,7 +168,7 @@ func (q *QUICConnection) handleRPCStream(rpcStream *quicpogs.RPCServerStream) er
167168
return rpcStream.Serve(q, q.logger)
168169
}
169170

170-
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16) error {
171+
func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.UUID, dstIP net.IP, dstPort uint16, closeAfterIdleHint time.Duration) error {
171172
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
172173
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
173174
originProxy, err := ingress.DialUDP(dstIP, dstPort)
@@ -182,7 +183,7 @@ func (q *QUICConnection) RegisterUdpSession(ctx context.Context, sessionID uuid.
182183
}
183184
go func() {
184185
defer q.sessionManager.UnregisterSession(q.session.Context(), sessionID)
185-
if err := session.Serve(q.session.Context()); err != nil {
186+
if err := session.Serve(q.session.Context(), closeAfterIdleHint); err != nil {
186187
q.logger.Debug().Err(err).Str("sessionID", sessionID.String()).Msg("session terminated")
187188
}
188189
}()

datagramsession/manager.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ func (m *manager) sendToSession(datagram *newDatagram) {
127127
}
128128
// session writes to destination over a connected UDP socket, which should not be blocking, so this call doesn't
129129
// need to run in another go routine
130-
_, err := session.writeToDst(datagram.payload)
130+
_, err := session.transportToDst(datagram.payload)
131131
if err != nil {
132132
m.log.Err(err).Str("sessionID", datagram.sessionID.String()).Msg("Failed to write payload to session")
133133
}

datagramsession/manager_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"io"
88
"net"
99
"testing"
10+
"time"
1011

1112
"github.com/google/uuid"
1213
"github.com/rs/zerolog"
@@ -21,15 +22,15 @@ func TestManagerServe(t *testing.T) {
2122
)
2223
log := zerolog.Nop()
2324
transport := &mockQUICTransport{
24-
reqChan: newDatagramChannel(),
25-
respChan: newDatagramChannel(),
25+
reqChan: newDatagramChannel(1),
26+
respChan: newDatagramChannel(1),
2627
}
2728
mg := NewManager(transport, &log)
2829

2930
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
3031
for i := 0; i < sessions; i++ {
3132
sessionID := uuid.New()
32-
eyeballTracker[sessionID] = newDatagramChannel()
33+
eyeballTracker[sessionID] = newDatagramChannel(1)
3334
}
3435

3536
ctx, cancel := context.WithCancel(context.Background())
@@ -88,7 +89,7 @@ func TestManagerServe(t *testing.T) {
8889

8990
sessionDone := make(chan struct{})
9091
go func() {
91-
session.Serve(ctx)
92+
session.Serve(ctx, time.Minute*2)
9293
close(sessionDone)
9394
}()
9495

@@ -179,9 +180,9 @@ type datagramChannel struct {
179180
closedChan chan struct{}
180181
}
181182

182-
func newDatagramChannel() *datagramChannel {
183+
func newDatagramChannel(capacity uint) *datagramChannel {
183184
return &datagramChannel{
184-
datagramChan: make(chan *newDatagram, 1),
185+
datagramChan: make(chan *newDatagram, capacity),
185186
closedChan: make(chan struct{}),
186187
}
187188
}

datagramsession/session.go

Lines changed: 65 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@ package datagramsession
33
import (
44
"context"
55
"io"
6+
"time"
67

78
"github.com/google/uuid"
89
)
910

11+
const (
12+
defaultCloseIdleAfter = time.Second * 210
13+
)
14+
1015
// Each Session is a bidirectional pipe of datagrams between transport and dstConn
1116
// Currently the only implementation of transport is quic DatagramMuxer
1217
// Destination can be a connection with origin or with eyeball
@@ -22,49 +27,91 @@ type Session struct {
2227
id uuid.UUID
2328
transport transport
2429
dstConn io.ReadWriteCloser
25-
doneChan chan struct{}
30+
// activeAtChan is used to communicate the last read/write time
31+
activeAtChan chan time.Time
32+
doneChan chan struct{}
2633
}
2734

2835
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser) *Session {
2936
return &Session{
3037
id: id,
3138
transport: transport,
3239
dstConn: dstConn,
33-
doneChan: make(chan struct{}),
40+
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
41+
// drop instead of blocking because last active time only needs to be an approximation
42+
activeAtChan: make(chan time.Time, 2),
43+
doneChan: make(chan struct{}),
3444
}
3545
}
3646

37-
func (s *Session) Serve(ctx context.Context) error {
47+
func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) error {
3848
serveCtx, cancel := context.WithCancel(ctx)
3949
defer cancel()
40-
go func() {
41-
select {
42-
case <-serveCtx.Done():
43-
case <-s.doneChan:
44-
}
45-
s.dstConn.Close()
46-
}()
50+
go s.waitForCloseCondition(serveCtx, closeAfterIdle)
4751
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975
4852
// This makes it safe to share readBuffer between iterations
49-
readBuffer := make([]byte, 1280)
53+
readBuffer := make([]byte, s.transport.MTU())
54+
for {
55+
if err := s.dstToTransport(readBuffer); err != nil {
56+
return err
57+
}
58+
}
59+
}
60+
61+
func (s *Session) waitForCloseCondition(ctx context.Context, closeAfterIdle time.Duration) {
62+
if closeAfterIdle == 0 {
63+
// provide deafult is caller doesn't specify one
64+
closeAfterIdle = defaultCloseIdleAfter
65+
}
66+
// Closing dstConn cancels read so Serve function can return
67+
defer s.dstConn.Close()
68+
69+
checkIdleFreq := closeAfterIdle / 8
70+
checkIdleTicker := time.NewTicker(checkIdleFreq)
71+
defer checkIdleTicker.Stop()
72+
73+
activeAt := time.Now()
5074
for {
51-
// TODO: TUN-5303: origin proxy should determine the buffer size
52-
n, err := s.dstConn.Read(readBuffer)
53-
if n > 0 {
54-
if err := s.transport.SendTo(s.id, readBuffer[:n]); err != nil {
55-
return err
75+
select {
76+
case <-ctx.Done():
77+
return
78+
case <-s.doneChan:
79+
return
80+
case <-checkIdleTicker.C:
81+
// The session is considered inactive if current time is after (last active time + allowed idle time)
82+
if time.Now().After(activeAt.Add(closeAfterIdle)) {
83+
return
5684
}
85+
case activeAt = <-s.activeAtChan: // Update last active time
5786
}
58-
if err != nil {
87+
}
88+
}
89+
90+
func (s *Session) dstToTransport(buffer []byte) error {
91+
n, err := s.dstConn.Read(buffer)
92+
s.markActive()
93+
if n > 0 {
94+
if err := s.transport.SendTo(s.id, buffer[:n]); err != nil {
5995
return err
6096
}
6197
}
98+
return err
6299
}
63100

64-
func (s *Session) writeToDst(payload []byte) (int, error) {
101+
func (s *Session) transportToDst(payload []byte) (int, error) {
102+
s.markActive()
65103
return s.dstConn.Write(payload)
66104
}
67105

106+
// Sends the last active time to the idle checker loop without blocking. activeAtChan will only be full when there
107+
// are many concurrent read/write. It is fine to lose some precision
108+
func (s *Session) markActive() {
109+
select {
110+
case s.activeAtChan <- time.Now():
111+
default:
112+
}
113+
}
114+
68115
func (s *Session) close() {
69116
close(s.doneChan)
70117
}

datagramsession/session_test.go

Lines changed: 127 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,54 @@
11
package datagramsession
22

33
import (
4+
"bytes"
45
"context"
6+
"fmt"
7+
"io"
58
"net"
9+
"sync"
610
"testing"
11+
"time"
712

813
"github.com/google/uuid"
914
"github.com/stretchr/testify/require"
15+
"golang.org/x/sync/errgroup"
1016
)
1117

1218
// TestCloseSession makes sure a session will stop after context is done
1319
func TestSessionCtxDone(t *testing.T) {
14-
testSessionReturns(t, true)
20+
testSessionReturns(t, closeByContext, time.Minute*2)
1521
}
1622

1723
// TestCloseSession makes sure a session will stop after close method is called
1824
func TestCloseSession(t *testing.T) {
19-
testSessionReturns(t, false)
25+
testSessionReturns(t, closeByCallingClose, time.Minute*2)
2026
}
2127

22-
func testSessionReturns(t *testing.T, closeByContext bool) {
28+
// TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
29+
func TestCloseIdle(t *testing.T) {
30+
testSessionReturns(t, closeByTimeout, time.Millisecond*100)
31+
}
32+
33+
func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.Duration) {
2334
sessionID := uuid.New()
2435
cfdConn, originConn := net.Pipe()
2536
payload := testPayload(sessionID)
2637
transport := &mockQUICTransport{
27-
reqChan: newDatagramChannel(),
28-
respChan: newDatagramChannel(),
38+
reqChan: newDatagramChannel(1),
39+
respChan: newDatagramChannel(1),
2940
}
3041
session := newSession(sessionID, transport, cfdConn)
3142

3243
ctx, cancel := context.WithCancel(context.Background())
3344
sessionDone := make(chan struct{})
3445
go func() {
35-
session.Serve(ctx)
46+
session.Serve(ctx, closeAfterIdle)
3647
close(sessionDone)
3748
}()
3849

3950
go func() {
40-
n, err := session.writeToDst(payload)
51+
n, err := session.transportToDst(payload)
4152
require.NoError(t, err)
4253
require.Equal(t, len(payload), n)
4354
}()
@@ -47,13 +58,120 @@ func testSessionReturns(t *testing.T, closeByContext bool) {
4758
require.NoError(t, err)
4859
require.Equal(t, len(payload), n)
4960

50-
if closeByContext {
61+
lastRead := time.Now()
62+
63+
switch closeBy {
64+
case closeByContext:
5165
cancel()
52-
} else {
66+
case closeByCallingClose:
5367
session.close()
5468
}
5569

5670
<-sessionDone
71+
if closeBy == closeByTimeout {
72+
require.True(t, time.Now().After(lastRead.Add(closeAfterIdle)))
73+
}
5774
// call cancelled again otherwise the linter will warn about possible context leak
5875
cancel()
5976
}
77+
78+
type closeMethod int
79+
80+
const (
81+
closeByContext closeMethod = iota
82+
closeByCallingClose
83+
closeByTimeout
84+
)
85+
86+
func TestWriteToDstSessionPreventClosed(t *testing.T) {
87+
testActiveSessionNotClosed(t, false, true)
88+
}
89+
90+
func TestReadFromDstSessionPreventClosed(t *testing.T) {
91+
testActiveSessionNotClosed(t, true, false)
92+
}
93+
94+
func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool) {
95+
const closeAfterIdle = time.Millisecond * 100
96+
const activeTime = time.Millisecond * 500
97+
98+
sessionID := uuid.New()
99+
cfdConn, originConn := net.Pipe()
100+
payload := testPayload(sessionID)
101+
transport := &mockQUICTransport{
102+
reqChan: newDatagramChannel(100),
103+
respChan: newDatagramChannel(100),
104+
}
105+
session := newSession(sessionID, transport, cfdConn)
106+
107+
startTime := time.Now()
108+
activeUntil := startTime.Add(activeTime)
109+
ctx, cancel := context.WithCancel(context.Background())
110+
errGroup, ctx := errgroup.WithContext(ctx)
111+
errGroup.Go(func() error {
112+
session.Serve(ctx, closeAfterIdle)
113+
if time.Now().Before(startTime.Add(activeTime)) {
114+
return fmt.Errorf("session closed while it's still active")
115+
}
116+
return nil
117+
})
118+
119+
if readFromDst {
120+
errGroup.Go(func() error {
121+
for {
122+
if time.Now().After(activeUntil) {
123+
return nil
124+
}
125+
if _, err := originConn.Write(payload); err != nil {
126+
return err
127+
}
128+
time.Sleep(closeAfterIdle / 2)
129+
}
130+
})
131+
}
132+
if writeToDst {
133+
errGroup.Go(func() error {
134+
readBuffer := make([]byte, len(payload))
135+
for {
136+
n, err := originConn.Read(readBuffer)
137+
if err != nil {
138+
if err == io.EOF || err == io.ErrClosedPipe {
139+
return nil
140+
}
141+
return err
142+
}
143+
if !bytes.Equal(payload, readBuffer[:n]) {
144+
return fmt.Errorf("payload %v is not equal to %v", readBuffer[:n], payload)
145+
}
146+
}
147+
})
148+
errGroup.Go(func() error {
149+
for {
150+
if time.Now().After(activeUntil) {
151+
return nil
152+
}
153+
if _, err := session.transportToDst(payload); err != nil {
154+
return err
155+
}
156+
time.Sleep(closeAfterIdle / 2)
157+
}
158+
})
159+
}
160+
161+
require.NoError(t, errGroup.Wait())
162+
cancel()
163+
}
164+
165+
func TestMarkActiveNotBlocking(t *testing.T) {
166+
const concurrentCalls = 50
167+
session := newSession(uuid.New(), nil, nil)
168+
var wg sync.WaitGroup
169+
wg.Add(concurrentCalls)
170+
for i := 0; i < concurrentCalls; i++ {
171+
go func() {
172+
session.markActive()
173+
wg.Done()
174+
}()
175+
}
176+
wg.Wait()
177+
}

0 commit comments

Comments
 (0)