@@ -40,7 +40,8 @@ type Worker struct {
4040 metadataProvider MetadataProviderFunc
4141 msgProvider StreamMessageProviderFunc
4242
43- streamRecv StreamRecvMsgInterceptFunc
43+ streamRecv StreamRecvMsgInterceptFunc
44+ streamInterceptorProviderFunc StreamInterceptorProviderFunc
4445}
4546
4647func (w * Worker ) runWorker () error {
@@ -83,6 +84,13 @@ func (w *Worker) makeRequest(tv TickValue) error {
8384
8485 ctd := newCallData (w .mtd , w .workerID , reqNum , ! w .config .disableTemplateFuncs , ! w .config .disableTemplateData , w .config .funcs )
8586
87+ var streamInterceptor StreamInterceptor
88+ if w .mtd .IsClientStreaming () || w .mtd .IsServerStreaming () {
89+ if w .streamInterceptorProviderFunc != nil {
90+ streamInterceptor = w .streamInterceptorProviderFunc (ctd )
91+ }
92+ }
93+
8694 reqMD , err := w .metadataProvider (ctd )
8795 if err != nil {
8896 return err
@@ -115,6 +123,8 @@ func (w *Worker) makeRequest(tv TickValue) error {
115123 var msgProvider StreamMessageProviderFunc
116124 if w .msgProvider != nil {
117125 msgProvider = w .msgProvider
126+ } else if streamInterceptor != nil {
127+ msgProvider = streamInterceptor .Send
118128 } else if w .mtd .IsClientStreaming () {
119129 if w .config .streamDynamicMessages {
120130 mp , err := newDynamicMessageProvider (w .mtd , w .config .data , w .config .streamCallCount , ! w .config .disableTemplateFuncs , ! w .config .disableTemplateData )
@@ -155,11 +165,11 @@ func (w *Worker) makeRequest(tv TickValue) error {
155165
156166 // RPC errors are handled via stats handler
157167 if w .mtd .IsClientStreaming () && w .mtd .IsServerStreaming () {
158- _ = w .makeBidiRequest (& ctx , ctd , msgProvider )
168+ _ = w .makeBidiRequest (& ctx , ctd , msgProvider , streamInterceptor )
159169 } else if w .mtd .IsClientStreaming () {
160170 _ = w .makeClientStreamingRequest (& ctx , ctd , msgProvider )
161171 } else if w .mtd .IsServerStreaming () {
162- _ = w .makeServerStreamingRequest (& ctx , inputs [0 ])
172+ _ = w .makeServerStreamingRequest (& ctx , inputs [0 ], streamInterceptor )
163173 } else {
164174 _ = w .makeUnaryRequest (& ctx , reqMD , inputs [0 ])
165175 }
@@ -314,7 +324,7 @@ func (w *Worker) makeClientStreamingRequest(ctx *context.Context,
314324 return nil
315325}
316326
317- func (w * Worker ) makeServerStreamingRequest (ctx * context.Context , input * dynamic.Message ) error {
327+ func (w * Worker ) makeServerStreamingRequest (ctx * context.Context , input * dynamic.Message , streamInterceptor StreamInterceptor ) error {
318328 var callOptions = []grpc.CallOption {}
319329 if w .config .enableCompression {
320330 callOptions = append (callOptions , grpc .UseCompressor (gzip .Name ))
@@ -388,6 +398,18 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic
388398 }
389399 }
390400
401+ if streamInterceptor != nil {
402+ if converted , ok := res .(* dynamic.Message ); ok {
403+ err = streamInterceptor .Recv (converted , err )
404+ if errors .Is (err , ErrEndStream ) && ! interceptCanceled {
405+ interceptCanceled = true
406+ err = nil
407+
408+ callCancel ()
409+ }
410+ }
411+ }
412+
391413 if err != nil {
392414 if err == io .EOF {
393415 err = nil
@@ -415,7 +437,7 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic
415437}
416438
417439func (w * Worker ) makeBidiRequest (ctx * context.Context ,
418- ctd * CallData , messageProvider StreamMessageProviderFunc ) error {
440+ ctd * CallData , messageProvider StreamMessageProviderFunc , streamInterceptor StreamInterceptor ) error {
419441
420442 var callOptions = []grpc.CallOption {}
421443
@@ -494,6 +516,19 @@ func (w *Worker) makeBidiRequest(ctx *context.Context,
494516 }
495517 }
496518
519+ if streamInterceptor != nil {
520+ if converted , ok := res .(* dynamic.Message ); ok {
521+ iErr := streamInterceptor .Recv (converted , recvErr )
522+ if errors .Is (iErr , ErrEndStream ) && ! interceptCanceled {
523+ interceptCanceled = true
524+ if len (cancel ) == 0 {
525+ cancel <- struct {}{}
526+ }
527+ recvErr = nil
528+ }
529+ }
530+ }
531+
497532 if recvErr != nil {
498533 close (recvDone )
499534 break
0 commit comments