Skip to content

Commit 7304665

Browse files
rpc: refactor Peer to support DRPC and gRPC conns
Refactor Peer further to make it easy to support both gRPC and DRPC peers. In this change, gRPC Peer continues to dial DRPC connections to keep the changes backward compatible. As part of enabling DRPC support across the codebase, future commits will remove the DRPC-specific code and replace it with DRPC generic peer. Epic: CRDB-48923 Informs: #148383 Release note: none
1 parent dfd17d1 commit 7304665

File tree

6 files changed

+129
-67
lines changed

6 files changed

+129
-67
lines changed

pkg/rpc/connection.go

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"io"
1111

1212
"github.com/cockroachdb/cockroach/pkg/kv/kvpb"
13+
"github.com/cockroachdb/cockroach/pkg/rpc/rpcbase"
1314
"github.com/cockroachdb/cockroach/pkg/util/circuit"
1415
"github.com/cockroachdb/cockroach/pkg/util/stop"
1516
"github.com/cockroachdb/errors"
@@ -21,10 +22,13 @@ import (
2122
// must implement. It is used as a type constraint for rpc connections and allows
2223
// the Connection and Peer structs to work seamlessly with both gRPC and DRPC
2324
// connections.
24-
type rpcConn interface {
25-
io.Closer
26-
comparable
27-
}
25+
type rpcConn io.Closer
26+
27+
// dialFunc is a generic function type used to establish an RPC connection.
28+
// It takes a context for cancellation/timeouts, a target string (e.g., address),
29+
// and a ConnectionClass to categorize the connection's purpose or priority.
30+
// It returns the established connection (of type Conn) or an error if dialing fails.
31+
type dialFunc[Conn rpcConn] func(ctx context.Context, target string, class rpcbase.ConnectionClass) (Conn, error)
2832

2933
// heartbeatClientConstructor is a function type that creates a HeartbeatClient
3034
// for a given rpc connection. This allows us to use different implementations of
@@ -37,6 +41,23 @@ type closeNotifier interface {
3741
CloseNotify(ctx context.Context) <-chan struct{}
3842
}
3943

44+
// ConnectionOptions allow for customization of connection behaviors such as:
45+
// - Establishing a new connection.
46+
// - Client constuctors for clients like batch streams.
47+
// - Comparing two connections for equivalence.
48+
type ConnectionOptions[Conn rpcConn] struct {
49+
// dial function to open a new connection.
50+
dial dialFunc[Conn]
51+
// connEquals defines the equivalence function for two RPC connections.
52+
connEquals equalsFunc[Conn]
53+
// newBatchStreamClient is a constructor function for creating a new batch
54+
// stream client associated with a specific RPC connection.
55+
newBatchStreamClient streamConstructor[*kvpb.BatchRequest, *kvpb.BatchResponse, Conn]
56+
// newCloseNotifier is a constructor function for creating a new
57+
// closeNotifier associated with a specific RPC connection.
58+
newCloseNotifier closeNotifierConstructor[Conn]
59+
}
60+
4061
// closeNotifierConstructor is a function type that creates a closeNotifier
4162
// for a given rpc connection. This allows us to use different implementations of
4263
// closeNotifier for different types of connections (e.g., gRPC and DRPC).
@@ -81,16 +102,17 @@ func newConnectionToNodeID[Conn rpcConn](
81102
opts *ContextOptions,
82103
k peerKey,
83104
breakerSignal func() circuit.Signal,
84-
newBatchStreamClient streamConstructor[*kvpb.BatchRequest, *kvpb.BatchResponse, Conn],
105+
connOptions *ConnectionOptions[Conn],
85106
) *Connection[Conn] {
107+
drpcConnEquals := func(a, b drpc.Conn) bool { return a == b }
86108
c := &Connection[Conn]{
87109
breakerSignalFn: breakerSignal,
88110
k: k,
89111
connFuture: connFuture[Conn]{
90112
ready: make(chan struct{}),
91113
},
92-
batchStreamPool: makeStreamPool(opts.Stopper, newBatchStreamClient),
93-
drpcBatchStreamPool: makeStreamPool(opts.Stopper, newDRPCBatchStream),
114+
batchStreamPool: makeStreamPool(opts.Stopper, connOptions.newBatchStreamClient, connOptions.connEquals),
115+
drpcBatchStreamPool: makeStreamPool(opts.Stopper, newDRPCBatchStream, drpcConnEquals),
94116
}
95117
return c
96118
}
@@ -107,10 +129,10 @@ func (c *Connection[Conn]) waitOrDefault(
107129
// want it to take precedence over connFuture below (which is closed in
108130
// the common case of a connection going bad after having been healthy
109131
// for a while).
110-
var cc Conn
132+
var nilConn Conn
111133
select {
112134
case <-sig.C():
113-
return cc, nil, sig.Err()
135+
return nilConn, nil, sig.Err()
114136
default:
115137
}
116138

@@ -121,19 +143,19 @@ func (c *Connection[Conn]) waitOrDefault(
121143
select {
122144
case <-c.connFuture.C():
123145
case <-sig.C():
124-
return cc, nil, sig.Err()
146+
return nilConn, nil, sig.Err()
125147
case <-ctx.Done():
126-
return cc, nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr)
148+
return nilConn, nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr)
127149
}
128150
} else {
129151
select {
130152
case <-c.connFuture.C():
131153
case <-sig.C():
132-
return cc, nil, sig.Err()
154+
return nilConn, nil, sig.Err()
133155
case <-ctx.Done():
134-
return cc, nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr)
156+
return nilConn, nil, errors.Wrapf(ctx.Err(), "while connecting to n%d at %s", c.k.NodeID, c.k.TargetAddr)
135157
default:
136-
return cc, nil, defErr
158+
return nilConn, nil, defErr
137159
}
138160
}
139161

