Skip to content

Commit 678b367

Browse files
committed
feat: simplify client-side Stream using
1 parent 9807a06 commit 678b367

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2174
-122
lines changed

client/callopt/streamcall/call_options.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
package streamcall
1818

1919
import (
20+
"fmt"
2021
"strings"
2122
"time"
2223

2324
"github.com/cloudwego/kitex/client/callopt"
25+
"github.com/cloudwego/kitex/pkg/streaming"
2426
)
2527

2628
// These options are directly translated from callopt.Option(s). If you can't find the option with the
@@ -79,3 +81,13 @@ func WithStreamTimeout(d time.Duration) Option {
7981
o.StreamOptions.StreamTimeout = d
8082
}}
8183
}
84+
85+
// WithCallbackConfig adds and calls FinishCallback when client-side stream
86+
// has finished.
87+
func WithCallbackConfig(cfg *streaming.FinishCallback) Option {
88+
return Option{f: func(o *callopt.CallOptions, di *strings.Builder) {
89+
fmt.Fprintf(di, "WithCallbackConfig(%+v)", cfg)
90+
91+
o.StreamOptions.FinishCallback = cfg
92+
}}
93+
}

client/client.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,9 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf
833833
if sopt.StreamTimeout > 0 {
834834
cfg.SetStreamTimeout(sopt.StreamTimeout)
835835
}
836+
if sopt.FinishCallback != nil {
837+
cfg.SetStreamCallbackConfig(sopt.FinishCallback)
838+
}
836839

837840
ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
838841

@@ -850,6 +853,9 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf
850853
if callOpts.StreamOptions.StreamTimeout != 0 {
851854
cfg.SetStreamTimeout(callOpts.StreamOptions.StreamTimeout)
852855
}
856+
if callOpts.StreamOptions.FinishCallback != nil {
857+
cfg.SetStreamCallbackConfig(callOpts.StreamOptions.FinishCallback)
858+
}
853859
}
854860

855861
return ctx, ri, callOpts

client/option_stream.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222

2323
"github.com/cloudwego/kitex/internal/client"
2424
"github.com/cloudwego/kitex/pkg/endpoint/cep"
25+
"github.com/cloudwego/kitex/pkg/streaming"
2526
"github.com/cloudwego/kitex/pkg/utils"
2627
)
2728

@@ -66,6 +67,16 @@ func WithStreamTimeout(d time.Duration) StreamOption {
6667
}}
6768
}
6869

70+
// WithStreamCallbackConfig adds and calls FinishCallback when client-side stream
71+
// has finished.
72+
func WithStreamCallbackConfig(cfg *streaming.FinishCallback) StreamOption {
73+
return StreamOption{F: func(o *StreamOptions, di *utils.Slice) {
74+
di.Push(fmt.Sprintf("WithStreamCallbackConfig(%+v)", cfg))
75+
76+
o.FinishCallback = cfg
77+
}}
78+
}
79+
6980
// WithStreamMiddleware add middleware for stream.
7081
func WithStreamMiddleware(mw cep.StreamMiddleware) StreamOption {
7182
return StreamOption{F: func(o *StreamOptions, di *utils.Slice) {

client/stream.go

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"fmt"
2222
"io"
23+
"runtime/debug"
2324
"sync/atomic"
2425
"time"
2526

@@ -29,6 +30,7 @@ import (
2930
"github.com/cloudwego/kitex/pkg/endpoint"
3031
"github.com/cloudwego/kitex/pkg/endpoint/cep"
3132
"github.com/cloudwego/kitex/pkg/kerrors"
33+
"github.com/cloudwego/kitex/pkg/klog"
3234
"github.com/cloudwego/kitex/pkg/remote"
3335
"github.com/cloudwego/kitex/pkg/remote/remotecli"
3436
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes"
@@ -262,9 +264,6 @@ func newStream(ctx context.Context, cancel context.CancelFunc, s streaming.Clien
262264
st.grpcStream.st = st
263265
}
264266
}
265-
if register, ok := s.(streaming.CloseCallbackRegister); ok {
266-
register.RegisterCloseCallback(st.DoFinish)
267-
}
268267
return st
269268
}
270269

@@ -363,18 +362,44 @@ func (s *stream) DoFinish(err error) {
363362
// already called
364363
return
365364
}
365+
366366
// release stream timeout cancel
367367
if s.cancelFunc != nil {
368368
s.cancelFunc()
369369
}
370-
if !isRPCError(err) {
371-
// only rpc errors are reported
372-
err = nil
370+
371+
// reporting and release connection
372+
reportErr := err
373+
if err == nil || err == io.EOF {
374+
reportErr = nil
375+
}
376+
// If the client-side callback returns a biz error or manually calls streaming.FinishStream/streaming.FinishClientStream with a biz error,
377+
// it needs to be set.
378+
if bizErr, bizOk := err.(kerrors.BizStatusErrorIface); bizOk {
379+
if setter, ok := s.ri.Invocation().(rpcinfo.InvocationSetter); ok {
380+
setter.SetBizStatusErr(bizErr)
381+
}
373382
}
374383
if s.scm != nil {
375-
s.scm.ReleaseConn(err, s.ri)
384+
s.scm.ReleaseConn(reportErr, s.ri)
385+
}
386+
s.kc.opt.TracerCtl.DoFinish(s.ctx, s.ri, reportErr)
387+
388+
// processing callback with original err
389+
stCfg := s.ri.Config().StreamCallbackConfig()
390+
if err == nil || err == io.EOF || stCfg == nil {
391+
return
392+
}
393+
defer func() {
394+
if r := recover(); r != nil {
395+
klog.CtxWarnf(s.ctx, "Panic happened during stream DoFinish. This may caused by injected stream Callback: error=%v, stack=%s", r, string(debug.Stack()))
396+
}
397+
}()
398+
if s.isGRPC {
399+
handleGRPC(s.ctx, s.ri, err, stCfg)
400+
} else {
401+
handleTTStream(s.ctx, s.ri, err, stCfg)
376402
}
377-
s.kc.opt.TracerCtl.DoFinish(s.ctx, s.ri, err)
378403
}
379404

380405
func (s *stream) GetGRPCStream() streaming.Stream {
@@ -478,18 +503,6 @@ func (s *grpcStream) DoFinish(err error) {
478503
s.st.DoFinish(err)
479504
}
480505

481-
func isRPCError(err error) bool {
482-
if err == nil {
483-
return false
484-
}
485-
if err == io.EOF {
486-
return false
487-
}
488-
_, isBizStatusError := err.(kerrors.BizStatusErrorIface)
489-
// if a tracer needs to get the BizStatusError, it should read from rpcinfo.invocation.bizStatusErr
490-
return !isBizStatusError
491-
}
492-
493506
func callWithTimeout(tm time.Duration, call func() error, buildTmErr func(time.Duration) error, cancel func(error)) error {
494507
if tm <= 0 {
495508
return call()

0 commit comments

Comments
 (0)