Skip to content

Commit 137b596

Browse files
authored
Wait for cleanup in ssh-tunnel tests (#4028)
## Changes That wouldn't fix the tests, but should make it easier to debug ## Why <!-- Why are these changes needed? Provide the context that the reviewer might be missing. For example, were there any decisions behind the change that are not reflected in the code itself? --> ## Tests <!-- How have you tested the changes? --> <!-- If your PR needs to be included in the release notes for next release, add a separate entry in NEXT_CHANGELOG.md as part of your PR. -->
1 parent f102a4c commit 137b596

File tree

2 files changed

+65
-47
lines changed

2 files changed

+65
-47
lines changed

experimental/ssh/internal/proxy/client_server_test.go

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ func createTestServer(t *testing.T, maxClients int, shutdownDelay time.Duration)
3131
return httptest.NewServer(proxyServer)
3232
}
3333

34-
func createTestClient(t *testing.T, serverURL string, requestHandoverTick func() <-chan time.Time, errChan chan error) (io.WriteCloser, *testBuffer) {
34+
type testClient struct {
35+
InputWriter io.WriteCloser
36+
Output *testBuffer
37+
Cleanup func()
38+
}
39+
40+
func createTestClient(t *testing.T, serverURL string, requestHandoverTick func() <-chan time.Time, errChan chan error) *testClient {
3541
ctx := cmdio.MockDiscard(t.Context())
3642
clientInput, clientInputWriter := io.Pipe()
3743
clientOutput := newTestBuffer(t)
@@ -46,94 +52,103 @@ func createTestClient(t *testing.T, serverURL string, requestHandoverTick func()
4652
return time.After(time.Hour)
4753
}
4854
}
49-
go func() {
55+
wg := sync.WaitGroup{}
56+
wg.Go(func() {
5057
err := RunClientProxy(ctx, clientInput, clientOutput, requestHandoverTick, createConn)
51-
if err != nil && !errors.Is(err, context.Canceled) {
58+
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrClosedPipe) {
5259
if errChan != nil {
5360
errChan <- err
5461
} else {
5562
t.Errorf("client error: %v", err)
5663
}
5764
}
58-
}()
59-
return clientInputWriter, clientOutput
65+
})
66+
return &testClient{
67+
InputWriter: clientInputWriter,
68+
Output: clientOutput,
69+
Cleanup: func() {
70+
clientInput.Close()
71+
clientInputWriter.Close()
72+
wg.Wait()
73+
},
74+
}
6075
}
6176

6277
func TestClientServerEcho(t *testing.T) {
6378
server := createTestServer(t, 2, time.Hour)
6479
defer server.Close()
65-
clientInputWriter, clientOutput := createTestClient(t, server.URL, nil, nil)
66-
defer clientInputWriter.Close()
80+
client := createTestClient(t, server.URL, nil, nil)
81+
defer client.Cleanup()
6782

6883
testMsg1 := []byte("test message 1\n")
69-
_, err := clientInputWriter.Write(testMsg1)
84+
_, err := client.InputWriter.Write(testMsg1)
7085
require.NoError(t, err)
71-
err = clientOutput.AssertWrite(testMsg1)
86+
err = client.Output.AssertWrite(testMsg1)
7287
require.NoError(t, err)
7388

7489
testMsg2 := []byte("test message 2\n")
75-
_, err = clientInputWriter.Write(testMsg2)
90+
_, err = client.InputWriter.Write(testMsg2)
7691
require.NoError(t, err)
77-
err = clientOutput.AssertWrite(testMsg2)
92+
err = client.Output.AssertWrite(testMsg2)
7893
require.NoError(t, err)
7994

8095
expectedOutput := fmt.Sprintf("%s%s", testMsg1, testMsg2)
81-
assert.Equal(t, expectedOutput, clientOutput.String())
96+
assert.Equal(t, expectedOutput, client.Output.String())
8297
}
8398

