Skip to content

Commit 5a237bf

Browse files
kv,rpc,server: consolidate RPC clients creation
To support DRPC and reuse the connection management logic in [Connection](https://github.com/cockroachdb/cockroach/blob/master/pkg/rpc/connection.go#L54) and [Peer](https://github.com/cockroachdb/cockroach/blob/8fed6e8481da1373020ab2cc95d222df02f61721/pkg/rpc/peer.go#L121), they were refactored to use generics. Previously, most code established connections to peers using [nodedialer](https://github.com/cockroachdb/cockroach/blob/8fed6e8481da1373020ab2cc95d222df02f61721/pkg/rpc/nodedialer/nodedialer.go#L37-L41) or [rpc.Context](https://github.com/cockroachdb/cockroach/blob/8fed6e8481da1373020ab2cc95d222df02f61721/pkg/rpc/context.go#L2004-L2016), subsequently creating an RPC client. Supporting DRPC would require updating numerous call-sites to conditionally dial `gRPC` or `DRPC` connections, introducing significant boilerplate code. This commit consolidates RPC client creation logic in the `kvserver` package. Call sites were updated to dial RPC clients than dialing connections. Subsequent work will ensure the appropriate connection type is dialed and the corresponding client is created. RPC clients will be unified to use a consistent interface for both `gRPC` and `DRPC` connections. Epic: CRDB-48923 Informs: #147757 Release note: none
1 parent 75c5230 commit 5a237bf

File tree

17 files changed

+185
-42
lines changed

17 files changed

+185
-42
lines changed

pkg/kv/kvserver/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ go_library(
7979
"replica_tscache.go",
8080
"replica_write.go",
8181
"replicate_queue.go",
82+
"rpc_clients.go",
8283
"scanner.go",
8384
"scheduler.go",
8485
"snapshot_apply_prepare.go",

pkg/kv/kvserver/client_raft_test.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4671,12 +4671,12 @@ func TestStoreRangeWaitForApplication(t *testing.T) {
46714671

46724672
var targets []target
46734673
for _, s := range tc.Servers {
4674-
conn, err := s.NodeDialer().(*nodedialer.Dialer).Dial(ctx, s.NodeID(), rpcbase.DefaultClass)
4674+
client, err := kvserver.DialPerReplicaClient(s.NodeDialer().(*nodedialer.Dialer), ctx, s.NodeID(), rpcbase.DefaultClass)
46754675
if err != nil {
46764676
t.Fatal(err)
46774677
}
46784678
targets = append(targets, target{
4679-
client: kvserver.NewPerReplicaClient(conn),
4679+
client: client,
46804680
header: kvserver.StoreRequestHeader{NodeID: s.NodeID(), StoreID: s.GetFirstStoreID()},
46814681
})
46824682
}
@@ -4799,11 +4799,10 @@ func TestStoreWaitForReplicaInit(t *testing.T) {
47994799
defer tc.Stopper().Stop(ctx)
48004800
store := tc.GetFirstStoreFromServer(t, 0)
48014801

4802-
conn, err := tc.Servers[0].NodeDialer().(*nodedialer.Dialer).Dial(ctx, store.Ident.NodeID, rpcbase.DefaultClass)
4802+
client, err := kvserver.DialPerReplicaClient(tc.Servers[0].NodeDialer().(*nodedialer.Dialer), ctx, store.Ident.NodeID, rpcbase.DefaultClass)
48034803
if err != nil {
48044804
t.Fatal(err)
48054805
}
4806-
client := kvserver.NewPerReplicaClient(conn)
48074806
storeHeader := kvserver.StoreRequestHeader{NodeID: store.Ident.NodeID, StoreID: store.Ident.StoreID}
48084807

48094808
// Test that WaitForReplicaInit returns successfully if the replica exists.

pkg/kv/kvserver/closedts/ctpb/BUILD.bazel

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
44

55
go_library(
66
name = "ctpb",
7-
srcs = ["service.go"],
7+
srcs = [
8+
"rpc_clients.go",
9+
"service.go",
10+
],
811
embed = [":ctpb_go_proto"],
912
importpath = "github.com/cockroachdb/cockroach/pkg/kv/kvserver/closedts/ctpb",
1013
visibility = ["//visibility:public"],
1114
deps = [
1215
"//pkg/kv/kvpb",
1316
"//pkg/roachpb",
17+
"//pkg/rpc/rpcbase",
1418
"//pkg/util/timeutil",
1519
],
1620
)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package ctpb
7+
8+
import (
9+
context "context"
10+
11+
"github.com/cockroachdb/cockroach/pkg/roachpb"
12+
"github.com/cockroachdb/cockroach/pkg/rpc/rpcbase"
13+
)
14+
15+
// DialSideTransportClient establishes a DRPC connection if enabled; otherwise,
16+
// it falls back to gRPC. The established connection is used to create a
17+
// SideTransportClient.
18+
func DialSideTransportClient(
19+
nd rpcbase.NodeDialer, ctx context.Context, nodeID roachpb.NodeID, class rpcbase.ConnectionClass,
20+
) (SideTransportClient, error) {
21+
if !rpcbase.TODODRPC {
22+
conn, err := nd.Dial(ctx, nodeID, class)
23+
if err != nil {
24+
return nil, err
25+
}
26+
return NewSideTransportClient(conn), nil
27+
}
28+
return nil, nil
29+
}

pkg/kv/kvserver/closedts/sidetransport/sender.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,18 +783,18 @@ func (r *rpcConn) close() {
783783
atomic.StoreInt32(&r.closed, 1)
784784
}
785785

786-
func (r *rpcConn) maybeConnect(ctx context.Context, stopper *stop.Stopper) error {
786+
func (r *rpcConn) maybeConnect(ctx context.Context, _ *stop.Stopper) error {
787787
if r.stream != nil {
788788
// Already connected.
789789
return nil
790790
}
791791

792-
conn, err := r.dialer.Dial(ctx, r.nodeID, rpcbase.SystemClass)
792+
client, err := ctpb.DialSideTransportClient(r.dialer, ctx, r.nodeID, rpcbase.SystemClass)
793793
if err != nil {
794794
return err
795795
}
796796
streamCtx, cancel := context.WithCancel(ctx)
797-
stream, err := ctpb.NewSideTransportClient(conn).PushUpdates(streamCtx)
797+
stream, err := client.PushUpdates(streamCtx)
798798
if err != nil {
799799
cancel()
800800
return err

pkg/kv/kvserver/raft_transport.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ func (t *RaftTransport) StopOutgoingMessage(storeID roachpb.StoreID) {
547547
// lost and a new instance of processQueue will be started by the next message
548548
// to be sent.
549549
func (t *RaftTransport) processQueue(
550-
q *raftSendQueue, stream MultiRaft_RaftMessageBatchClient, class rpcbase.ConnectionClass,
550+
q *raftSendQueue, stream MultiRaft_RaftMessageBatchClient, _ rpcbase.ConnectionClass,
551551
) error {
552552
errCh := make(chan error, 1)
553553

@@ -844,13 +844,11 @@ func (t *RaftTransport) startProcessNewQueue(
844844
t.connectionMu.connectionTracker.markNodeDisconnected(toNodeID, class)
845845
t.connectionMu.Unlock()
846846
}()
847-
conn, err := t.dialer.Dial(ctx, toNodeID, class)
847+
client, err := DialMultiRaftClient(t.dialer, ctx, toNodeID, class)
848848
if err != nil {
849849
// DialNode already logs sufficiently, so just return.
850850
return
851851
}
852-
853-
client := NewMultiRaftClient(conn)
854852
batchCtx, cancel := context.WithCancel(ctx)
855853
defer cancel()
856854

@@ -977,11 +975,10 @@ func (t *RaftTransport) SendSnapshot(
977975
) (*kvserverpb.SnapshotResponse, error) {
978976
nodeID := header.RaftMessageRequest.ToReplica.NodeID
979977

980-
conn, err := t.dialer.Dial(ctx, nodeID, rpcbase.DefaultClass)
978+
client, err := DialMultiRaftClient(t.dialer, ctx, nodeID, rpcbase.DefaultClass)
981979
if err != nil {
982980
return nil, err
983981
}
984-
client := NewMultiRaftClient(conn)
985982
stream, err := client.RaftSnapshot(ctx)
986983
if err != nil {
987984
return nil, err
@@ -1001,11 +998,10 @@ func (t *RaftTransport) DelegateSnapshot(
1001998
ctx context.Context, req *kvserverpb.DelegateSendSnapshotRequest,
1002999
) (*kvserverpb.DelegateSnapshotResponse, error) {
10031000
nodeID := req.DelegatedSender.NodeID
1004-
conn, err := t.dialer.Dial(ctx, nodeID, rpcbase.DefaultClass)
1001+
client, err := DialMultiRaftClient(t.dialer, ctx, nodeID, rpcbase.DefaultClass)
10051002
if err != nil {
10061003
return nil, errors.Mark(err, errMarkSnapshotError)
10071004
}
1008-
client := NewMultiRaftClient(conn)
10091005

10101006
// Creates a rpc stream between the leaseholder and sender.
10111007
stream, err := client.DelegateRaftSnapshot(ctx)

pkg/kv/kvserver/replica_command.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,11 +1018,11 @@ func waitForApplication(
10181018
for _, repl := range replicas {
10191019
repl := repl // copy for goroutine
10201020
g.GoCtx(func(ctx context.Context) error {
1021-
conn, err := dialer.Dial(ctx, repl.NodeID, rpcbase.DefaultClass)
1021+
client, err := DialPerReplicaClient(dialer, ctx, repl.NodeID, rpcbase.DefaultClass)
10221022
if err != nil {
10231023
return errors.Wrapf(err, "could not dial n%d", repl.NodeID)
10241024
}
1025-
_, err = NewPerReplicaClient(conn).WaitForApplication(ctx, &WaitForApplicationRequest{
1025+
_, err = client.WaitForApplication(ctx, &WaitForApplicationRequest{
10261026
StoreRequestHeader: StoreRequestHeader{NodeID: repl.NodeID, StoreID: repl.StoreID},
10271027
RangeID: rangeID,
10281028
LeaseIndex: leaseIndex,
@@ -1048,11 +1048,11 @@ func waitForReplicasInit(
10481048
for _, repl := range replicas {
10491049
repl := repl // copy for goroutine
10501050
g.GoCtx(func(ctx context.Context) error {
1051-
conn, err := dialer.Dial(ctx, repl.NodeID, rpcbase.DefaultClass)
1051+
client, err := DialPerReplicaClient(dialer, ctx, repl.NodeID, rpcbase.DefaultClass)
10521052
if err != nil {
10531053
return errors.Wrapf(err, "could not dial n%d", repl.NodeID)
10541054
}
1055-
_, err = NewPerReplicaClient(conn).WaitForReplicaInit(ctx, &WaitForReplicaInitRequest{
1055+
_, err = client.WaitForReplicaInit(ctx, &WaitForReplicaInitRequest{
10561056
StoreRequestHeader: StoreRequestHeader{NodeID: repl.NodeID, StoreID: repl.StoreID},
10571057
RangeID: rangeID,
10581058
})

pkg/kv/kvserver/replica_consistency.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,11 @@ type ConsistencyCheckResult struct {
269269
func (r *Replica) collectChecksumFromReplica(
270270
ctx context.Context, replica roachpb.ReplicaDescriptor, id uuid.UUID,
271271
) (CollectChecksumResponse, error) {
272-
conn, err := r.store.cfg.NodeDialer.Dial(ctx, replica.NodeID, rpcbase.DefaultClass)
272+
client, err := DialPerReplicaClient(r.store.cfg.NodeDialer, ctx, replica.NodeID, rpcbase.DefaultClass)
273273
if err != nil {
274274
return CollectChecksumResponse{},
275275
errors.Wrapf(err, "could not dial node ID %d", replica.NodeID)
276276
}
277-
client := NewPerReplicaClient(conn)
278277
req := &CollectChecksumRequest{
279278
StoreRequestHeader: StoreRequestHeader{NodeID: replica.NodeID, StoreID: replica.StoreID},
280279
RangeID: r.RangeID,

pkg/kv/kvserver/rpc_clients.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package kvserver
7+
8+
import (
9+
context "context"
10+
11+
"github.com/cockroachdb/cockroach/pkg/roachpb"
12+
"github.com/cockroachdb/cockroach/pkg/rpc/rpcbase"
13+
)
14+
15+
// DialMultiRaftClient establishes a DRPC connection if enabled; otherwise,
16+
// it falls back to gRPC. The established connection is used to create a
17+
// MultiRaftClient.
18+
func DialMultiRaftClient(
19+
nd rpcbase.NodeDialer, ctx context.Context, nodeID roachpb.NodeID, class rpcbase.ConnectionClass,
20+
) (MultiRaftClient, error) {
21+
if !rpcbase.TODODRPC {
22+
conn, err := nd.Dial(ctx, nodeID, class)
23+
if err != nil {
24+
return nil, err
25+
}
26+
return NewMultiRaftClient(conn), nil
27+
}
28+
return nil, nil
29+
}
30+
31+
// DialPerReplicaClient establishes a DRPC connection if enabled; otherwise,
32+
// it falls back to gRPC. The established connection is used to create a
33+
// PerReplicaClient.
34+
func DialPerReplicaClient(
35+
nd rpcbase.NodeDialer, ctx context.Context, nodeID roachpb.NodeID, class rpcbase.ConnectionClass,
36+
) (PerReplicaClient, error) {
37+
if !rpcbase.TODODRPC {
38+
conn, err := nd.Dial(ctx, nodeID, class)
39+
if err != nil {
40+
return nil, err
41+
}
42+
return NewPerReplicaClient(conn), nil
43+
}
44+
return nil, nil
45+
}
46+
47+
// DialPerStoreClient establishes a DRPC connection if enabled; otherwise,
48+
// it falls back to gRPC. The established connection is used to create a
49+
// PerStoreClient.
50+
func DialPerStoreClient(
51+
nd rpcbase.NodeDialer, ctx context.Context, nodeID roachpb.NodeID, class rpcbase.ConnectionClass,
52+
) (PerStoreClient, error) {
53+
if !rpcbase.TODODRPC {
54+
conn, err := nd.Dial(ctx, nodeID, class)
55+
if err != nil {
56+
return nil, err
57+
}
58+
return NewPerStoreClient(conn), nil
59+
}
60+
return nil, nil
61+
}

pkg/kv/kvserver/storage_engine_client.go

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,10 @@ func NewStorageEngineClient(nd *nodedialer.Dialer) *StorageEngineClient {
2929
func (c *StorageEngineClient) CompactEngineSpan(
3030
ctx context.Context, nodeID, storeID int32, startKey, endKey []byte,
3131
) error {
32-
conn, err := c.nd.Dial(ctx, roachpb.NodeID(nodeID), rpcbase.DefaultClass)
32+
client, err := DialPerStoreClient(c.nd, ctx, roachpb.NodeID(nodeID), rpcbase.DefaultClass)
3333
if err != nil {
3434
return errors.Wrapf(err, "could not dial node ID %d", nodeID)
3535
}
36-
client := NewPerStoreClient(conn)
3736
req := &CompactEngineSpanRequest{
3837
StoreRequestHeader: StoreRequestHeader{
3938
NodeID: roachpb.NodeID(nodeID),
@@ -49,12 +48,10 @@ func (c *StorageEngineClient) CompactEngineSpan(
4948
func (c *StorageEngineClient) GetTableMetrics(
5049
ctx context.Context, nodeID, storeID int32, startKey, endKey []byte,
5150
) ([]enginepb.SSTableMetricsInfo, error) {
52-
conn, err := c.nd.Dial(ctx, roachpb.NodeID(nodeID), rpcbase.DefaultClass)
51+
client, err := DialPerStoreClient(c.nd, ctx, roachpb.NodeID(nodeID), rpcbase.DefaultClass)
5352
if err != nil {
5453
return []enginepb.SSTableMetricsInfo{}, errors.Wrapf(err, "could not dial node ID %d", nodeID)
5554
}
56-
57-
client := NewPerStoreClient(conn)
5855
req := &GetTableMetricsRequest{
5956
StoreRequestHeader: StoreRequestHeader{
6057
NodeID: roachpb.NodeID(nodeID),
@@ -75,12 +72,10 @@ func (c *StorageEngineClient) GetTableMetrics(
7572
func (c *StorageEngineClient) ScanStorageInternalKeys(
7673
ctx context.Context, nodeID, storeID int32, startKey, endKey []byte, megabytesPerSecond int64,
7774
) ([]enginepb.StorageInternalKeysMetrics, error) {
78-
conn, err := c.nd.Dial(ctx, roachpb.NodeID(nodeID), rpcbase.DefaultClass)
75+
client, err := DialPerStoreClient(c.nd, ctx, roachpb.NodeID(nodeID), rpcbase.DefaultClass)
7976
if err != nil {
8077
return []enginepb.StorageInternalKeysMetrics{}, errors.Wrapf(err, "could not dial node ID %d", nodeID)
8178
}
82-
83-
client := NewPerStoreClient(conn)
8479
req := &ScanStorageInternalKeysRequest{
8580
StoreRequestHeader: StoreRequestHeader{
8681
NodeID: roachpb.NodeID(nodeID),
@@ -102,11 +97,10 @@ func (c *StorageEngineClient) ScanStorageInternalKeys(
10297
func (c *StorageEngineClient) SetCompactionConcurrency(
10398
ctx context.Context, nodeID, storeID int32, compactionConcurrency uint64,
10499
) error {
105-
conn, err := c.nd.Dial(ctx, roachpb.NodeID(nodeID), rpcbase.DefaultClass)
100+
client, err := DialPerStoreClient(c.nd, ctx, roachpb.NodeID(nodeID), rpcbase.DefaultClass)
106101
if err != nil {
107102
return errors.Wrapf(err, "could not dial node ID %d", nodeID)
108103
}
109-
client := NewPerStoreClient(conn)
110104
req := &CompactionConcurrencyRequest{
111105
StoreRequestHeader: StoreRequestHeader{
112106
NodeID: roachpb.NodeID(nodeID),

0 commit comments

Comments
 (0)