Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions client/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package client
import (
"context"
"fmt"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2"
"net"
"reflect"
"strings"
Expand Down Expand Up @@ -497,6 +498,14 @@ func WithGRPCKeepaliveParams(kp grpc.ClientKeepalive) Option {
}}
}

// WithGRPCConnectionGetter configures the way to get a connection in gRPC transport
func WithGRPCConnectionGetter(getter nphttp2.ConnectionGetter) Option {
return Option{F: func(o *client.Options, di *utils.Slice) {
di.Push("WithGRPCConnectionGetter")
o.GRPCConnectionGetter = getter
}}
}

// WithWarmingUp forces the client to do some warm-ups at the end of the initialization.
func WithWarmingUp(wuo *warmup.ClientOption) Option {
return Option{F: func(o *client.Options, di *utils.Slice) {
Expand Down
23 changes: 23 additions & 0 deletions client/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ import (
"context"
"crypto/tls"
"fmt"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2"
"net"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -501,6 +504,26 @@ func TestWithGRPCKeepaliveParams(t *testing.T) {
test.Assert(t, opts.GRPCConnectOpts.KeepaliveParams.PermitWithoutStream)
}

type mockConnGetter struct{}

func (mg *mockConnGetter) Get(ctx context.Context, network, address string, opt remote.ConnOption) (net.Conn, error) {
return nil, nil
}

func TestWithGRPCConnectionGetter(t *testing.T) {
var connGetter nphttp2.ConnectionGetter = &mockConnGetter{}
cliOpt := []client.Option{
WithGRPCConnectionGetter(connGetter),
}
opts := client.NewOptions(cliOpt)
test.Assert(t, opts.GRPCConnectionGetter == connGetter)

// no configuration
cliOpt = []client.Option{}
opts = client.NewOptions(cliOpt)
test.Assert(t, opts.GRPCConnectionGetter == nil)
}

func TestWithHTTPConnection(t *testing.T) {
opts := client.NewOptions([]client.Option{WithHTTPConnection()})
test.Assert(t, opts.RemoteOpt.CliHandlerFactory != nil)
Expand Down
7 changes: 4 additions & 3 deletions internal/client/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ type Options struct {
WarmUpOption *warmup.ClientOption

// GRPC
GRPCConnPoolSize uint32
GRPCConnectOpts *grpc.ConnectOptions
GRPCConnPoolSize uint32
GRPCConnectOpts *grpc.ConnectOptions
GRPCConnectionGetter nphttp2.ConnectionGetter

// XDS
XDSEnabled bool
Expand Down Expand Up @@ -180,7 +181,7 @@ func (o *Options) initRemoteOpt() {
// grpc unary short connection
o.GRPCConnectOpts.ShortConn = true
}
o.RemoteOpt.ConnPool = nphttp2.NewConnPool(o.Svr.ServiceName, o.GRPCConnPoolSize, *o.GRPCConnectOpts)
o.RemoteOpt.ConnPool = nphttp2.NewConnPool(o.Svr.ServiceName, o.GRPCConnPoolSize, *o.GRPCConnectOpts, o.GRPCConnectionGetter)
o.RemoteOpt.CliHandlerFactory = nphttp2.NewCliTransHandlerFactory()
}
if o.RemoteOpt.ConnPool == nil {
Expand Down
41 changes: 29 additions & 12 deletions pkg/remote/trans/nphttp2/conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ import (
"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
)

// ConnectionGetter is used to get a connection based on address and connection option.
type ConnectionGetter interface {
Get(ctx context.Context, network, address string, opt remote.ConnOption) (net.Conn, error)
}

var _ remote.LongConnPool = &connPool{}

func poolSize() uint32 {
Expand All @@ -42,14 +47,15 @@ func poolSize() uint32 {
}

// NewConnPool ...
func NewConnPool(remoteService string, size uint32, connOpts grpc.ConnectOptions) *connPool {
func NewConnPool(remoteService string, size uint32, connOpts grpc.ConnectOptions, connGetter ConnectionGetter) *connPool {
if size == 0 {
size = poolSize()
}
return &connPool{
remoteService: remoteService,
size: size,
connOpts: connOpts,
connGetter: connGetter,
}
}

Expand All @@ -60,6 +66,7 @@ type connPool struct {
conns sync.Map // key: address, value: *transports
remoteService string // remote service name
connOpts grpc.ConnectOptions
connGetter ConnectionGetter // use this to get a connection if specified
}

type transports struct {
Expand Down Expand Up @@ -102,18 +109,28 @@ func (t *transports) close() {
var _ remote.LongConnPool = (*connPool)(nil)

func (p *connPool) newTransport(ctx context.Context, dialer remote.Dialer, network, address string,
connectTimeout time.Duration, opts grpc.ConnectOptions,
) (grpc.ClientTransport, error) {
conn, err := dialer.DialTimeout(network, address, connectTimeout)
if err != nil {
return nil, err
}
if opts.TLSConfig != nil {
tlsConn, err := newTLSConn(conn, opts.TLSConfig)
connectTimeout time.Duration, opts grpc.ConnectOptions, opt remote.ConnOption) (grpc.ClientTransport, error) {
var (
conn net.Conn
err error
)
if p.connGetter != nil {
conn, err = p.connGetter.Get(ctx, network, address, opt)
if err != nil {
return nil, err
}
} else {
conn, err = dialer.DialTimeout(network, address, connectTimeout)
if err != nil {
return nil, err
}
conn = tlsConn
if opts.TLSConfig != nil {
tlsConn, err := newTLSConn(conn, opts.TLSConfig)
if err != nil {
return nil, err
}
conn = tlsConn
}
}
return grpc.NewClientTransport(
ctx,
Expand Down Expand Up @@ -158,7 +175,7 @@ func (p *connPool) Get(ctx context.Context, network, address string, opt remote.
tr, err, _ := p.sfg.Do(address, func() (i interface{}, e error) {
// Notice: newTransport means new a connection, the timeout of connection cannot be set,
// so using context.Background() but not the ctx passed in as the parameter.
tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts)
tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts, opt)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -197,7 +214,7 @@ func (p *connPool) release(conn net.Conn) error {
func (p *connPool) createShortConn(ctx context.Context, network, address string, opt remote.ConnOption) (net.Conn, error) {
// Notice: newTransport means new a connection, the timeout of connection cannot be set,
// so using context.Background() but not the ctx passed in as the parameter.
tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts)
tr, err := p.newTransport(context.Background(), opt.Dialer, network, address, opt.ConnectTimeout, p.connOpts, opt)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/remote/trans/nphttp2/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ func newMockConnPool() *connPool {
WriteBufferSize: defaultMockReadWriteBufferSize,
ReadBufferSize: defaultMockReadWriteBufferSize,
MaxHeaderListSize: nil,
})
}, nil)
return connPool
}

Expand Down
24 changes: 19 additions & 5 deletions pkg/remote/trans/nphttp2/server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,27 @@ var prefaceReadAtMost = func() int {
return grpcTransport.ClientPrefaceLen
}()

type PeekReader interface {
Peek(size int) ([]byte, error)
}

func (t *svrTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) (err error) {
// Check the validity of client preface.
npReader := conn.(interface{ Reader() netpoll.Reader }).Reader()
// read at most avoid block
preface, err := npReader.Peek(prefaceReadAtMost)
if err != nil {
return err
var preface []byte
if withReader, ok := conn.(interface{ Reader() netpoll.Reader }); ok {
npReader := withReader.Reader()
// read at most avoid block
preface, err = npReader.Peek(prefaceReadAtMost)
if err != nil {
return err
}
} else if peekReader, ok := conn.(PeekReader); ok {
preface, err = peekReader.Peek(prefaceReadAtMost)
if err != nil {
return err
}
} else {
return errors.New("read protocol info failed")
}
if bytes.Equal(preface[:prefaceReadAtMost], grpcTransport.ClientPreface[:prefaceReadAtMost]) {
return nil
Expand Down