Skip to content

Commit 18220b6

Browse files
committed
feat: simplify client-side Stream using
1 parent fbade94 commit 18220b6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2233
-122
lines changed

client/callopt/streamcall/call_options.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
package streamcall
1818

1919
import (
20+
"fmt"
2021
"strings"
2122
"time"
2223

2324
"github.com/cloudwego/kitex/client/callopt"
25+
"github.com/cloudwego/kitex/pkg/streaming"
2426
)
2527

2628
// These options are directly translated from callopt.Option(s). If you can't find the option with the
@@ -79,3 +81,29 @@ func WithStreamTimeout(d time.Duration) Option {
7981
o.StreamOptions.StreamTimeout = d
8082
}}
8183
}
84+
85+
// WithIndependentLifecycle ensures that newly created client-side Streams are not controlled
86+
// by the lifecycle of the passed ctx.
87+
func WithIndependentLifecycle(flag bool) Option {
88+
return Option{f: func(o *callopt.CallOptions, di *strings.Builder) {
89+
di.WriteString("WithIndependentLifecycle(")
90+
if flag {
91+
di.WriteString("true")
92+
} else {
93+
di.WriteString("false")
94+
}
95+
di.WriteString(")")
96+
97+
o.StreamOptions.IndependentLifecycle = flag
98+
}}
99+
}
100+
101+
// WithCallbackConfig adds and calls FinishCallback when client-side stream
102+
// has finished.
103+
func WithCallbackConfig(cfg *streaming.FinishCallback) Option {
104+
return Option{f: func(o *callopt.CallOptions, di *strings.Builder) {
105+
di.WriteString(fmt.Sprintf("WithCallbackConfig(%+v)", cfg))
106+
107+
o.StreamOptions.FinishCallback = cfg
108+
}}
109+
}

client/client.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,12 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf
833833
if sopt.StreamTimeout > 0 {
834834
cfg.SetStreamTimeout(sopt.StreamTimeout)
835835
}
836+
if sopt.FinishCallback != nil {
837+
cfg.SetStreamCallbackConfig(sopt.FinishCallback)
838+
}
839+
if sopt.IndependentLifecycle {
840+
cfg.SetStreamIndependentLifecycle(sopt.IndependentLifecycle)
841+
}
836842

837843
ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
838844

@@ -850,6 +856,12 @@ func initRPCInfo(ctx context.Context, method string, opt *client.Options, svcInf
850856
if callOpts.StreamOptions.StreamTimeout != 0 {
851857
cfg.SetStreamTimeout(callOpts.StreamOptions.StreamTimeout)
852858
}
859+
if callOpts.StreamOptions.FinishCallback != nil {
860+
cfg.SetStreamCallbackConfig(callOpts.StreamOptions.FinishCallback)
861+
}
862+
if callOpts.StreamOptions.IndependentLifecycle {
863+
cfg.SetStreamIndependentLifecycle(callOpts.StreamOptions.IndependentLifecycle)
864+
}
853865
}
854866

855867
return ctx, ri, callOpts

client/option_stream.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222

2323
"github.com/cloudwego/kitex/internal/client"
2424
"github.com/cloudwego/kitex/pkg/endpoint/cep"
25+
"github.com/cloudwego/kitex/pkg/streaming"
2526
"github.com/cloudwego/kitex/pkg/utils"
2627
)
2728

@@ -66,6 +67,26 @@ func WithStreamTimeout(d time.Duration) StreamOption {
6667
}}
6768
}
6869

