Skip to content

Commit 14ff833

Browse files
craig[bot]stevendanna
andcommitted
Merge #154842
154842: rpc: respect interceptors ServerOption in NewDRPCServer r=tbg a=stevendanna We were not adding the passed interceptor to the list of unary or stream interceptors. Now we do. This in turn addresses a race condition that we intended to solve in e005120. Part of this PR was authored by Claude Code. After a hint about DRPC vs gRPC, it was able to diagnose the problem given race detector output and generate the fix included here. Its initial attempts at creating a test either didn't actually assert anything relevant or was very large. After some prompting, it created a test similar to what is included here, but with many hallucinated functions. I then just fixed it by hand. Epic: none Release note: None Co-authored-by: Steven Danna <[email protected]>
2 parents 5d15bbf + a96c854 commit 14ff833

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

pkg/rpc/drpc.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,26 @@ func NewDRPCServer(_ context.Context, rpcCtx *Context, opts ...ServerOption) (DR
236236
streamInterceptors = append(streamInterceptors, a.AuthDRPCStream())
237237
}
238238

239+
if o.interceptor != nil {
240+
unaryInterceptors = append(unaryInterceptors, func(
241+
ctx context.Context, req interface{}, fullMethod string, handler drpcmux.UnaryHandler,
242+
) (interface{}, error) {
243+
if err := o.interceptor(fullMethod); err != nil {
244+
return nil, err
245+
}
246+
return handler(ctx, req)
247+
})
248+
249+
streamInterceptors = append(streamInterceptors, func(
250+
stream drpc.Stream, fullMethod string, handler drpcmux.StreamHandler,
251+
) (interface{}, error) {
252+
if err := o.interceptor(fullMethod); err != nil {
253+
return nil, err
254+
}
255+
return handler(stream)
256+
})
257+
}
258+
239259
if tracer := rpcCtx.Stopper.Tracer(); tracer != nil {
240260
unaryInterceptors = append(unaryInterceptors, drpcinterceptor.ServerInterceptor(tracer))
241261
streamInterceptors = append(streamInterceptors, drpcinterceptor.StreamServerInterceptor(tracer))

pkg/rpc/drpc_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,19 @@ package rpc
77

88
import (
99
"context"
10+
"crypto/tls"
11+
"sync/atomic"
1012
"testing"
13+
"time"
1114

15+
"github.com/cockroachdb/cockroach/pkg/roachpb"
16+
"github.com/cockroachdb/cockroach/pkg/rpc/rpcbase"
1217
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
1318
"github.com/cockroachdb/cockroach/pkg/util/log"
19+
"github.com/cockroachdb/cockroach/pkg/util/netutil"
1420
"github.com/cockroachdb/cockroach/pkg/util/stop"
21+
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
22+
"github.com/cockroachdb/cockroach/pkg/util/uuid"
1523
"github.com/cockroachdb/errors"
1624
"github.com/stretchr/testify/require"
1725
"storj.io/drpc"
@@ -129,3 +137,47 @@ func TestGatewayRequestDRPCRecoveryInterceptor(t *testing.T) {
129137
require.ErrorIs(t, err, expectedErr)
130138
})
131139
}
140+
141+
// TestDRPCServerWithInterceptor verifies that configured interceptors are
142+
// invoked.
143+
func TestDRPCServerWithInterceptor(t *testing.T) {
144+
defer leaktest.AfterTest(t)()
145+
defer log.Scope(t).Close(t)
146+
147+
ctx := context.Background()
148+
stopper := stop.NewStopper()
149+
defer stopper.Stop(ctx)
150+
151+
loopbackL := netutil.NewLoopbackListener(ctx, stopper)
152+
153+
serverCtx := newTestContext(uuid.MakeV4(), timeutil.NewManualTime(timeutil.Unix(0, 1)), 0, stopper)
154+
serverCtx.NodeID.Set(ctx, 1)
155+
serverCtx.SetLoopbackDRPCDialer(loopbackL.Connect)
156+
serverCtx.AdvertiseAddr = "127.0.0.1:1"
157+
serverCtx.RPCHeartbeatInterval = 1 * time.Hour
158+
159+
var shouldBlock atomic.Bool
160+
blockedErr := errors.New("RPC blocked by interceptor")
161+
drpcServer, err := NewDRPCServer(ctx, serverCtx, WithInterceptor(func(rpcName string) error {
162+
if shouldBlock.Load() && rpcName == "/cockroach.rpc.Heartbeat/Ping" {
163+
return blockedErr
164+
}
165+
return nil
166+
}))
167+
require.NoError(t, err)
168+
require.NoError(t, DRPCRegisterHeartbeat(drpcServer, serverCtx.NewHeartbeatService()))
169+
170+
tlsConfig, err := serverCtx.GetServerTLSConfig()
171+
require.NoError(t, err)
172+
tlsListener := tls.NewListener(loopbackL, tlsConfig)
173+
require.NoError(t, stopper.RunAsyncTask(ctx, "drpc-server", func(ctx context.Context) {
174+
netutil.FatalIfUnexpected(drpcServer.Serve(ctx, tlsListener))
175+
}))
176+
177+
conn, err := serverCtx.DRPCDialNode("127.0.0.1:1", 1, roachpb.Locality{}, rpcbase.DefaultClass).Connect(ctx)
178+
require.NoError(t, err)
179+
client := NewDRPCHeartbeatClient(conn)
180+
shouldBlock.Store(true)
181+
_, err = client.Ping(ctx, &PingRequest{})
182+
require.Contains(t, err.Error(), "RPC blocked by interceptor")
183+
}

0 commit comments

Comments
 (0)