pkg/rpc/context.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2065,7 +2065,7 @@ func (rpcCtx *Context) grpcDialNodeInternal(
20652065
conns.mu.m = map[peerKey]*peer[*grpc.ClientConn]{}
20662066
}
20672067

2068-
p := rpcCtx.newPeer(k, remoteLocality)
2068+
p := newPeer(rpcCtx, k, newGRPCPeerOptions(rpcCtx, k, remoteLocality))
20692069
// (Asynchronously) Start the probe (= heartbeat loop). The breaker is healthy
20702070
// right now (it was just created) but the call to `.Probe` will launch the
20712071
// probe[1] regardless.

pkg/rpc/grpc.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ package rpc
88
import (
99
"context"
1010

11+
"github.com/cockroachdb/cockroach/pkg/kv/kvpb"
12+
"github.com/cockroachdb/cockroach/pkg/roachpb"
13+
"github.com/cockroachdb/cockroach/pkg/rpc/rpcbase"
1114
"github.com/cockroachdb/cockroach/pkg/util/stop"
1215
"google.golang.org/grpc"
1316
"google.golang.org/grpc/connectivity"
@@ -47,3 +50,34 @@ func (g *grpcCloseNotifier) CloseNotify(ctx context.Context) <-chan struct{} {
4750
}
4851

4952
type GRPCConnection = Connection[*grpc.ClientConn]
53+
54+
// newGRPCPeerOptions creates peerOptions for gRPC peers.
55+
func newGRPCPeerOptions(
56+
rpcCtx *Context, k peerKey, locality roachpb.Locality,
57+
) *peerOptions[*grpc.ClientConn] {
58+
pm, lm := rpcCtx.metrics.acquire(k, locality)
59+
return &peerOptions[*grpc.ClientConn]{
60+
locality: locality,
61+
peers: &rpcCtx.peers,
62+
connOptions: &ConnectionOptions[*grpc.ClientConn]{
63+
dial: func(ctx context.Context, target string, class rpcbase.ConnectionClass) (*grpc.ClientConn, error) {
64+
additionalDialOpts := []grpc.DialOption{grpc.WithStatsHandler(&statsTracker{lm})}
65+
additionalDialOpts = append(additionalDialOpts, rpcCtx.testingDialOpts...)
66+
return rpcCtx.grpcDialRaw(ctx, target, class, additionalDialOpts...)
67+
},
68+
connEquals: func(a, b *grpc.ClientConn) bool {
69+
return a == b
70+
},
71+
newBatchStreamClient: func(ctx context.Context, cc *grpc.ClientConn) (BatchStreamClient, error) {
72+
return kvpb.NewInternalClient(cc).BatchStream(ctx)
73+
},
74+
newCloseNotifier: func(stopper *stop.Stopper, cc *grpc.ClientConn) closeNotifier {
75+
return &grpcCloseNotifier{stopper: stopper, conn: cc}
76+
},
77+
},
78+
newHeartbeatClient: func(cc *grpc.ClientConn) RPCHeartbeatClient {
79+
return NewGRPCHeartbeatClientAdapter(cc)
80+
},
81+
pm: pm,
82+
}
83+
}

pkg/rpc/peer.go

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,15 @@ import (
1414
"github.com/VividCortex/ewma"
1515
"github.com/cockroachdb/cockroach/pkg/kv/kvpb"
1616
"github.com/cockroachdb/cockroach/pkg/roachpb"
17-
"github.com/cockroachdb/cockroach/pkg/rpc/rpcbase"
1817
"github.com/cockroachdb/cockroach/pkg/util/circuit"
1918
"github.com/cockroachdb/cockroach/pkg/util/grpcutil"
2019
"github.com/cockroachdb/cockroach/pkg/util/log"
2120
"github.com/cockroachdb/cockroach/pkg/util/netutil"
22-
"github.com/cockroachdb/cockroach/pkg/util/stop"
2321
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
2422
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
2523
"github.com/cockroachdb/errors"
2624
"github.com/cockroachdb/logtags"
2725
"github.com/cockroachdb/redact"
28-
"google.golang.org/grpc"
2926
"google.golang.org/grpc/status"
3027
"storj.io/drpc"
3128
)
@@ -121,15 +118,13 @@ func (p *peer[Conn]) releaseMetricsLocked() {
121118
// See (*peer).launch for details on the probe (heartbeat loop) itself.
122119
type peer[Conn rpcConn] struct {
123120
peerMetrics
124-
k peerKey
125-
opts *ContextOptions
126-
heartbeatInterval time.Duration
127-
heartbeatTimeout time.Duration
128-
dial func(ctx context.Context, target string, class rpcbase.ConnectionClass) (Conn, error)
129-
dialDRPC func(ctx context.Context, target string, class rpcbase.ConnectionClass) (drpc.Conn, error)
130-
newHeartbeatClient heartbeatClientConstructor[Conn]
131-
newBatchStreamClient streamConstructor[*kvpb.BatchRequest, *kvpb.BatchResponse, Conn]
132-
newCloseNotifier closeNotifierConstructor[Conn]
121+
k peerKey
122+
opts *ContextOptions
123+
newHeartbeatClient heartbeatClientConstructor[Conn]
124+
heartbeatInterval time.Duration
125+
heartbeatTimeout time.Duration
126+
connOptions *ConnectionOptions[Conn]
127+
drpcDial dialFunc[drpc.Conn]
133128
// b maintains connection health. This breaker's async probe is always
134129
// active - it is the heartbeat loop and manages `mu.c.` (including
135130
// recreating it after the connection fails and has to be redialed).
@@ -210,6 +205,14 @@ func (p *peer[Conn]) snap() PeerSnap[Conn] {
210205
return p.mu.PeerSnap
211206
}
212207

208+
type peerOptions[Conn rpcConn] struct {
209+
locality roachpb.Locality
210+
pm peerMetrics
211+
newHeartbeatClient heartbeatClientConstructor[Conn]
212+
connOptions *ConnectionOptions[Conn]
213+
peers *peerMap[Conn]
214+
}
215+
213216
// newPeer returns circuit breaker that trips when connection (associated
214217
// with provided peerKey) is failed. The breaker's probe *is* the heartbeat loop
215218
// and is thus running at all times. The exception is a decommissioned node, for
@@ -229,39 +232,26 @@ func (p *peer[Conn]) snap() PeerSnap[Conn] {
229232
// map, the next attempt to dial the node will start from a blank slate. In
230233
// other words, even with this theoretical race, the situation will sort itself
231234
// out quickly.
232-
func (rpcCtx *Context) newPeer(k peerKey, locality roachpb.Locality) *peer[*grpc.ClientConn] {
235+
func newPeer[Conn rpcConn](rpcCtx *Context, k peerKey, peerOpts *peerOptions[Conn]) *peer[Conn] {
233236
// Initialization here is a bit circular. The peer holds the breaker. The
234237
// breaker probe references the peer because it needs to replace the one-shot
235238
// Connection when it makes a new connection in the probe. And (all but the
236239
// first incarnation of) the Connection also holds on to the breaker since the
237240
// Connect method needs to do the short-circuiting (if a Connection is created
238241
// while the breaker is tripped, we want to block in Connect only once we've
239242
// seen the first heartbeat succeed).
240-
pm, lm := rpcCtx.metrics.acquire(k, locality)
241-
p := &peer[*grpc.ClientConn]{
242-
peerMetrics: pm,
243+
p := &peer[Conn]{
244+
peerMetrics: peerOpts.pm,
243245
logDisconnectEvery: log.Every(time.Minute),
244246
k: k,
245247
remoteClocks: rpcCtx.RemoteClocks,
246248
opts: &rpcCtx.ContextOptions,
247-
peers: &rpcCtx.peers,
248-
dial: func(ctx context.Context, target string, class rpcbase.ConnectionClass) (*grpc.ClientConn, error) {
249-
additionalDialOpts := []grpc.DialOption{grpc.WithStatsHandler(&statsTracker{lm})}
250-
additionalDialOpts = append(additionalDialOpts, rpcCtx.testingDialOpts...)
251-
return rpcCtx.grpcDialRaw(ctx, target, class, additionalDialOpts...)
252-
},
253-
dialDRPC: dialDRPC(rpcCtx),
254-
newHeartbeatClient: func(cc *grpc.ClientConn) RPCHeartbeatClient {
255-
return NewGRPCHeartbeatClientAdapter(cc)
256-
},
257-
newBatchStreamClient: func(ctx context.Context, cc *grpc.ClientConn) (BatchStreamClient, error) {
258-
return kvpb.NewInternalClient(cc).BatchStream(ctx)
259-
},
260-
newCloseNotifier: func(stopper *stop.Stopper, cc *grpc.ClientConn) closeNotifier {
261-
return &grpcCloseNotifier{stopper: stopper, conn: cc}
262-
},
263-
heartbeatInterval: rpcCtx.RPCHeartbeatInterval,
264-
heartbeatTimeout: rpcCtx.RPCHeartbeatTimeout,
249+
peers: peerOpts.peers,
250+
connOptions: peerOpts.connOptions,
251+
drpcDial: dialDRPC(rpcCtx),
252+
newHeartbeatClient: peerOpts.newHeartbeatClient,
253+
heartbeatInterval: rpcCtx.RPCHeartbeatInterval,
254+
heartbeatTimeout: rpcCtx.RPCHeartbeatTimeout,
265255
}
266256
var b *circuit.Breaker
267257

@@ -275,8 +265,8 @@ func (rpcCtx *Context) newPeer(k peerKey, locality roachpb.Locality) *peer[*grpc
275265
},
276266
})
277267
p.b = b
278-
c := newConnectionToNodeID(p.opts, k, b.Signal, p.newBatchStreamClient)
279-
p.mu.PeerSnap = PeerSnap[*grpc.ClientConn]{c: c}
268+
c := newConnectionToNodeID(p.opts, k, b.Signal, p.connOptions)
269+
p.mu.PeerSnap = PeerSnap[Conn]{c: c}
280270

281271
return p
282272
}
@@ -375,7 +365,7 @@ func (p *peer[Conn]) run(ctx context.Context, report func(error), done func()) {
375365
func() {
376366
p.mu.Lock()
377367
defer p.mu.Unlock()
378-
p.mu.c = newConnectionToNodeID(p.opts, p.k, p.mu.c.breakerSignalFn, p.newBatchStreamClient)
368+
p.mu.c = newConnectionToNodeID(p.opts, p.k, p.mu.c.breakerSignalFn, p.connOptions)
379369
}()
380370

381371
if p.snap().deleteAfter != 0 {
@@ -388,14 +378,14 @@ func (p *peer[Conn]) run(ctx context.Context, report func(error), done func()) {
388378
}
389379

390380
func (p *peer[Conn]) runOnce(ctx context.Context, report func(error)) error {
391-
cc, err := p.dial(ctx, p.k.TargetAddr, p.k.Class)
381+
cc, err := p.connOptions.dial(ctx, p.k.TargetAddr, p.k.Class)
392382
if err != nil {
393383
return err
394384
}
395385
defer func() {
396386
_ = cc.Close() // nolint:grpcconnclose
397387
}()
398-
dc, err := p.dialDRPC(ctx, p.k.TargetAddr, p.k.Class)
388+
dc, err := p.drpcDial(ctx, p.k.TargetAddr, p.k.Class)
399389
if err != nil {
400390
return err
401391
}
@@ -406,7 +396,7 @@ func (p *peer[Conn]) runOnce(ctx context.Context, report func(error)) error {
406396
// Set up notifications on a channel when gRPC tears down, so that we
407397
// can trigger another instant heartbeat for expedited circuit breaker
408398
// tripping.
409-
connClosedCh := p.newCloseNotifier(p.opts.Stopper, cc).CloseNotify(ctx)
399+
connClosedCh := p.connOptions.newCloseNotifier(p.opts.Stopper, cc).CloseNotify(ctx)
410400

411401
if p.remoteClocks != nil {
412402
p.remoteClocks.OnConnect(ctx, p.k.NodeID)

pkg/rpc/stream_pool.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ type result[Resp any] struct {
4343
// that stream pools eventually shrink when the load decreases.
4444
const defaultPooledStreamIdleTimeout = 10 * time.Second
4545

46+
// equalsFunc is a generic function type used to compare two RPC connections
47+
// for equivalence.
48+
type equalsFunc[Conn rpcConn] func(a, b Conn) bool
49+
4650
// pooledStream is a wrapper around a grpc.ClientStream that is managed by a
4751
// streamPool. It is responsible for sending a single request and receiving a
4852
// single response on the stream at a time, mimicking the behavior of a gRPC
@@ -194,6 +198,7 @@ type streamPool[Req, Resp any, Conn rpcConn] struct {
194198
stopper *stop.Stopper
195199
idleTimeout time.Duration
196200
newStream streamConstructor[Req, Resp, Conn]
201+
connEquals equalsFunc[Conn]
197202

198203
// cc and ccCtx are set on bind, when the gRPC connection is established.
199204
cc Conn
@@ -207,12 +212,13 @@ type streamPool[Req, Resp any, Conn rpcConn] struct {
207212
}
208213

209214
func makeStreamPool[Req, Resp any, Conn rpcConn](
210-
stopper *stop.Stopper, newStream streamConstructor[Req, Resp, Conn],
215+
stopper *stop.Stopper, newStream streamConstructor[Req, Resp, Conn], connEquals equalsFunc[Conn],
211216
) streamPool[Req, Resp, Conn] {
212217
return streamPool[Req, Resp, Conn]{
213218
stopper: stopper,
214219
idleTimeout: defaultPooledStreamIdleTimeout,
215220
newStream: newStream,
221+
connEquals: connEquals,
216222
}
217223
}
218224

@@ -280,7 +286,7 @@ func (p *streamPool[Req, Resp, Conn]) remove(s *pooledStream[Req, Resp, Conn]) b
280286

281287
func (p *streamPool[Req, Resp, Conn]) newPooledStream() (*pooledStream[Req, Resp, Conn], error) {
282288
var zero Conn
283-
if p.cc == zero {
289+
if p.connEquals(p.cc, zero) {
284290
return nil, errors.AssertionFailedf("streamPool not bound to a client conn")
285291
}
286292

0 commit comments

Comments
 (0)