@@ -10,7 +10,6 @@ import (
1010 "io"
1111 "net/http/httptest"
1212 "os/exec"
13- "strings"
1413 "sync"
1514 "testing"
1615 "time"
@@ -30,21 +29,23 @@ func createTestServer(t *testing.T, maxClients int, shutdownDelay time.Duration)
3029 return httptest .NewServer (proxyServer )
3130}
3231
33- func createTestClient (t * testing.T , serverURL string , handoverTimeout time.Duration , errChan , createConnChan chan error ) (io.WriteCloser , * testBuffer ) {
32+ func createTestClient (t * testing.T , serverURL string , requestHandoverTick func () <- chan time.Time , errChan chan error ) (io.WriteCloser , * testBuffer ) {
3433 ctx := t .Context ()
3534 clientInput , clientInputWriter := io .Pipe ()
3635 clientOutput := newTestBuffer (t )
3736 wsURL := "ws" + serverURL [4 :]
3837 createConn := func (ctx context.Context , connID string ) (* websocket.Conn , error ) {
3938 url := fmt .Sprintf ("%s?id=%s" , wsURL , connID )
4039 conn , _ , err := websocket .DefaultDialer .Dial (url , nil ) // nolint:bodyclose
41- if createConnChan != nil {
42- createConnChan <- err
43- }
4440 return conn , err
4541 }
42+ if requestHandoverTick == nil {
43+ requestHandoverTick = func () <- chan time.Time {
44+ return time .After (time .Hour )
45+ }
46+ }
4647 go func () {
47- err := RunClientProxy (ctx , clientInput , clientOutput , handoverTimeout , createConn )
48+ err := RunClientProxy (ctx , clientInput , clientOutput , requestHandoverTick , createConn )
4849 if err != nil && ! isNormalClosure (err ) && ! errors .Is (err , context .Canceled ) {
4950 if errChan != nil {
5051 errChan <- err
@@ -59,7 +60,7 @@ func createTestClient(t *testing.T, serverURL string, handoverTimeout time.Durat
5960func TestClientServerEcho (t * testing.T ) {
6061 server := createTestServer (t , 2 , time .Hour )
6162 defer server .Close ()
62- clientInputWriter , clientOutput := createTestClient (t , server .URL , time . Hour , nil , nil )
63+ clientInputWriter , clientOutput := createTestClient (t , server .URL , nil , nil )
6364 defer clientInputWriter .Close ()
6465
6566 testMsg1 := []byte ("test message 1\n " )
@@ -81,9 +82,9 @@ func TestClientServerEcho(t *testing.T) {
8182func TestMultipleClients (t * testing.T ) {
8283 server := createTestServer (t , 2 , time .Hour )
8384 defer server .Close ()
84- clientInputWriter1 , clientOutput1 := createTestClient (t , server .URL , time . Hour , nil , nil )
85+ clientInputWriter1 , clientOutput1 := createTestClient (t , server .URL , nil , nil )
8586 defer clientInputWriter1 .Close ()
86- clientInputWriter2 , clientOutput2 := createTestClient (t , server .URL , time . Hour , nil , nil )
87+ clientInputWriter2 , clientOutput2 := createTestClient (t , server .URL , nil , nil )
8788 defer clientInputWriter2 .Close ()
8889
8990 messageCount := 10
@@ -113,9 +114,9 @@ func TestMaxClients(t *testing.T) {
113114 maxClients := 2
114115 server := createTestServer (t , maxClients , time .Hour )
115116 defer server .Close ()
116- clientInputWriter1 , clientOutput1 := createTestClient (t , server .URL , time . Hour , nil , nil )
117+ clientInputWriter1 , clientOutput1 := createTestClient (t , server .URL , nil , nil )
117118 defer clientInputWriter1 .Close ()
118- clientInputWriter2 , clientOutput2 := createTestClient (t , server .URL , time . Hour , nil , nil )
119+ clientInputWriter2 , clientOutput2 := createTestClient (t , server .URL , nil , nil )
119120 defer clientInputWriter2 .Close ()
120121
121122 testMsg1 := []byte ("test message 1\n " )
@@ -129,7 +130,7 @@ func TestMaxClients(t *testing.T) {
129130 require .NoError (t , err )
130131
131132 errChan := make (chan error , 1 )
132- clientInputWriter3 , _ := createTestClient (t , server .URL , time . Hour , errChan , nil )
133+ clientInputWriter3 , _ := createTestClient (t , server .URL , nil , errChan )
133134 defer clientInputWriter3 .Close ()
134135 select {
135136 case err = <- errChan :
@@ -143,60 +144,35 @@ func TestHandover(t *testing.T) {
143144 server := createTestServer (t , 2 , time .Hour )
144145 defer server .Close ()
145146
146- maxHandoverCount := 3
147- handoverTimeout := 500 * time .Millisecond
148- createConnChan := make (chan error , 1 )
149- clientInputWriter , clientOutput := createTestClient (t , server .URL , handoverTimeout , nil , createConnChan )
147+ handoverChan := make (chan time.Time )
148+ requestHandoverTick := func () <- chan time.Time {
149+ return handoverChan
150+ }
151+ clientInputWriter , clientOutput := createTestClient (t , server .URL , requestHandoverTick , nil )
150152 defer clientInputWriter .Close ()
151153
152- messageCount := 0
153154 expectedOutput := ""
154- sendMessage := func () {
155- message := fmt .Appendf (nil , "message %d\n " , messageCount )
156- _ , err := clientInputWriter .Write (message )
157- if err != nil {
158- t .Errorf ("failed to write message %d: %v" , messageCount , err )
159- }
160- messageCount ++
161- if messageCount > TOTAL_MESSAGE_COUNT {
162- t .Errorf ("exceeded total message count, test buffer won't work correctly" )
163- }
164- expectedOutput += string (message )
165- }
166155
167156 wg := sync.WaitGroup {}
168157 wg .Go (func () {
169- handoverCount := 0
170- for {
171- select {
172- case <- createConnChan :
173- sendMessage ()
174- handoverCount ++
175- if handoverCount >= maxHandoverCount {
176- return
177- }
178- default :
179- sendMessage ()
180- time .Sleep (time .Millisecond )
158+ for i := range TOTAL_MESSAGE_COUNT {
159+ if i > 0 && i % MESSAGES_PER_CHUNK == 0 {
160+ handoverChan <- time .Now ()
181161 }
162+ message := fmt .Appendf (nil , "message %d\n " , i )
163+ _ , err := clientInputWriter .Write (message )
164+ if err != nil {
165+ t .Errorf ("failed to write message %d: %v" , i , err )
166+ }
167+ expectedOutput += string (message )
182168 }
183169 })
184170
185- wg .Wait ()
171+ err := clientOutput .WaitForWrite (fmt .Appendf (nil , "message %d\n " , TOTAL_MESSAGE_COUNT - 1 ))
172+ require .NoError (t , err , "failed to receive the last message (%d)" , TOTAL_MESSAGE_COUNT - 1 )
186173
187- for i := 0 ; i < messageCount ; {
188- // Client can receive multiple echo messages in one response,
189- // so we split them again and verify each one.
190- data , err := clientOutput .WaitForWrite ()
191- require .NoError (t , err , "failed to receive message %d" , i )
192- lines := strings .SplitSeq (string (data ), "\n " )
193- for line := range lines {
194- if line != "" {
195- assert .Equal (t , fmt .Sprintf ("message %d\n " , i ), line + "\n " )
196- i ++
197- }
198- }
199- }
174+ wg .Wait ()
200175
176+ // clientOutput is created by appending incoming messages as they arrive, so we are also test correct order here
201177 assert .Equal (t , expectedOutput , clientOutput .String ())
202178}
0 commit comments