8499
func TestMultipleClients(t *testing.T) {
85100
server := createTestServer(t, 2, time.Hour)
86101
defer server.Close()
87-
clientInputWriter1, clientOutput1 := createTestClient(t, server.URL, nil, nil)
88-
defer clientInputWriter1.Close()
89-
clientInputWriter2, clientOutput2 := createTestClient(t, server.URL, nil, nil)
90-
defer clientInputWriter2.Close()
102+
client1 := createTestClient(t, server.URL, nil, nil)
103+
defer client1.Cleanup()
104+
client2 := createTestClient(t, server.URL, nil, nil)
105+
defer client2.Cleanup()
91106

92107
messageCount := 10
93108
expectedClientOutput1 := ""
94109
expectedClientOutput2 := ""
95110
for i := range messageCount {
96111
message := fmt.Appendf(nil, "client 1 message %d\n", i)
97-
_, err := clientInputWriter1.Write(message)
112+
_, err := client1.InputWriter.Write(message)
98113
require.NoError(t, err)
99-
err = clientOutput1.AssertWrite(message)
114+
err = client1.Output.AssertWrite(message)
100115
require.NoError(t, err)
101116
expectedClientOutput1 += string(message)
102117

103118
message = fmt.Appendf(nil, "client 2 message %d\n", i)
104-
_, err = clientInputWriter2.Write(message)
119+
_, err = client2.InputWriter.Write(message)
105120
require.NoError(t, err)
106-
err = clientOutput2.AssertWrite(message)
121+
err = client2.Output.AssertWrite(message)
107122
require.NoError(t, err)
108123
expectedClientOutput2 += string(message)
109124
}
110125

111-
assert.Equal(t, expectedClientOutput1, clientOutput1.String())
112-
assert.Equal(t, expectedClientOutput2, clientOutput2.String())
126+
assert.Equal(t, expectedClientOutput1, client1.Output.String())
127+
assert.Equal(t, expectedClientOutput2, client2.Output.String())
113128
}
114129

115130
func TestMaxClients(t *testing.T) {
116131
maxClients := 2
117132
server := createTestServer(t, maxClients, time.Hour)
118133
defer server.Close()
119-
clientInputWriter1, clientOutput1 := createTestClient(t, server.URL, nil, nil)
120-
defer clientInputWriter1.Close()
121-
clientInputWriter2, clientOutput2 := createTestClient(t, server.URL, nil, nil)
122-
defer clientInputWriter2.Close()
134+
client1 := createTestClient(t, server.URL, nil, nil)
135+
defer client1.Cleanup()
136+
client2 := createTestClient(t, server.URL, nil, nil)
137+
defer client2.Cleanup()
123138

124139
testMsg1 := []byte("test message 1\n")
125-
_, err := clientInputWriter1.Write(testMsg1)
140+
_, err := client1.InputWriter.Write(testMsg1)
126141
require.NoError(t, err)
127-
err = clientOutput1.AssertWrite(testMsg1)
142+
err = client1.Output.AssertWrite(testMsg1)
128143
require.NoError(t, err)
129-
_, err = clientInputWriter2.Write(testMsg1)
144+
_, err = client2.InputWriter.Write(testMsg1)
130145
require.NoError(t, err)
131-
err = clientOutput2.AssertWrite(testMsg1)
146+
err = client2.Output.AssertWrite(testMsg1)
132147
require.NoError(t, err)
133148

134149
errChan := make(chan error, 1)
135-
clientInputWriter3, _ := createTestClient(t, server.URL, nil, errChan)
136-
defer clientInputWriter3.Close()
150+
client3 := createTestClient(t, server.URL, nil, errChan)
151+
defer client3.Cleanup()
137152
select {
138153
case err = <-errChan:
139154
require.Error(t, err)
@@ -150,8 +165,8 @@ func TestHandover(t *testing.T) {
150165
requestHandoverTick := func() <-chan time.Time {
151166
return handoverChan
152167
}
153-
clientInputWriter, clientOutput := createTestClient(t, server.URL, requestHandoverTick, nil)
154-
defer clientInputWriter.Close()
168+
client := createTestClient(t, server.URL, requestHandoverTick, nil)
169+
defer client.Cleanup()
155170

156171
expectedOutput := ""
157172

@@ -162,21 +177,21 @@ func TestHandover(t *testing.T) {
162177
handoverChan <- time.Now()
163178
}
164179
message := fmt.Appendf(nil, "message %d\n", i)
165-
_, err := clientInputWriter.Write(message)
180+
_, err := client.InputWriter.Write(message)
166181
if err != nil {
167182
t.Errorf("failed to write message %d: %v", i, err)
168183
}
169184
expectedOutput += string(message)
170185
}
171186
})
172187

