@@ -31,7 +31,13 @@ func createTestServer(t *testing.T, maxClients int, shutdownDelay time.Duration)
3131 return httptest .NewServer (proxyServer )
3232}
3333
34- func createTestClient (t * testing.T , serverURL string , requestHandoverTick func () <- chan time.Time , errChan chan error ) (io.WriteCloser , * testBuffer ) {
34+ type testClient struct {
35+ InputWriter io.WriteCloser
36+ Output * testBuffer
37+ Cleanup func ()
38+ }
39+
40+ func createTestClient (t * testing.T , serverURL string , requestHandoverTick func () <- chan time.Time , errChan chan error ) * testClient {
3541 ctx := cmdio .MockDiscard (t .Context ())
3642 clientInput , clientInputWriter := io .Pipe ()
3743 clientOutput := newTestBuffer (t )
@@ -46,94 +52,103 @@ func createTestClient(t *testing.T, serverURL string, requestHandoverTick func()
4652 return time .After (time .Hour )
4753 }
4854 }
49- go func () {
55+ wg := sync.WaitGroup {}
56+ wg .Go (func () {
5057 err := RunClientProxy (ctx , clientInput , clientOutput , requestHandoverTick , createConn )
51- if err != nil && ! errors .Is (err , context .Canceled ) {
58+ if err != nil && ! errors .Is (err , context .Canceled ) && ! errors . Is ( err , io . ErrClosedPipe ) {
5259 if errChan != nil {
5360 errChan <- err
5461 } else {
5562 t .Errorf ("client error: %v" , err )
5663 }
5764 }
58- }()
59- return clientInputWriter , clientOutput
65+ })
66+ return & testClient {
67+ InputWriter : clientInputWriter ,
68+ Output : clientOutput ,
69+ Cleanup : func () {
70+ clientInput .Close ()
71+ clientInputWriter .Close ()
72+ wg .Wait ()
73+ },
74+ }
6075}
6176
6277func TestClientServerEcho (t * testing.T ) {
6378 server := createTestServer (t , 2 , time .Hour )
6479 defer server .Close ()
65- clientInputWriter , clientOutput := createTestClient (t , server .URL , nil , nil )
66- defer clientInputWriter . Close ()
80+ client := createTestClient (t , server .URL , nil , nil )
81+ defer client . Cleanup ()
6782
6883 testMsg1 := []byte ("test message 1\n " )
69- _ , err := clientInputWriter .Write (testMsg1 )
84+ _ , err := client . InputWriter .Write (testMsg1 )
7085 require .NoError (t , err )
71- err = clientOutput .AssertWrite (testMsg1 )
86+ err = client . Output .AssertWrite (testMsg1 )
7287 require .NoError (t , err )
7388
7489 testMsg2 := []byte ("test message 2\n " )
75- _ , err = clientInputWriter .Write (testMsg2 )
90+ _ , err = client . InputWriter .Write (testMsg2 )
7691 require .NoError (t , err )
77- err = clientOutput .AssertWrite (testMsg2 )
92+ err = client . Output .AssertWrite (testMsg2 )
7893 require .NoError (t , err )
7994
8095 expectedOutput := fmt .Sprintf ("%s%s" , testMsg1 , testMsg2 )
81- assert .Equal (t , expectedOutput , clientOutput .String ())
96+ assert .Equal (t , expectedOutput , client . Output .String ())
8297}
8398
8499func TestMultipleClients (t * testing.T ) {
85100 server := createTestServer (t , 2 , time .Hour )
86101 defer server .Close ()
87- clientInputWriter1 , clientOutput1 := createTestClient (t , server .URL , nil , nil )
88- defer clientInputWriter1 . Close ()
89- clientInputWriter2 , clientOutput2 := createTestClient (t , server .URL , nil , nil )
90- defer clientInputWriter2 . Close ()
102+ client1 := createTestClient (t , server .URL , nil , nil )
103+ defer client1 . Cleanup ()
104+ client2 := createTestClient (t , server .URL , nil , nil )
105+ defer client2 . Cleanup ()
91106
92107 messageCount := 10
93108 expectedClientOutput1 := ""
94109 expectedClientOutput2 := ""
95110 for i := range messageCount {
96111 message := fmt .Appendf (nil , "client 1 message %d\n " , i )
97- _ , err := clientInputWriter1 .Write (message )
112+ _ , err := client1 . InputWriter .Write (message )
98113 require .NoError (t , err )
99- err = clientOutput1 .AssertWrite (message )
114+ err = client1 . Output .AssertWrite (message )
100115 require .NoError (t , err )
101116 expectedClientOutput1 += string (message )
102117
103118 message = fmt .Appendf (nil , "client 2 message %d\n " , i )
104- _ , err = clientInputWriter2 .Write (message )
119+ _ , err = client2 . InputWriter .Write (message )
105120 require .NoError (t , err )
106- err = clientOutput2 .AssertWrite (message )
121+ err = client2 . Output .AssertWrite (message )
107122 require .NoError (t , err )
108123 expectedClientOutput2 += string (message )
109124 }
110125
111- assert .Equal (t , expectedClientOutput1 , clientOutput1 .String ())
112- assert .Equal (t , expectedClientOutput2 , clientOutput2 .String ())
126+ assert .Equal (t , expectedClientOutput1 , client1 . Output .String ())
127+ assert .Equal (t , expectedClientOutput2 , client2 . Output .String ())
113128}
114129
115130func TestMaxClients (t * testing.T ) {
116131 maxClients := 2
117132 server := createTestServer (t , maxClients , time .Hour )
118133 defer server .Close ()
119- clientInputWriter1 , clientOutput1 := createTestClient (t , server .URL , nil , nil )
120- defer clientInputWriter1 . Close ()
121- clientInputWriter2 , clientOutput2 := createTestClient (t , server .URL , nil , nil )
122- defer clientInputWriter2 . Close ()
134+ client1 := createTestClient (t , server .URL , nil , nil )
135+ defer client1 . Cleanup ()
136+ client2 := createTestClient (t , server .URL , nil , nil )
137+ defer client2 . Cleanup ()
123138
124139 testMsg1 := []byte ("test message 1\n " )
125- _ , err := clientInputWriter1 .Write (testMsg1 )
140+ _ , err := client1 . InputWriter .Write (testMsg1 )
126141 require .NoError (t , err )
127- err = clientOutput1 .AssertWrite (testMsg1 )
142+ err = client1 . Output .AssertWrite (testMsg1 )
128143 require .NoError (t , err )
129- _ , err = clientInputWriter2 .Write (testMsg1 )
144+ _ , err = client2 . InputWriter .Write (testMsg1 )
130145 require .NoError (t , err )
131- err = clientOutput2 .AssertWrite (testMsg1 )
146+ err = client2 . Output .AssertWrite (testMsg1 )
132147 require .NoError (t , err )
133148
134149 errChan := make (chan error , 1 )
135- clientInputWriter3 , _ := createTestClient (t , server .URL , nil , errChan )
136- defer clientInputWriter3 . Close ()
150+ client3 := createTestClient (t , server .URL , nil , errChan )
151+ defer client3 . Cleanup ()
137152 select {
138153 case err = <- errChan :
139154 require .Error (t , err )
@@ -150,8 +165,8 @@ func TestHandover(t *testing.T) {
150165 requestHandoverTick := func () <- chan time.Time {
151166 return handoverChan
152167 }
153- clientInputWriter , clientOutput := createTestClient (t , server .URL , requestHandoverTick , nil )
154- defer clientInputWriter . Close ()
168+ client := createTestClient (t , server .URL , requestHandoverTick , nil )
169+ defer client . Cleanup ()
155170
156171 expectedOutput := ""
157172
@@ -162,21 +177,21 @@ func TestHandover(t *testing.T) {
162177 handoverChan <- time .Now ()
163178 }
164179 message := fmt .Appendf (nil , "message %d\n " , i )
165- _ , err := clientInputWriter .Write (message )
180+ _ , err := client . InputWriter .Write (message )
166181 if err != nil {
167182 t .Errorf ("failed to write message %d: %v" , i , err )
168183 }
169184 expectedOutput += string (message )
170185 }
171186 })
172187
173- err := clientOutput .WaitForWrite (fmt .Appendf (nil , "message %d\n " , TOTAL_MESSAGE_COUNT - 1 ))
188+ err := client . Output .WaitForWrite (fmt .Appendf (nil , "message %d\n " , TOTAL_MESSAGE_COUNT - 1 ))
174189 require .NoError (t , err , "failed to receive the last message (%d)" , TOTAL_MESSAGE_COUNT - 1 )
175190
176191 wg .Wait ()
177192
178- // clientOutput is created by appending incoming messages as they arrive, so we are also test correct order here
179- assert .Equal (t , expectedOutput , clientOutput .String ())
193+ // client.Output is created by appending incoming messages as they arrive, so we are also test correct order here
194+ assert .Equal (t , expectedOutput , client . Output .String ())
180195}
181196
182197// Tests handovers in quick succession with few messages in between.
@@ -189,8 +204,8 @@ func TestQuickHandover(t *testing.T) {
189204 requestHandoverTick := func () <- chan time.Time {
190205 return handoverChan
191206 }
192- clientInputWriter , clientOutput := createTestClient (t , server .URL , requestHandoverTick , nil )
193- defer clientInputWriter . Close ()
207+ client := createTestClient (t , server .URL , requestHandoverTick , nil )
208+ defer client . Cleanup ()
194209
195210 expectedOutput := ""
196211
@@ -201,18 +216,18 @@ func TestQuickHandover(t *testing.T) {
201216 handoverChan <- time .Now ()
202217 }
203218 message := fmt .Appendf (nil , "message %d\n " , i )
204- _ , err := clientInputWriter .Write (message )
219+ _ , err := client . InputWriter .Write (message )
205220 if err != nil {
206221 t .Errorf ("failed to write message %d: %v" , i , err )
207222 }
208223 expectedOutput += string (message )
209224 }
210225 })
211226
212- err := clientOutput .WaitForWrite (fmt .Appendf (nil , "message %d\n " , 15 ))
227+ err := client . Output .WaitForWrite (fmt .Appendf (nil , "message %d\n " , 15 ))
213228 require .NoError (t , err , "failed to receive the last message (%d)" , 15 )
214229
215230 wg .Wait ()
216231
217- assert .Equal (t , expectedOutput , clientOutput .String ())
232+ assert .Equal (t , expectedOutput , client . Output .String ())
218233}
0 commit comments