diff --git a/experimental/ssh/internal/proxy/client_server_test.go b/experimental/ssh/internal/proxy/client_server_test.go index daf3ab2b58..1701675fe4 100644 --- a/experimental/ssh/internal/proxy/client_server_test.go +++ b/experimental/ssh/internal/proxy/client_server_test.go @@ -31,7 +31,13 @@ func createTestServer(t *testing.T, maxClients int, shutdownDelay time.Duration) return httptest.NewServer(proxyServer) } -func createTestClient(t *testing.T, serverURL string, requestHandoverTick func() <-chan time.Time, errChan chan error) (io.WriteCloser, *testBuffer) { +type testClient struct { + InputWriter io.WriteCloser + Output *testBuffer + Cleanup func() +} + +func createTestClient(t *testing.T, serverURL string, requestHandoverTick func() <-chan time.Time, errChan chan error) *testClient { ctx := cmdio.MockDiscard(t.Context()) clientInput, clientInputWriter := io.Pipe() clientOutput := newTestBuffer(t) @@ -46,94 +52,103 @@ func createTestClient(t *testing.T, serverURL string, requestHandoverTick func() return time.After(time.Hour) } } - go func() { + wg := sync.WaitGroup{} + wg.Go(func() { err := RunClientProxy(ctx, clientInput, clientOutput, requestHandoverTick, createConn) - if err != nil && !errors.Is(err, context.Canceled) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrClosedPipe) { if errChan != nil { errChan <- err } else { t.Errorf("client error: %v", err) } } - }() - return clientInputWriter, clientOutput + }) + return &testClient{ + InputWriter: clientInputWriter, + Output: clientOutput, + Cleanup: func() { + clientInput.Close() + clientInputWriter.Close() + wg.Wait() + }, + } } func TestClientServerEcho(t *testing.T) { server := createTestServer(t, 2, time.Hour) defer server.Close() - clientInputWriter, clientOutput := createTestClient(t, server.URL, nil, nil) - defer clientInputWriter.Close() + client := createTestClient(t, server.URL, nil, nil) + defer client.Cleanup() testMsg1 := []byte("test message 1\n") - _, err := clientInputWriter.Write(testMsg1) + _, err := client.InputWriter.Write(testMsg1) require.NoError(t, err) - err = clientOutput.AssertWrite(testMsg1) + err = client.Output.AssertWrite(testMsg1) require.NoError(t, err) testMsg2 := []byte("test message 2\n") - _, err = clientInputWriter.Write(testMsg2) + _, err = client.InputWriter.Write(testMsg2) require.NoError(t, err) - err = clientOutput.AssertWrite(testMsg2) + err = client.Output.AssertWrite(testMsg2) require.NoError(t, err) expectedOutput := fmt.Sprintf("%s%s", testMsg1, testMsg2) - assert.Equal(t, expectedOutput, clientOutput.String()) + assert.Equal(t, expectedOutput, client.Output.String()) } func TestMultipleClients(t *testing.T) { server := createTestServer(t, 2, time.Hour) defer server.Close() - clientInputWriter1, clientOutput1 := createTestClient(t, server.URL, nil, nil) - defer clientInputWriter1.Close() - clientInputWriter2, clientOutput2 := createTestClient(t, server.URL, nil, nil) - defer clientInputWriter2.Close() + client1 := createTestClient(t, server.URL, nil, nil) + defer client1.Cleanup() + client2 := createTestClient(t, server.URL, nil, nil) + defer client2.Cleanup() messageCount := 10 expectedClientOutput1 := "" expectedClientOutput2 := "" for i := range messageCount { message := fmt.Appendf(nil, "client 1 message %d\n", i) - _, err := clientInputWriter1.Write(message) + _, err := client1.InputWriter.Write(message) require.NoError(t, err) - err = clientOutput1.AssertWrite(message) + err = client1.Output.AssertWrite(message) require.NoError(t, err) expectedClientOutput1 += string(message) message = fmt.Appendf(nil, "client 2 message %d\n", i) - _, err = clientInputWriter2.Write(message) + _, err = client2.InputWriter.Write(message) require.NoError(t, err) - err = clientOutput2.AssertWrite(message) + err = client2.Output.AssertWrite(message) require.NoError(t, err) expectedClientOutput2 += string(message) } - assert.Equal(t, expectedClientOutput1, clientOutput1.String()) - assert.Equal(t, expectedClientOutput2, clientOutput2.String()) + assert.Equal(t, expectedClientOutput1, client1.Output.String()) + assert.Equal(t, expectedClientOutput2, client2.Output.String()) } func TestMaxClients(t *testing.T) { maxClients := 2 server := createTestServer(t, maxClients, time.Hour) defer server.Close() - clientInputWriter1, clientOutput1 := createTestClient(t, server.URL, nil, nil) - defer clientInputWriter1.Close() - clientInputWriter2, clientOutput2 := createTestClient(t, server.URL, nil, nil) - defer clientInputWriter2.Close() + client1 := createTestClient(t, server.URL, nil, nil) + defer client1.Cleanup() + client2 := createTestClient(t, server.URL, nil, nil) + defer client2.Cleanup() testMsg1 := []byte("test message 1\n") - _, err := clientInputWriter1.Write(testMsg1) + _, err := client1.InputWriter.Write(testMsg1) require.NoError(t, err) - err = clientOutput1.AssertWrite(testMsg1) + err = client1.Output.AssertWrite(testMsg1) require.NoError(t, err) - _, err = clientInputWriter2.Write(testMsg1) + _, err = client2.InputWriter.Write(testMsg1) require.NoError(t, err) - err = clientOutput2.AssertWrite(testMsg1) + err = client2.Output.AssertWrite(testMsg1) require.NoError(t, err) errChan := make(chan error, 1) - clientInputWriter3, _ := createTestClient(t, server.URL, nil, errChan) - defer clientInputWriter3.Close() + client3 := createTestClient(t, server.URL, nil, errChan) + defer client3.Cleanup() select { case err = <-errChan: require.Error(t, err) @@ -150,8 +165,8 @@ func TestHandover(t *testing.T) { requestHandoverTick := func() <-chan time.Time { return handoverChan } - clientInputWriter, clientOutput := createTestClient(t, server.URL, requestHandoverTick, nil) - defer clientInputWriter.Close() + client := createTestClient(t, server.URL, requestHandoverTick, nil) + defer client.Cleanup() expectedOutput := "" @@ -162,7 +177,7 @@ func TestHandover(t *testing.T) { handoverChan <- time.Now() } message := fmt.Appendf(nil, "message %d\n", i) - _, err := clientInputWriter.Write(message) + _, err := client.InputWriter.Write(message) if err != nil { t.Errorf("failed to write message %d: %v", i, err) } @@ -170,13 +185,13 @@ func TestHandover(t *testing.T) { } }) - err := clientOutput.WaitForWrite(fmt.Appendf(nil, "message %d\n", TOTAL_MESSAGE_COUNT-1)) + err := client.Output.WaitForWrite(fmt.Appendf(nil, "message %d\n", TOTAL_MESSAGE_COUNT-1)) require.NoError(t, err, "failed to receive the last message (%d)", TOTAL_MESSAGE_COUNT-1) wg.Wait() - // clientOutput is created by appending incoming messages as they arrive, so we are also test correct order here - assert.Equal(t, expectedOutput, clientOutput.String()) + // client.Output is created by appending incoming messages as they arrive, so we are also test correct order here + assert.Equal(t, expectedOutput, client.Output.String()) } // Tests handovers in quick succession with few messages in between. @@ -189,8 +204,8 @@ func TestQuickHandover(t *testing.T) { requestHandoverTick := func() <-chan time.Time { return handoverChan } - clientInputWriter, clientOutput := createTestClient(t, server.URL, requestHandoverTick, nil) - defer clientInputWriter.Close() + client := createTestClient(t, server.URL, requestHandoverTick, nil) + defer client.Cleanup() expectedOutput := "" @@ -201,7 +216,7 @@ func TestQuickHandover(t *testing.T) { handoverChan <- time.Now() } message := fmt.Appendf(nil, "message %d\n", i) - _, err := clientInputWriter.Write(message) + _, err := client.InputWriter.Write(message) if err != nil { t.Errorf("failed to write message %d: %v", i, err) } @@ -209,10 +224,10 @@ func TestQuickHandover(t *testing.T) { } }) - err := clientOutput.WaitForWrite(fmt.Appendf(nil, "message %d\n", 15)) + err := client.Output.WaitForWrite(fmt.Appendf(nil, "message %d\n", 15)) require.NoError(t, err, "failed to receive the last message (%d)", 15) wg.Wait() - assert.Equal(t, expectedOutput, clientOutput.String()) + assert.Equal(t, expectedOutput, client.Output.String()) } diff --git a/experimental/ssh/internal/proxy/proxy_test.go b/experimental/ssh/internal/proxy/proxy_test.go index 2e0a534d1f..0e1db9021e 100644 --- a/experimental/ssh/internal/proxy/proxy_test.go +++ b/experimental/ssh/internal/proxy/proxy_test.go @@ -115,7 +115,7 @@ func setupTestServer(ctx context.Context, t *testing.T) *TestProxy { } defer serverProxy.close() err = serverProxy.start(ctx, serverInput, serverOutput) - if err != nil && !errors.Is(err, context.Canceled) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrClosedPipe) { t.Errorf("server error: %v", err) return } @@ -150,16 +150,19 @@ func setupTestClient(ctx context.Context, t *testing.T, serverURL string) *TestP err := clientProxy.connect(ctx) require.NoError(t, err) - go func() { + wg := sync.WaitGroup{} + wg.Go(func() { err := clientProxy.start(ctx, clientInput, clientOutput) - if err != nil && !errors.Is(err, context.Canceled) { + if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrClosedPipe) { t.Errorf("proxy error: %v", err) } - }() + }) cleanup := func() { clientProxy.close() + clientInput.Close() clientInputWriter.Close() + wg.Wait() } return &TestProxy{