Skip to content

Commit 8f0498f

Browse files
committed
TUN-6123: For a given connection with edge, close all datagram sessions through this connection when it's closed
1 parent a97233b commit 8f0498f

File tree

4 files changed

+103
-48
lines changed

4 files changed

+103
-48
lines changed

datagramsession/manager.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package datagramsession
22

33
import (
44
"context"
5+
"fmt"
56
"io"
67
"time"
78

@@ -16,6 +17,10 @@ const (
1617
defaultReqTimeout = time.Second * 5
1718
)
1819

20+
var (
21+
errSessionManagerClosed = fmt.Errorf("session manager closed")
22+
)
23+
1924
// Manager defines the APIs to manage sessions from the same transport.
2025
type Manager interface {
2126
// Serve starts the event loop
@@ -30,6 +35,7 @@ type manager struct {
3035
registrationChan chan *registerSessionEvent
3136
unregistrationChan chan *unregisterSessionEvent
3237
datagramChan chan *newDatagram
38+
closedChan chan struct{}
3339
transport transport
3440
sessions map[uuid.UUID]*Session
3541
log *zerolog.Logger
@@ -43,6 +49,7 @@ func NewManager(transport transport, log *zerolog.Logger) *manager {
4349
unregistrationChan: make(chan *unregisterSessionEvent),
4450
// datagramChan is buffered, so it can read more datagrams from transport while the event loop is processing other events
4551
datagramChan: make(chan *newDatagram, requestChanCapacity),
52+
closedChan: make(chan struct{}),
4653
transport: transport,
4754
sessions: make(map[uuid.UUID]*Session),
4855
log: log,
@@ -90,7 +97,24 @@ func (m *manager) Serve(ctx context.Context) error {
9097
}
9198
}
9299
})
93-
return errGroup.Wait()
100+
err := errGroup.Wait()
101+
close(m.closedChan)
102+
m.shutdownSessions(err)
103+
return err
104+
}
105+
106+
func (m *manager) shutdownSessions(err error) {
107+
if err == nil {
108+
err = errSessionManagerClosed
109+
}
110+
closeSessionErr := &errClosedSession{
111+
message: err.Error(),
112+
// Usually connection with remote has been closed, so set this to true to skip unregistering from remote
113+
byRemote: true,
114+
}
115+
for _, s := range m.sessions {
116+
s.close(closeSessionErr)
117+
}
94118
}
95119

96120
func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, originProxy io.ReadWriteCloser) (*Session, error) {
@@ -104,15 +128,33 @@ func (m *manager) RegisterSession(ctx context.Context, sessionID uuid.UUID, orig
104128
case m.registrationChan <- event:
105129
session := <-event.resultChan
106130
return session, nil
131+
// Once closedChan is closed, manager won't accept more registration because nothing is
132+
// reading from registrationChan and it's an unbuffered channel
133+
case <-m.closedChan:
134+
return nil, errSessionManagerClosed
107135
}
108136
}
109137

110138
func (m *manager) registerSession(ctx context.Context, registration *registerSessionEvent) {
111-
session := newSession(registration.sessionID, m.transport, registration.originProxy, m.log)
139+
session := m.newSession(registration.sessionID, registration.originProxy)
112140
m.sessions[registration.sessionID] = session
113141
registration.resultChan <- session
114142
}
115143

144+
func (m *manager) newSession(id uuid.UUID, dstConn io.ReadWriteCloser) *Session {
145+
return &Session{
146+
ID: id,
147+
transport: m.transport,
148+
dstConn: dstConn,
149+
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
150+
// drop instead of blocking because last active time only needs to be an approximation
151+
activeAtChan: make(chan time.Time, 2),
152+
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
153+
closeChan: make(chan error, 2),
154+
log: m.log,
155+
}
156+
}
157+
116158
func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string, byRemote bool) error {
117159
ctx, cancel := context.WithTimeout(ctx, m.timeout)
118160
defer cancel()
@@ -129,6 +171,8 @@ func (m *manager) UnregisterSession(ctx context.Context, sessionID uuid.UUID, me
129171
return ctx.Err()
130172
case m.unregistrationChan <- event:
131173
return nil
174+
case <-m.closedChan:
175+
return errSessionManagerClosed
132176
}
133177
}
134178

datagramsession/manager_test.go

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"io"
88
"net"
9+
"sync"
910
"testing"
1011
"time"
1112

@@ -21,12 +22,8 @@ func TestManagerServe(t *testing.T) {
2122
msgs = 50
2223
remoteUnregisterMsg = "eyeball closed connection"
2324
)
24-
log := zerolog.Nop()
25-
transport := &mockQUICTransport{
26-
reqChan: newDatagramChannel(1),
27-
respChan: newDatagramChannel(1),
28-
}
29-
mg := NewManager(transport, &log)
25+
26+
mg, transport := newTestManager(1)
3027

