diff --git a/benchmark/benchmark.go b/benchmark/benchmark.go index 5c3dcd51db9d..9f637deebcd5 100644 --- a/benchmark/benchmark.go +++ b/benchmark/benchmark.go @@ -23,6 +23,7 @@ package benchmark import ( "context" + "errors" "fmt" "io" "log" @@ -290,6 +291,23 @@ func DoUnaryCall(tc testgrpc.BenchmarkServiceClient, reqSize, respSize int) erro return nil } +// DoUnaryCallWithContext performs a unary RPC with propagated context and given stub and request and response sizes. +func DoUnaryCallWithContext(ctx context.Context, tc testgrpc.BenchmarkServiceClient, reqSize, respSize int) error { + pl := NewPayload(testpb.PayloadType_COMPRESSABLE, reqSize) + req := &testpb.SimpleRequest{ + ResponseType: pl.Type, + ResponseSize: int32(respSize), + Payload: pl, + } + if _, err := tc.UnaryCall(ctx, req); err != nil { + if status.Code(err) == codes.Canceled || errors.Is(err, context.Canceled) { + return err + } + return fmt.Errorf("/BenchmarkService/UnaryCall(_, _) = _, %v, want _, ", err) + } + return nil +} + // DoStreamingRoundTrip performs a round trip for a single streaming rpc. func DoStreamingRoundTrip(stream testgrpc.BenchmarkService_StreamingCallClient, reqSize, respSize int) error { pl := NewPayload(testpb.PayloadType_COMPRESSABLE, reqSize) diff --git a/benchmark/worker/benchmark_client.go b/benchmark/worker/benchmark_client.go index c28312dd6aab..35908b28fa75 100644 --- a/benchmark/worker/benchmark_client.go +++ b/benchmark/worker/benchmark_client.go @@ -20,6 +20,7 @@ package main import ( "context" + "errors" "flag" "math" rand "math/rand/v2" @@ -73,7 +74,8 @@ func (h *lockingHistogram) mergeInto(merged *stats.Histogram) { type benchmarkClient struct { closeConns func() - stop chan bool + ctx context.Context + cancel context.CancelFunc lastResetTime time.Time histogramOptions stats.HistogramOptions lockingHistograms []lockingHistogram @@ -223,6 +225,10 @@ func performRPCs(config *testpb.ClientConfig, conns []*grpc.ClientConn, bc *benc } func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) { + return startBenchmarkClientWithContext(context.Background(), config) +} + +func startBenchmarkClientWithContext(ctx context.Context, config *testpb.ClientConfig) (*benchmarkClient, error) { printClientConfig(config) // Set running environment like how many cores to use. @@ -233,6 +239,7 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) return nil, err } + ctx, cancel := context.WithCancel(ctx) rpcCountPerConn := int(config.OutstandingRpcsPerChannel) bc := &benchmarkClient{ histogramOptions: stats.HistogramOptions{ @@ -243,7 +250,8 @@ func startBenchmarkClient(config *testpb.ClientConfig) (*benchmarkClient, error) }, lockingHistograms: make([]lockingHistogram, rpcCountPerConn*len(conns)), - stop: make(chan bool), + ctx: ctx, + cancel: cancel, lastResetTime: time.Now(), closeConns: closeConns, rusageLastReset: syscall.GetRusage(), @@ -274,13 +282,14 @@ func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn i // before starting benchmark. if poissonLambda == nil { // Closed loop. for { - select { - case <-bc.stop: + if bc.ctx.Err() != nil { return - default: } start := time.Now() - if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil { + if err := benchmark.DoUnaryCallWithContext(bc.ctx, client, reqSize, respSize); err != nil { + if status.Code(err) == codes.Canceled || errors.Is(err, context.Canceled) { + return + } continue } elapse := time.Since(start) @@ -289,10 +298,12 @@ func (bc *benchmarkClient) unaryLoop(conns []*grpc.ClientConn, rpcCountPerConn i } else { // Open loop. timeBetweenRPCs := time.Duration((rand.ExpFloat64() / *poissonLambda) * float64(time.Second)) time.AfterFunc(timeBetweenRPCs, func() { + if bc.ctx.Err() != nil { + return + } bc.poissonUnary(client, idx, reqSize, respSize, *poissonLambda) }) } - }(idx) } } @@ -309,8 +320,11 @@ func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerCo // For each connection, create rpcCountPerConn goroutines to do rpc. for j := 0; j < rpcCountPerConn; j++ { c := testgrpc.NewBenchmarkServiceClient(conn) - stream, err := c.StreamingCall(context.Background()) + stream, err := c.StreamingCall(bc.ctx) if err != nil { + if status.Code(err) == codes.Canceled || errors.Is(err, context.Canceled) { + return + } logger.Fatalf("%v.StreamingCall(_) = _, %v", c, err) } idx := ic*rpcCountPerConn + j @@ -323,22 +337,26 @@ func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerCo // The worker client needs to wait for some time after client is created, // before starting benchmark. for { + if bc.ctx.Err() != nil { + return + } start := time.Now() if err := doRPC(stream, reqSize, respSize); err != nil { - return + if status.Code(err) == codes.Canceled || errors.Is(err, context.Canceled) { + return + } + continue } elapse := time.Since(start) bc.lockingHistograms[idx].add(int64(elapse)) - select { - case <-bc.stop: - return - default: - } } }(idx) } else { // Open loop. timeBetweenRPCs := time.Duration((rand.ExpFloat64() / *poissonLambda) * float64(time.Second)) time.AfterFunc(timeBetweenRPCs, func() { + if bc.ctx.Err() != nil { + return + } bc.poissonStreaming(stream, idx, reqSize, respSize, *poissonLambda, doRPC) }) } @@ -349,7 +367,12 @@ func (bc *benchmarkClient) streamingLoop(conns []*grpc.ClientConn, rpcCountPerCo func (bc *benchmarkClient) poissonUnary(client testgrpc.BenchmarkServiceClient, idx int, reqSize int, respSize int, lambda float64) { go func() { start := time.Now() - if err := benchmark.DoUnaryCall(client, reqSize, respSize); err != nil { + + if bc.ctx.Err() != nil { + return + } + + if err := benchmark.DoUnaryCallWithContext(bc.ctx, client, reqSize, respSize); err != nil { return } elapse := time.Since(start) @@ -357,6 +380,9 @@ func (bc *benchmarkClient) poissonUnary(client testgrpc.BenchmarkServiceClient, }() timeBetweenRPCs := time.Duration((rand.ExpFloat64() / lambda) * float64(time.Second)) time.AfterFunc(timeBetweenRPCs, func() { + if bc.ctx.Err() != nil { + return + } bc.poissonUnary(client, idx, reqSize, respSize, lambda) }) } @@ -364,6 +390,11 @@ func (bc *benchmarkClient) poissonUnary(client testgrpc.BenchmarkServiceClient, func (bc *benchmarkClient) poissonStreaming(stream testgrpc.BenchmarkService_StreamingCallClient, idx int, reqSize int, respSize int, lambda float64, doRPC func(testgrpc.BenchmarkService_StreamingCallClient, int, int) error) { go func() { start := time.Now() + + if bc.ctx.Err() != nil { + return + } + if err := doRPC(stream, reqSize, respSize); err != nil { return } @@ -372,6 +403,9 @@ func (bc *benchmarkClient) poissonStreaming(stream testgrpc.BenchmarkService_Str }() timeBetweenRPCs := time.Duration((rand.ExpFloat64() / lambda) * float64(time.Second)) time.AfterFunc(timeBetweenRPCs, func() { + if bc.ctx.Err() != nil { + return + } bc.poissonStreaming(stream, idx, reqSize, respSize, lambda, doRPC) }) } @@ -430,6 +464,8 @@ func (bc *benchmarkClient) getStats(reset bool) *testpb.ClientStats { } func (bc *benchmarkClient) shutdown() { - close(bc.stop) + if bc.cancel != nil { + bc.cancel() + } bc.closeConns() }