Skip to content

Commit 149d520

Browse files
committed
Adds a stream interceptor to keep state between the send and receive calls
1 parent 811b53a commit 149d520

File tree

4 files changed

+78
-22
lines changed

4 files changed

+78
-22
lines changed

runner/data.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ type StreamMessageProviderFunc func(*CallData) (*dynamic.Message, error)
4444
// Clients can return ErrEndStream to end the call early
4545
type StreamRecvMsgInterceptFunc func(*dynamic.Message, error) error
4646

47+
// StreamInterceptorProviderFunc is an interface for a function invoked to generate a stream interceptor
48+
type StreamInterceptorProviderFunc func(*CallData) StreamInterceptor
49+
50+
// StreamInterceptor is an interface for sending and receiving stream messages.
51+
// The interceptor can keep shared state for the send and receive calls.
52+
type StreamInterceptor interface {
53+
Recv(*dynamic.Message, error) error
54+
Send(*CallData) (*dynamic.Message, error)
55+
}
56+
4757
type dataProvider struct {
4858
binary bool
4959
data []byte

runner/options.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,13 @@ type RunConfig struct {
129129
disableTemplateData bool
130130

131131
// misc
132-
name string
133-
cpus int
134-
tags []byte
135-
skipFirst int
136-
countErrors bool
137-
recvMsgFunc StreamRecvMsgInterceptFunc
132+
name string
133+
cpus int
134+
tags []byte
135+
skipFirst int
136+
countErrors bool
137+
recvMsgFunc StreamRecvMsgInterceptFunc
138+
streamInterceptorProviderFunc StreamInterceptorProviderFunc
138139
}
139140

140141
// Option controls some aspect of run
@@ -1034,6 +1035,15 @@ func WithStreamRecvMsgIntercept(fn StreamRecvMsgInterceptFunc) Option {
10341035
}
10351036
}
10361037

1038+
// WithStreamInterceptor specifies the stream interceptor provider function
1039+
func WithStreamInterceptorProviderFunc(interceptor StreamInterceptorProviderFunc) Option {
1040+
return func(o *RunConfig) error {
1041+
o.streamInterceptorProviderFunc = interceptor
1042+
1043+
return nil
1044+
}
1045+
}
1046+
10371047
// WithDataProvider provides custom data provider
10381048
//
10391049
// WithDataProvider(func(*CallData) ([]*dynamic.Message, error) {

runner/requester.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -389,17 +389,18 @@ func (b *Requester) runWorkers(wt load.WorkerTicker, p load.Pacer) error {
389389
}
390390

391391
w := Worker{
392-
ticks: ticks,
393-
active: true,
394-
stub: b.stubs[n],
395-
mtd: b.mtd,
396-
config: b.config,
397-
stopCh: make(chan bool),
398-
workerID: wID,
399-
dataProvider: b.dataProvider,
400-
metadataProvider: b.metadataProvider,
401-
streamRecv: b.config.recvMsgFunc,
402-
msgProvider: b.config.dataStreamFunc,
392+
ticks: ticks,
393+
active: true,
394+
stub: b.stubs[n],
395+
mtd: b.mtd,
396+
config: b.config,
397+
stopCh: make(chan bool),
398+
workerID: wID,
399+
dataProvider: b.dataProvider,
400+
metadataProvider: b.metadataProvider,
401+
streamRecv: b.config.recvMsgFunc,
402+
msgProvider: b.config.dataStreamFunc,
403+
streamInterceptorProviderFunc: b.config.streamInterceptorProviderFunc,
403404
}
404405

405406
wc++ // increment worker id

runner/worker.go

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ type Worker struct {
4040
metadataProvider MetadataProviderFunc
4141
msgProvider StreamMessageProviderFunc
4242

43-
streamRecv StreamRecvMsgInterceptFunc
43+
streamRecv StreamRecvMsgInterceptFunc
44+
streamInterceptorProviderFunc StreamInterceptorProviderFunc
4445
}
4546