70+
// WithStreamCallbackConfig adds and calls FinishCallback when client-side stream
71+
// has finished.
72+
func WithStreamCallbackConfig(cfg *streaming.FinishCallback) StreamOption {
73+
return StreamOption{F: func(o *StreamOptions, di *utils.Slice) {
74+
di.Push(fmt.Sprintf("WithStreamCallbackConfig(%+v)", cfg))
75+
76+
o.FinishCallback = cfg
77+
}}
78+
}
79+
80+
// WithStreamIndependentLifecycle ensures that newly created client-side Streams are not controlled
81+
// by the lifecycle of the passed ctx.
82+
func WithStreamIndependentLifecycle(flag bool) StreamOption {
83+
return StreamOption{F: func(o *StreamOptions, di *utils.Slice) {
84+
di.Push(fmt.Sprintf("WithStreamIndependentLifecycle(%+v)", flag))
85+
86+
o.IndependentLifecycle = flag
87+
}}
88+
}
89+
6990
// WithStreamMiddleware add middleware for stream.
7091
func WithStreamMiddleware(mw cep.StreamMiddleware) StreamOption {
7192
return StreamOption{F: func(o *StreamOptions, di *utils.Slice) {

client/stream.go

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"fmt"
2222
"io"
23+
"runtime/debug"
2324
"sync/atomic"
2425
"time"
2526

@@ -29,6 +30,7 @@ import (
2930
"github.com/cloudwego/kitex/pkg/endpoint"
3031
"github.com/cloudwego/kitex/pkg/endpoint/cep"
3132
"github.com/cloudwego/kitex/pkg/kerrors"
33+
"github.com/cloudwego/kitex/pkg/klog"
3234
"github.com/cloudwego/kitex/pkg/remote"
3335
"github.com/cloudwego/kitex/pkg/remote/remotecli"
3436
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes"
@@ -64,6 +66,9 @@ func (kc *kClient) Stream(ctx context.Context, method string, request, response
6466
}
6567
var ri rpcinfo.RPCInfo
6668
ctx, ri, _ = kc.initRPCInfo(ctx, method, 0, nil, true)
69+
if ri.Config().StreamIndependentLifecycle() {
70+
ctx = context.WithoutCancel(ctx)
71+
}
6772

6873
ctx = kc.opt.TracerCtl.DoStart(ctx, ri)
6974
var reportErr error
@@ -128,6 +133,9 @@ func (kc *kClient) StreamX(ctx context.Context, method string) (streaming.Client
128133
}
129134
var ri rpcinfo.RPCInfo
130135
ctx, ri, _ = kc.initRPCInfo(ctx, method, 0, nil, true)
136+
if ri.Config().StreamIndependentLifecycle() {
137+
ctx = context.WithoutCancel(ctx)
138+
}
131139

132140
ctx = kc.opt.TracerCtl.DoStart(ctx, ri)
133141
var reportErr error
@@ -262,9 +270,6 @@ func newStream(ctx context.Context, cancel context.CancelFunc, s streaming.Clien
262270
st.grpcStream.st = st
263271
}
264272
}
265-
if register, ok := s.(streaming.CloseCallbackRegister); ok {
266-
register.RegisterCloseCallback(st.DoFinish)
267-
}
268273
return st
269274
}
270275

@@ -363,18 +368,44 @@ func (s *stream) DoFinish(err error) {
363368
// already called
364369
return
365370
}
371+
366372
// release stream timeout cancel
367373
if s.cancelFunc != nil {
368374
s.cancelFunc()
369375
}
370-
if !isRPCError(err) {
371-
// only rpc errors are reported
372-
err = nil
376+
377+
// reporting and release connection
378+
reportErr := err
379+
if err == nil || err == io.EOF {
380+
reportErr = nil
381+
}
382+
// If the client-side callback returns a biz error or manually calls streaming.FinishStream/streaming.FinishClientStream with a biz error,
383+
// it needs to be set.
384+
if bizErr, bizOk := err.(kerrors.BizStatusErrorIface); bizOk {
385+
if setter, ok := s.ri.Invocation().(rpcinfo.InvocationSetter); ok {
386+
setter.SetBizStatusErr(bizErr)
387+
}
373388
}
374389
if s.scm != nil {
375-
s.scm.ReleaseConn(err, s.ri)
390+
s.scm.ReleaseConn(reportErr, s.ri)
391+
}
392+
s.kc.opt.TracerCtl.DoFinish(s.ctx, s.ri, reportErr)
393+
394+
// processing callback with original err
395+
stCfg := s.ri.Config().StreamCallbackConfig()
396+
if err == nil || err == io.EOF || stCfg == nil {
397+
return
398+
}
399+
defer func() {
400+
if r := recover(); r != nil {
401+
klog.CtxWarnf(s.ctx, "Panic happened during stream DoFinish. This may caused by injected stream Callback: error=%v, stack=%s", r, string(debug.Stack()))
402+
}
403+
}()
404+
if s.isGRPC {
405+
handleGRPC(s.ctx, s.ri, err, stCfg)
406+
} else {
407+
handleTTStream(s.ctx, s.ri, err, stCfg)
376408
}
377-
s.kc.opt.TracerCtl.DoFinish(s.ctx, s.ri, err)
378409
}
379410

380411
func (s *stream) GetGRPCStream() streaming.Stream {
@@ -478,18 +509,6 @@ func (s *grpcStream) DoFinish(err error) {
478509
s.st.DoFinish(err)
479510
}
480511

481-
func isRPCError(err error) bool {
482-
if err == nil {
483-
return false
484-
}
485-
if err == io.EOF {
486-
return false
487-
}
488-
_, isBizStatusError := err.(kerrors.BizStatusErrorIface)
489-
// if a tracer needs to get the BizStatusError, it should read from rpcinfo.invocation.bizStatusErr
490-
return !isBizStatusError
491-
}
492-
493512
func callWithTimeout(tm time.Duration, call func() error, buildTmErr func(time.Duration) error, cancel func(error)) error {
494513
if tm <= 0 {
495514
return call()

0 commit comments

Comments
 (0)