diff --git a/client/option.go b/client/option.go index 29a0161908..3c31f8ec3e 100644 --- a/client/option.go +++ b/client/option.go @@ -19,6 +19,7 @@ package client import ( "context" "fmt" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "net" "reflect" "strings" @@ -497,6 +498,14 @@ func WithGRPCKeepaliveParams(kp grpc.ClientKeepalive) Option { }} } +// WithGRPCConnectionGetter configures the way to get a connection in gRPC transport +func WithGRPCConnectionGetter(getter nphttp2.ConnectionGetter) Option { + return Option{F: func(o *client.Options, di *utils.Slice) { + di.Push("WithGRPCConnectionGetter") + o.GRPCConnectionGetter = getter + }} +} + // WithWarmingUp forces the client to do some warm-ups at the end of the initialization. func WithWarmingUp(wuo *warmup.ClientOption) Option { return Option{F: func(o *client.Options, di *utils.Slice) { diff --git a/client/option_test.go b/client/option_test.go index 883f67481c..650824d89c 100644 --- a/client/option_test.go +++ b/client/option_test.go @@ -20,6 +20,9 @@ import ( "context" "crypto/tls" "fmt" + "github.com/cloudwego/kitex/pkg/remote" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" + "net" "reflect" "testing" "time" @@ -501,6 +504,26 @@ func TestWithGRPCKeepaliveParams(t *testing.T) { test.Assert(t, opts.GRPCConnectOpts.KeepaliveParams.PermitWithoutStream) } +type mockConnGetter struct{} + +func (mg *mockConnGetter) Get(ctx context.Context, network, address string, opt remote.ConnOption) (net.Conn, error) { + return nil, nil +} + +func TestWithGRPCConnectionGetter(t *testing.T) { + var connGetter nphttp2.ConnectionGetter = &mockConnGetter{} + cliOpt := []client.Option{ + WithGRPCConnectionGetter(connGetter), + } + opts := client.NewOptions(cliOpt) + test.Assert(t, opts.GRPCConnectionGetter == connGetter) + + // no configuration + cliOpt = []client.Option{} + opts = client.NewOptions(cliOpt) + test.Assert(t, opts.GRPCConnectionGetter == nil) +} + func TestWithHTTPConnection(t *testing.T) { opts := client.NewOptions([]client.Option{WithHTTPConnection()}) test.Assert(t, opts.RemoteOpt.CliHandlerFactory != nil) diff --git a/internal/client/option.go b/internal/client/option.go index 5d3b30635a..f61fcc6319 100644 --- a/internal/client/option.go +++ b/internal/client/option.go @@ -109,8 +109,9 @@ type Options struct { WarmUpOption *warmup.ClientOption // GRPC - GRPCConnPoolSize uint32 - GRPCConnectOpts *grpc.ConnectOptions + GRPCConnPoolSize uint32 + GRPCConnectOpts *grpc.ConnectOptions + GRPCConnectionGetter nphttp2.ConnectionGetter // XDS XDSEnabled bool @@ -180,7 +181,7 @@ func (o *Options) initRemoteOpt() { // grpc unary short connection o.GRPCConnectOpts.ShortConn = true } - o.RemoteOpt.ConnPool = nphttp2.NewConnPool(o.Svr.ServiceName, o.GRPCConnPoolSize, *o.GRPCConnectOpts) + o.RemoteOpt.ConnPool = nphttp2.NewConnPool(o.Svr.ServiceName, o.GRPCConnPoolSize, *o.GRPCConnectOpts, o.GRPCConnectionGetter) o.RemoteOpt.CliHandlerFactory = nphttp2.NewCliTransHandlerFactory() } if o.RemoteOpt.ConnPool == nil { diff --git a/pkg/remote/trans/nphttp2/conn_pool.go b/pkg/remote/trans/nphttp2/conn_pool.go index cdd1a2f785..af4978122f 100644 --- a/pkg/remote/trans/nphttp2/conn_pool.go +++ b/pkg/remote/trans/nphttp2/conn_pool.go @@ -32,6 +32,11 @@ import ( "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" ) +// ConnectionGetter is used to get a connection based on address and connection option. +type ConnectionGetter interface { + Get(ctx context.Context, network, address string, opt remote.ConnOption) (net.Conn, error) +} + var _ remote.LongConnPool = &connPool{} func poolSize() uint32 { @@ -42,7 +47,7 @@ func poolSize() uint32 { } // NewConnPool ... -func NewConnPool(remoteService string, size uint32, connOpts grpc.ConnectOptions) *connPool { +func NewConnPool(remoteService string, size uint32, connOpts grpc.ConnectOptions, connGetter ConnectionGetter) *connPool { if size == 0 { size = poolSize() } @@ -50,6 +55,7 @@ func NewConnPool(remoteService string, size uint32, connOpts grpc.ConnectOptions remoteService: remoteService, size: size, connOpts: connOpts, + connGetter: connGetter, } } @@ -60,6 +66,7 @@ type connPool struct { conns sync.Map // key: address, value: *transports remoteService string // remote service name connOpts grpc.ConnectOptions + connGetter ConnectionGetter // use this to get a connection if specified } type transports struct { @@ -102,18 +109,28 @@ func (t *transports) close() { var _ remote.LongConnPool = (*connPool)(nil) func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, network, address string, - connectTimeout time.Duration, opts grpc.ConnectOptions, -) (grpc.ClientTransport, error) { - conn, err := dialer.DialTimeout(network, address, connectTimeout) - if err != nil { - return nil, err - } - if opts.TLSConfig != nil { - tlsConn, err := newTLSConn(conn, opts.TLSConfig) + connectTimeout time.Duration, opts grpc.ConnectOptions, opt remote.ConnOption) (grpc.ClientTransport, error) { + var ( + conn net.Conn + err error + ) + if p.connGetter != nil { + conn, err = p.connGetter.Get(ctx, network, address, opt) + if err != nil { + return nil, err + } + } else { + conn, err = dialer.DialTimeout(network, address, connectTimeout) if err != nil { return nil, err } - conn = tlsConn + if opts.TLSConfig != nil { + tlsConn, err := newTLSConn(conn, opts.TLSConfig) + if err != nil { + return nil, err + } + conn = tlsConn + } } return grpc.NewClientTransport( ctx, @@ -158,7 +175,7 @@ func (p *connPool) Get(ctx context.Context, network, address string, opt remote. tr, err, _ := p.sfg.Do(address, func() (i interface{}, e error) { // Notice: newTransport means new a connection, the timeout of connection cannot be set, // so using context.Background() but not the ctx passed in as the parameter. - tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts) + tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts, opt) if err != nil { return nil, err } @@ -197,7 +214,7 @@ func (p *connPool) release(conn net.Conn) error { func (p *connPool) createShortConn(ctx context.Context, network, address string, opt remote.ConnOption) (net.Conn, error) { // Notice: newTransport means new a connection, the timeout of connection cannot be set, // so using context.Background() but not the ctx passed in as the parameter. - tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts) + tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts, opt) if err != nil { return nil, err } diff --git a/pkg/remote/trans/nphttp2/mocks_test.go b/pkg/remote/trans/nphttp2/mocks_test.go index e027235869..225b9ffef3 100644 --- a/pkg/remote/trans/nphttp2/mocks_test.go +++ b/pkg/remote/trans/nphttp2/mocks_test.go @@ -286,7 +286,7 @@ func newMockConnPool() *connPool { WriteBufferSize: defaultMockReadWriteBufferSize, ReadBufferSize: defaultMockReadWriteBufferSize, MaxHeaderListSize: nil, - }) + }, nil) return connPool } diff --git a/pkg/remote/trans/nphttp2/server_handler.go b/pkg/remote/trans/nphttp2/server_handler.go index dd99c38915..e9a98e8990 100644 --- a/pkg/remote/trans/nphttp2/server_handler.go +++ b/pkg/remote/trans/nphttp2/server_handler.go @@ -83,13 +83,27 @@ var prefaceReadAtMost = func() int { return grpcTransport.ClientPrefaceLen }() +type PeekReader interface { + Peek(size int) ([]byte, error) +} + func (t *svrTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) (err error) { // Check the validity of client preface. - npReader := conn.(interface{ Reader() netpoll.Reader }).Reader() - // read at most avoid block - preface, err := npReader.Peek(prefaceReadAtMost) - if err != nil { - return err + var preface []byte + if withReader, ok := conn.(interface{ Reader() netpoll.Reader }); ok { + npReader := withReader.Reader() + // read at most avoid block + preface, err = npReader.Peek(prefaceReadAtMost) + if err != nil { + return err + } + } else if peekReader, ok := conn.(PeekReader); ok { + preface, err = peekReader.Peek(prefaceReadAtMost) + if err != nil { + return err + } + } else { + return errors.New("read protocol info failed") } if bytes.Equal(preface[:prefaceReadAtMost], grpcTransport.ClientPreface[:prefaceReadAtMost]) { return nil