11package datagramsession
22
33import (
4+ "bytes"
45 "context"
6+ "fmt"
7+ "io"
58 "net"
9+ "sync"
610 "testing"
11+ "time"
712
813 "github.com/google/uuid"
914 "github.com/stretchr/testify/require"
15+ "golang.org/x/sync/errgroup"
1016)
1117
1218// TestCloseSession makes sure a session will stop after context is done
1319func TestSessionCtxDone (t * testing.T ) {
14- testSessionReturns (t , true )
20+ testSessionReturns (t , closeByContext , time . Minute * 2 )
1521}
1622
1723// TestCloseSession makes sure a session will stop after close method is called
1824func TestCloseSession (t * testing.T ) {
19- testSessionReturns (t , false )
25+ testSessionReturns (t , closeByCallingClose , time . Minute * 2 )
2026}
2127
22- func testSessionReturns (t * testing.T , closeByContext bool ) {
28+ // TestCloseIdle makess sure a session will stop after there is no read/write for a period defined by closeAfterIdle
29+ func TestCloseIdle (t * testing.T ) {
30+ testSessionReturns (t , closeByTimeout , time .Millisecond * 100 )
31+ }
32+
33+ func testSessionReturns (t * testing.T , closeBy closeMethod , closeAfterIdle time.Duration ) {
2334 sessionID := uuid .New ()
2435 cfdConn , originConn := net .Pipe ()
2536 payload := testPayload (sessionID )
2637 transport := & mockQUICTransport {
27- reqChan : newDatagramChannel (),
28- respChan : newDatagramChannel (),
38+ reqChan : newDatagramChannel (1 ),
39+ respChan : newDatagramChannel (1 ),
2940 }
3041 session := newSession (sessionID , transport , cfdConn )
3142
3243 ctx , cancel := context .WithCancel (context .Background ())
3344 sessionDone := make (chan struct {})
3445 go func () {
35- session .Serve (ctx )
46+ session .Serve (ctx , closeAfterIdle )
3647 close (sessionDone )
3748 }()
3849
3950 go func () {
40- n , err := session .writeToDst (payload )
51+ n , err := session .transportToDst (payload )
4152 require .NoError (t , err )
4253 require .Equal (t , len (payload ), n )
4354 }()
@@ -47,13 +58,120 @@ func testSessionReturns(t *testing.T, closeByContext bool) {
4758 require .NoError (t , err )
4859 require .Equal (t , len (payload ), n )
4960
50- if closeByContext {
61+ lastRead := time .Now ()
62+
63+ switch closeBy {
64+ case closeByContext :
5165 cancel ()
52- } else {
66+ case closeByCallingClose :
5367 session .close ()
5468 }
5569
5670 <- sessionDone
71+ if closeBy == closeByTimeout {
72+ require .True (t , time .Now ().After (lastRead .Add (closeAfterIdle )))
73+ }
5774 // call cancelled again otherwise the linter will warn about possible context leak
5875 cancel ()
5976}
77+
78+ type closeMethod int
79+
80+ const (
81+ closeByContext closeMethod = iota
82+ closeByCallingClose
83+ closeByTimeout
84+ )
85+
86+ func TestWriteToDstSessionPreventClosed (t * testing.T ) {
87+ testActiveSessionNotClosed (t , false , true )
88+ }
89+
90+ func TestReadFromDstSessionPreventClosed (t * testing.T ) {
91+ testActiveSessionNotClosed (t , true , false )
92+ }
93+
94+ func testActiveSessionNotClosed (t * testing.T , readFromDst bool , writeToDst bool ) {
95+ const closeAfterIdle = time .Millisecond * 100
96+ const activeTime = time .Millisecond * 500
97+
98+ sessionID := uuid .New ()
99+ cfdConn , originConn := net .Pipe ()
100+ payload := testPayload (sessionID )
101+ transport := & mockQUICTransport {
102+ reqChan : newDatagramChannel (100 ),
103+ respChan : newDatagramChannel (100 ),
104+ }
105+ session := newSession (sessionID , transport , cfdConn )
106+
107+ startTime := time .Now ()
108+ activeUntil := startTime .Add (activeTime )
109+ ctx , cancel := context .WithCancel (context .Background ())
110+ errGroup , ctx := errgroup .WithContext (ctx )
111+ errGroup .Go (func () error {
112+ session .Serve (ctx , closeAfterIdle )
113+ if time .Now ().Before (startTime .Add (activeTime )) {
114+ return fmt .Errorf ("session closed while it's still active" )
115+ }
116+ return nil
117+ })
118+
119+ if readFromDst {
120+ errGroup .Go (func () error {
121+ for {
122+ if time .Now ().After (activeUntil ) {
123+ return nil
124+ }
125+ if _ , err := originConn .Write (payload ); err != nil {
126+ return err
127+ }
128+ time .Sleep (closeAfterIdle / 2 )
129+ }
130+ })
131+ }
132+ if writeToDst {
133+ errGroup .Go (func () error {
134+ readBuffer := make ([]byte , len (payload ))
135+ for {
136+ n , err := originConn .Read (readBuffer )
137+ if err != nil {
138+ if err == io .EOF || err == io .ErrClosedPipe {
139+ return nil
140+ }
141+ return err
142+ }
143+ if ! bytes .Equal (payload , readBuffer [:n ]) {
144+ return fmt .Errorf ("payload %v is not equal to %v" , readBuffer [:n ], payload )
145+ }
146+ }
147+ })
148+ errGroup .Go (func () error {
149+ for {
150+ if time .Now ().After (activeUntil ) {
151+ return nil
152+ }
153+ if _ , err := session .transportToDst (payload ); err != nil {
154+ return err
155+ }
156+ time .Sleep (closeAfterIdle / 2 )
157+ }
158+ })
159+ }
160+
161+ require .NoError (t , errGroup .Wait ())
162+ cancel ()
163+ }
164+
165+ func TestMarkActiveNotBlocking (t * testing.T ) {
166+ const concurrentCalls = 50
167+ session := newSession (uuid .New (), nil , nil )
168+ var wg sync.WaitGroup
169+ wg .Add (concurrentCalls )
170+ for i := 0 ; i < concurrentCalls ; i ++ {
171+ go func () {
172+ session .markActive ()
173+ wg .Done ()
174+ }()
175+ }
176+ wg .Wait ()
177+ }
0 commit comments