diff --git a/test/integration/util.go b/test/integration/util.go index d78b76e28..9f0dcde7b 100644 --- a/test/integration/util.go +++ b/test/integration/util.go @@ -54,7 +54,13 @@ func SendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, return res, err } -func StreamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, requests []*extProcPb.ProcessingRequest, expectedResponses int) ([]*extProcPb.ProcessingResponse, error) { +// StreamedRequest sends a series of requests and collects the specified number of responses. +func StreamedRequest( + t *testing.T, + client extProcPb.ExternalProcessor_ProcessClient, + requests []*extProcPb.ProcessingRequest, + expectedResponses int, +) ([]*extProcPb.ProcessingResponse, error) { for _, req := range requests { t.Logf("Sending request: %v", req) if err := client.Send(req); err != nil { @@ -62,27 +68,35 @@ func StreamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessCli return nil, err } } + responses := []*extProcPb.ProcessingResponse{} + for i := range expectedResponses { + type recvResult struct { + res *extProcPb.ProcessingResponse + err error + } + recvChan := make(chan recvResult, 1) - // Make an incredible simple timeout func in the case where - // there is less than the expected amount of responses; bail and fail. - var simpleTimeout bool - go func() { - time.Sleep(10 * time.Second) - simpleTimeout = true - }() + go func() { + res, err := client.Recv() + recvChan <- recvResult{res, err} + }() - for range expectedResponses { - if simpleTimeout { - break - } - res, err := client.Recv() - if err != nil && err != io.EOF { - t.Logf("Failed to receive: %v", err) - return nil, err + select { + case <-time.After(10 * time.Second): + t.Logf("Timeout waiting for response %d of %d", i+1, expectedResponses) + return responses, nil + case result := <-recvChan: + if result.err != nil { + if result.err == io.EOF { + return responses, nil + } + t.Logf("Failed to receive: %v", result.err) + return nil, result.err + } + t.Logf("Received response %+v", result.res) + responses = append(responses, result.res) } - t.Logf("Received response %+v", res) - responses = append(responses, res) } return responses, nil }