Skip to content

Commit 249793c

Browse files
craig[bot]Nukitt
andcommitted
Merge #153012
153012: rpc: drpc interceptors run task within stopper r=shubhamdhama a=Nukitt This change adds support for running tasks within stoppers for DRPC interceptors. Now, DRPC unary and streaming RPCs run their handlers inside Stopper tasks, ensuring requests are tracked and rejected once draining begins. Epic: CRDB-49359 Fixes: #144371 Release note: None Co-authored-by: Nukitt <[email protected]>
2 parents 8b39430 + e109096 commit 249793c

File tree

3 files changed

+126
-13
lines changed

3 files changed

+126
-13
lines changed

pkg/rpc/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ go_test(
114114
"context_test.go",
115115
"datadriven_test.go",
116116
"down_node_test.go",
117+
"drpc_test.go",
117118
"heartbeat_test.go",
118119
"helpers_test.go",
119120
"main_test.go",
@@ -180,6 +181,7 @@ go_test(
180181
"@com_github_prometheus_client_model//go",
181182
"@com_github_stretchr_testify//assert",
182183
"@com_github_stretchr_testify//require",
184+
"@io_storj_drpc//:drpc",
183185
"@io_storj_drpc//drpcctx",
184186
"@io_storj_drpc//drpcmetadata",
185187
"@io_storj_drpc//drpcmux",

pkg/rpc/drpc.go

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,41 @@ type drpcServer struct {
161161
drpc.Mux
162162
}
163163

164+
// makeStopperInterceptors returns unary and stream interceptors that run
165+
// incoming RPCs in stopper tasks.
166+
func makeStopperInterceptors(
167+
rpcCtx *Context,
168+
) (drpcmux.UnaryServerInterceptor, drpcmux.StreamServerInterceptor) {
169+
unary := func(
170+
ctx context.Context, req interface{}, rpc string, handler drpcmux.UnaryHandler,
171+
) (interface{}, error) {
172+
var resp interface{}
173+
if err := rpcCtx.Stopper.RunTaskWithErr(ctx, rpc, func(ctx context.Context) error {
174+
var err error
175+
resp, err = handler(ctx, req)
176+
return err
177+
}); err != nil {
178+
return nil, err
179+
}
180+
return resp, nil
181+
}
182+
183+
stream := func(
184+
stream drpc.Stream, rpc string, handler drpcmux.StreamHandler,
185+
) (interface{}, error) {
186+
var resp interface{}
187+
if err := rpcCtx.Stopper.RunTaskWithErr(stream.Context(), rpc, func(ctx context.Context) error {
188+
var err error
189+
resp, err = handler(stream)
190+
return err
191+
}); err != nil {
192+
return nil, err
193+
}
194+
return resp, nil
195+
}
196+
return unary, stream
197+
}
198+
164199
// NewDRPCServer creates a new DRPCServer with the provided rpc context.
165200
func NewDRPCServer(_ context.Context, rpcCtx *Context, opts ...ServerOption) (DRPCServer, error) {
166201
d := &drpcServer{}
@@ -173,6 +208,16 @@ func NewDRPCServer(_ context.Context, rpcCtx *Context, opts ...ServerOption) (DR
173208
var unaryInterceptors []drpcmux.UnaryServerInterceptor
174209
var streamInterceptors []drpcmux.StreamServerInterceptor
175210

211+
// These interceptors run in the order they're appended. The first
212+
// interceptor added becomes the outermost wrapper around the handler.
213+
214+
// We start with an interceptor that ensures every RPC executes inside a
215+
// stopper task. Running the handler in a stopper task lets the stopper
216+
// keep track of in-flight RPCs and reject new ones once draining begins.
217+
stopUnary, stopStream := makeStopperInterceptors(rpcCtx)
218+
unaryInterceptors = append(unaryInterceptors, stopUnary)
219+
streamInterceptors = append(streamInterceptors, stopStream)
220+
176221
if !rpcCtx.ContextOptions.Insecure {
177222
a := kvAuth{
178223
sv: &rpcCtx.Settings.SV,
@@ -199,19 +244,6 @@ func NewDRPCServer(_ context.Context, rpcCtx *Context, opts ...ServerOption) (DR
199244
})
200245
d.Mux = mux
201246

202-
// NB: any server middleware (server interceptors in gRPC parlance) would go
203-
// here:
204-
// dmux = whateverMiddleware1(dmux)
205-
// dmux = whateverMiddleware2(dmux)
206-
// ...
207-
//
208-
// Each middleware must implement the Handler interface:
209-
//
210-
// HandleRPC(stream Stream, rpc string) error
211-
//
212-
// where Stream
213-
// See here for an example:
214-
// https://github.com/bryk-io/pkg/blob/4da5fbfef47770be376e4022eab5c6c324984bf7/net/drpc/server.go#L91-L101
215247
return d, nil
216248
}
217249

pkg/rpc/drpc_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package rpc
7+
8+
import (
9+
"context"
10+
"testing"
11+
12+
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
13+
"github.com/cockroachdb/cockroach/pkg/util/log"
14+
"github.com/cockroachdb/cockroach/pkg/util/stop"
15+
"github.com/stretchr/testify/require"
16+
"storj.io/drpc"
17+
)
18+
19+
// dummyStream is a minimal implementation of drpc.Stream used for testing the
20+
// stream interceptor.
21+
type dummyStream struct {
22+
ctx context.Context
23+
}
24+
25+
func (d dummyStream) Context() context.Context { return d.ctx }
26+
func (d dummyStream) MsgSend(drpc.Message, drpc.Encoding) error { return nil }
27+
func (d dummyStream) MsgRecv(drpc.Message, drpc.Encoding) error { return nil }
28+
func (d dummyStream) CloseSend() error { return nil }
29+
func (d dummyStream) Close() error { return nil }
30+
31+
// TestMakeStopperInterceptors verifies that the stopper interceptors allow RPCs
32+
// to run before the stopper quiesces and reject them afterward.
33+
func TestMakeStopperInterceptors(t *testing.T) {
34+
defer leaktest.AfterTest(t)()
35+
defer log.Scope(t).Close(t)
36+
ctx := context.Background()
37+
stopper := stop.NewStopper()
38+
defer stopper.Stop(ctx)
39+
40+
rpcCtx := &Context{ContextOptions: ContextOptions{Stopper: stopper}}
41+
42+
unaryInterceptor, streamInterceptor := makeStopperInterceptors(rpcCtx)
43+
44+
// Before quiesce runs.
45+
called := false
46+
_, err := unaryInterceptor(ctx, nil, "test", func(ctx context.Context, req interface{}) (interface{}, error) {
47+
called = true
48+
return nil, nil
49+
})
50+
require.NoError(t, err)
51+
require.True(t, called)
52+
53+
called = false
54+
_, err = streamInterceptor(dummyStream{ctx: ctx}, "test", func(stream drpc.Stream) (interface{}, error) {
55+
called = true
56+
return nil, nil
57+
})
58+
require.NoError(t, err)
59+
require.True(t, called)
60+
61+
// After quiesce, RPCs are rejected.
62+
stopper.Quiesce(ctx)
63+
64+
called = false
65+
_, err = unaryInterceptor(ctx, nil, "test", func(ctx context.Context, req interface{}) (interface{}, error) {
66+
called = true
67+
return nil, nil
68+
})
69+
require.ErrorIs(t, err, stop.ErrUnavailable)
70+
require.False(t, called)
71+
72+
called = false
73+
_, err = streamInterceptor(dummyStream{ctx: ctx}, "test", func(stream drpc.Stream) (interface{}, error) {
74+
called = true
75+
return nil, nil
76+
})
77+
require.ErrorIs(t, err, stop.ErrUnavailable)
78+
require.False(t, called)
79+
}

0 commit comments

Comments
 (0)