3128
eyeballTracker := make(map[uuid.UUID]*datagramChannel)
3229
for i := 0; i < sessions; i++ {
@@ -124,12 +121,8 @@ func TestTimeout(t *testing.T) {
124121
const (
125122
testTimeout = time.Millisecond * 50
126123
)
127-
log := zerolog.Nop()
128-
transport := &mockQUICTransport{
129-
reqChan: newDatagramChannel(1),
130-
respChan: newDatagramChannel(1),
131-
}
132-
mg := NewManager(transport, &log)
124+
125+
mg, _ := newTestManager(1)
133126
mg.timeout = testTimeout
134127
ctx := context.Background()
135128
sessionID := uuid.New()
@@ -142,6 +135,47 @@ func TestTimeout(t *testing.T) {
142135
require.ErrorIs(t, err, context.DeadlineExceeded)
143136
}
144137

138+
func TestCloseTransportCloseSessions(t *testing.T) {
139+
mg, transport := newTestManager(1)
140+
ctx := context.Background()
141+
142+
var wg sync.WaitGroup
143+
wg.Add(1)
144+
go func() {
145+
defer wg.Done()
146+
err := mg.Serve(ctx)
147+
require.Error(t, err)
148+
}()
149+
150+
cfdConn, eyeballConn := net.Pipe()
151+
session, err := mg.RegisterSession(ctx, uuid.New(), cfdConn)
152+
require.NoError(t, err)
153+
require.NotNil(t, session)
154+
155+
wg.Add(1)
156+
go func() {
157+
defer wg.Done()
158+
_, err := eyeballConn.Write([]byte(t.Name()))
159+
require.NoError(t, err)
160+
transport.close()
161+
}()
162+
163+
closedByRemote, err := session.Serve(ctx, time.Minute)
164+
require.True(t, closedByRemote)
165+
require.Error(t, err)
166+
167+
wg.Wait()
168+
}
169+
170+
func newTestManager(capacity uint) (*manager, *mockQUICTransport) {
171+
log := zerolog.Nop()
172+
transport := &mockQUICTransport{
173+
reqChan: newDatagramChannel(capacity),
174+
respChan: newDatagramChannel(capacity),
175+
}
176+
return NewManager(transport, &log), transport
177+
}
178+
145179
type mockOrigin struct {
146180
expectMsgCount int
147181
expectedMsg []byte

datagramsession/session.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,6 @@ type Session struct {
3939
log *zerolog.Logger
4040
}
4141

42-
func newSession(id uuid.UUID, transport transport, dstConn io.ReadWriteCloser, log *zerolog.Logger) *Session {
43-
return &Session{
44-
ID: id,
45-
transport: transport,
46-
dstConn: dstConn,
47-
// activeAtChan has low capacity. It can be full when there are many concurrent read/write. markActive() will
48-
// drop instead of blocking because last active time only needs to be an approximation
49-
activeAtChan: make(chan time.Time, 2),
50-
// capacity is 2 because close() and dstToTransport routine in Serve() can write to this channel
51-
closeChan: make(chan error, 2),
52-
log: log,
53-
}
54-
}
55-
5642
func (s *Session) Serve(ctx context.Context, closeAfterIdle time.Duration) (closedByRemote bool, err error) {
5743
go func() {
5844
// QUIC implementation copies data to another buffer before returning https://github.com/lucas-clemente/quic-go/blob/v0.24.0/session.go#L1967-L1975

datagramsession/session_test.go

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"time"
1212

1313
"github.com/google/uuid"
14-
"github.com/rs/zerolog"
1514
"github.com/stretchr/testify/require"
1615
"golang.org/x/sync/errgroup"
1716
)
@@ -41,12 +40,9 @@ func testSessionReturns(t *testing.T, closeBy closeMethod, closeAfterIdle time.D
4140
sessionID := uuid.New()
4241
cfdConn, originConn := net.Pipe()
4342
payload := testPayload(sessionID)
44-
transport := &mockQUICTransport{
45-
reqChan: newDatagramChannel(1),
46-
respChan: newDatagramChannel(1),
47-
}
48-
log := zerolog.Nop()
49-
session := newSession(sessionID, transport, cfdConn, &log)
43+
44+
mg, _ := newTestManager(1)
45+
session := mg.newSession(sessionID, cfdConn)
5046

5147
ctx, cancel := context.WithCancel(context.Background())
5248
sessionDone := make(chan struct{})
@@ -117,12 +113,9 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
117113
sessionID := uuid.New()
118114
cfdConn, originConn := net.Pipe()
119115
payload := testPayload(sessionID)
120-
transport := &mockQUICTransport{
121-
reqChan: newDatagramChannel(100),
122-
respChan: newDatagramChannel(100),
123-
}
124-
log := zerolog.Nop()
125-
session := newSession(sessionID, transport, cfdConn, &log)
116+
117+
mg, _ := newTestManager(100)
118+
session := mg.newSession(sessionID, cfdConn)
126119

127120
startTime := time.Now()
128121
activeUntil := startTime.Add(activeTime)
@@ -184,7 +177,8 @@ func testActiveSessionNotClosed(t *testing.T, readFromDst bool, writeToDst bool)
184177

185178
func TestMarkActiveNotBlocking(t *testing.T) {
186179
const concurrentCalls = 50
187-
session := newSession(uuid.New(), nil, nil, nil)
180+
mg, _ := newTestManager(1)
181+
session := mg.newSession(uuid.New(), nil)
188182
var wg sync.WaitGroup
189183
wg.Add(concurrentCalls)
190184
for i := 0; i < concurrentCalls; i++ {
@@ -199,12 +193,9 @@ func TestMarkActiveNotBlocking(t *testing.T) {
199193
func TestZeroBytePayload(t *testing.T) {
200194
sessionID := uuid.New()
201195
cfdConn, originConn := net.Pipe()
202-
transport := &mockQUICTransport{
203-
reqChan: newDatagramChannel(1),
204-
respChan: newDatagramChannel(1),
205-
}
206-
log := zerolog.Nop()
207-
session := newSession(sessionID, transport, cfdConn, &log)
196+
197+
mg, transport := newTestManager(1)
198+
session := mg.newSession(sessionID, cfdConn)
208199

209200
ctx, cancel := context.WithCancel(context.Background())
210201
errGroup, ctx := errgroup.WithContext(ctx)

0 commit comments

Comments
 (0)