173-
err := clientOutput.WaitForWrite(fmt.Appendf(nil, "message %d\n", TOTAL_MESSAGE_COUNT-1))
188+
err := client.Output.WaitForWrite(fmt.Appendf(nil, "message %d\n", TOTAL_MESSAGE_COUNT-1))
174189
require.NoError(t, err, "failed to receive the last message (%d)", TOTAL_MESSAGE_COUNT-1)
175190

176191
wg.Wait()
177192

178-
// clientOutput is created by appending incoming messages as they arrive, so we are also test correct order here
179-
assert.Equal(t, expectedOutput, clientOutput.String())
193+
// client.Output is created by appending incoming messages as they arrive, so we are also test correct order here
194+
assert.Equal(t, expectedOutput, client.Output.String())
180195
}
181196

182197
// Tests handovers in quick succession with few messages in between.
@@ -189,8 +204,8 @@ func TestQuickHandover(t *testing.T) {
189204
requestHandoverTick := func() <-chan time.Time {
190205
return handoverChan
191206
}
192-
clientInputWriter, clientOutput := createTestClient(t, server.URL, requestHandoverTick, nil)
193-
defer clientInputWriter.Close()
207+
client := createTestClient(t, server.URL, requestHandoverTick, nil)
208+
defer client.Cleanup()
194209

195210
expectedOutput := ""
196211

@@ -201,18 +216,18 @@ func TestQuickHandover(t *testing.T) {
201216
handoverChan <- time.Now()
202217
}
203218
message := fmt.Appendf(nil, "message %d\n", i)
204-
_, err := clientInputWriter.Write(message)
219+
_, err := client.InputWriter.Write(message)
205220
if err != nil {
206221
t.Errorf("failed to write message %d: %v", i, err)
207222
}
208223
expectedOutput += string(message)
209224
}
210225
})
211226

212-
err := clientOutput.WaitForWrite(fmt.Appendf(nil, "message %d\n", 15))
227+
err := client.Output.WaitForWrite(fmt.Appendf(nil, "message %d\n", 15))
213228
require.NoError(t, err, "failed to receive the last message (%d)", 15)
214229

215230
wg.Wait()
216231

217-
assert.Equal(t, expectedOutput, clientOutput.String())
232+
assert.Equal(t, expectedOutput, client.Output.String())
218233
}

experimental/ssh/internal/proxy/proxy_test.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ func setupTestServer(ctx context.Context, t *testing.T) *TestProxy {
115115
}
116116
defer serverProxy.close()
117117
err = serverProxy.start(ctx, serverInput, serverOutput)
118-
if err != nil && !errors.Is(err, context.Canceled) {
118+
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrClosedPipe) {
119119
t.Errorf("server error: %v", err)
120120
return
121121
}
@@ -150,16 +150,19 @@ func setupTestClient(ctx context.Context, t *testing.T, serverURL string) *TestP
150150
err := clientProxy.connect(ctx)
151151
require.NoError(t, err)
152152

153-
go func() {
153+
wg := sync.WaitGroup{}
154+
wg.Go(func() {
154155
err := clientProxy.start(ctx, clientInput, clientOutput)
155-
if err != nil && !errors.Is(err, context.Canceled) {
156+
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrClosedPipe) {
156157
t.Errorf("proxy error: %v", err)
157158
}
158-
}()
159+
})
159160

160161
cleanup := func() {
161162
clientProxy.close()
163+
clientInput.Close()
162164
clientInputWriter.Close()
165+
wg.Wait()
163166
}
164167

165168
return &TestProxy{

0 commit comments

Comments
 (0)