4647
func (w *Worker) runWorker() error {
@@ -83,6 +84,13 @@ func (w *Worker) makeRequest(tv TickValue) error {
8384

8485
ctd := newCallData(w.mtd, w.workerID, reqNum, !w.config.disableTemplateFuncs, !w.config.disableTemplateData, w.config.funcs)
8586

87+
var streamInterceptor StreamInterceptor
88+
if w.mtd.IsClientStreaming() || w.mtd.IsServerStreaming() {
89+
if w.streamInterceptorProviderFunc != nil {
90+
streamInterceptor = w.streamInterceptorProviderFunc(ctd)
91+
}
92+
}
93+
8694
reqMD, err := w.metadataProvider(ctd)
8795
if err != nil {
8896
return err
@@ -115,6 +123,8 @@ func (w *Worker) makeRequest(tv TickValue) error {
115123
var msgProvider StreamMessageProviderFunc
116124
if w.msgProvider != nil {
117125
msgProvider = w.msgProvider
126+
} else if streamInterceptor != nil {
127+
msgProvider = streamInterceptor.Send
118128
} else if w.mtd.IsClientStreaming() {
119129
if w.config.streamDynamicMessages {
120130
mp, err := newDynamicMessageProvider(w.mtd, w.config.data, w.config.streamCallCount, !w.config.disableTemplateFuncs, !w.config.disableTemplateData)
@@ -155,11 +165,11 @@ func (w *Worker) makeRequest(tv TickValue) error {
155165

156166
// RPC errors are handled via stats handler
157167
if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() {
158-
_ = w.makeBidiRequest(&ctx, ctd, msgProvider)
168+
_ = w.makeBidiRequest(&ctx, ctd, msgProvider, streamInterceptor)
159169
} else if w.mtd.IsClientStreaming() {
160170
_ = w.makeClientStreamingRequest(&ctx, ctd, msgProvider)
161171
} else if w.mtd.IsServerStreaming() {
162-
_ = w.makeServerStreamingRequest(&ctx, inputs[0])
172+
_ = w.makeServerStreamingRequest(&ctx, inputs[0], streamInterceptor)
163173
} else {
164174
_ = w.makeUnaryRequest(&ctx, reqMD, inputs[0])
165175
}
@@ -314,7 +324,7 @@ func (w *Worker) makeClientStreamingRequest(ctx *context.Context,
314324
return nil
315325
}
316326

317-
func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic.Message) error {
327+
func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic.Message, streamInterceptor StreamInterceptor) error {
318328
var callOptions = []grpc.CallOption{}
319329
if w.config.enableCompression {
320330
callOptions = append(callOptions, grpc.UseCompressor(gzip.Name))
@@ -388,6 +398,18 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic
388398
}
389399
}
390400

401+
if streamInterceptor != nil {
402+
if converted, ok := res.(*dynamic.Message); ok {
403+
err = streamInterceptor.Recv(converted, err)
404+
if errors.Is(err, ErrEndStream) && !interceptCanceled {
405+
interceptCanceled = true
406+
err = nil
407+
408+
callCancel()
409+
}
410+
}
411+
}
412+
391413
if err != nil {
392414
if err == io.EOF {
393415
err = nil
@@ -415,7 +437,7 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic
415437
}
416438

417439
func (w *Worker) makeBidiRequest(ctx *context.Context,
418-
ctd *CallData, messageProvider StreamMessageProviderFunc) error {
440+
ctd *CallData, messageProvider StreamMessageProviderFunc, streamInterceptor StreamInterceptor) error {
419441

420442
var callOptions = []grpc.CallOption{}
421443

@@ -494,6 +516,19 @@ func (w *Worker) makeBidiRequest(ctx *context.Context,
494516
}
495517
}
496518

519+
if streamInterceptor != nil {
520+
if converted, ok := res.(*dynamic.Message); ok {
521+
iErr := streamInterceptor.Recv(converted, recvErr)
522+
if errors.Is(iErr, ErrEndStream) && !interceptCanceled {
523+
interceptCanceled = true
524+
if len(cancel) == 0 {
525+
cancel <- struct{}{}
526+
}
527+
recvErr = nil
528+
}
529+
}
530+
}
531+
497532
if recvErr != nil {
498533
close(recvDone)
499534
break

0 commit comments

Comments
 (0)