Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions test/integration/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,35 +54,49 @@ func SendRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient,
return res, err
}

func StreamedRequest(t *testing.T, client extProcPb.ExternalProcessor_ProcessClient, requests []*extProcPb.ProcessingRequest, expectedResponses int) ([]*extProcPb.ProcessingResponse, error) {
// StreamedRequest sends a series of requests and collects the specified number of responses.
func StreamedRequest(
t *testing.T,
client extProcPb.ExternalProcessor_ProcessClient,
requests []*extProcPb.ProcessingRequest,
expectedResponses int,
) ([]*extProcPb.ProcessingResponse, error) {
for _, req := range requests {
t.Logf("Sending request: %v", req)
if err := client.Send(req); err != nil {
t.Logf("Failed to send request %+v: %v", req, err)
return nil, err
}
}

responses := []*extProcPb.ProcessingResponse{}
for i := range expectedResponses {
type recvResult struct {
res *extProcPb.ProcessingResponse
err error
}
recvChan := make(chan recvResult, 1)

// Make an incredible simple timeout func in the case where
// there is less than the expected amount of responses; bail and fail.
var simpleTimeout bool
go func() {
time.Sleep(10 * time.Second)
simpleTimeout = true
}()
go func() {
res, err := client.Recv()
recvChan <- recvResult{res, err}
}()

for range expectedResponses {
if simpleTimeout {
break
}
res, err := client.Recv()
if err != nil && err != io.EOF {
t.Logf("Failed to receive: %v", err)
return nil, err
select {
case <-time.After(10 * time.Second):
t.Logf("Timeout waiting for response %d of %d", i+1, expectedResponses)
return responses, nil
case result := <-recvChan:
if result.err != nil {
if result.err == io.EOF {
return responses, nil
}
t.Logf("Failed to receive: %v", result.err)
return nil, result.err
}
t.Logf("Received response %+v", result.res)
responses = append(responses, result.res)
}
t.Logf("Received response %+v", res)
responses = append(responses, res)
}
return responses, nil
}
Expand Down