Skip to content

Commit e109096

Browse files
committed
rpc: drpc interceptors run task within stopper
This change adds support for running tasks within stoppers for DRPC interceptors. Now, DRPC unary and streaming RPCs run their handlers inside Stopper tasks, ensuring requests are tracked and rejected once draining begins. Epic: CRDB-49359 Fixes: #144371 Release note: None
1 parent 9a77ab6 commit e109096

File tree

3 files changed

+126
-13
lines changed

3 files changed

+126
-13
lines changed

pkg/rpc/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ go_test(
115115
"context_test.go",
116116
"datadriven_test.go",
117117
"down_node_test.go",
118+
"drpc_test.go",
118119
"heartbeat_test.go",
119120
"helpers_test.go",
120121
"main_test.go",
@@ -181,6 +182,7 @@ go_test(
181182
"@com_github_prometheus_client_model//go",
182183
"@com_github_stretchr_testify//assert",
183184
"@com_github_stretchr_testify//require",
185+
"@io_storj_drpc//:drpc",
184186
"@io_storj_drpc//drpcctx",
185187
"@io_storj_drpc//drpcmetadata",
186188
"@org_golang_google_grpc//:grpc",

pkg/rpc/drpc.go

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,41 @@ type drpcServer struct {
156156
drpc.Mux
157157
}
158158

159+
// makeStopperInterceptors returns unary and stream interceptors that run
160+
// incoming RPCs in stopper tasks.
161+
func makeStopperInterceptors(
162+
rpcCtx *Context,
163+
) (drpcmux.UnaryServerInterceptor, drpcmux.StreamServerInterceptor) {
164+
unary := func(
165+
ctx context.Context, req interface{}, rpc string, handler drpcmux.UnaryHandler,
166+
) (interface{}, error) {
167+
var resp interface{}
168+
if err := rpcCtx.Stopper.RunTaskWithErr(ctx, rpc, func(ctx context.Context) error {
169+
var err error
170+
resp, err = handler(ctx, req)
171+
return err
172+
}); err != nil {
173+
return nil, err
174+
}
175+
return resp, nil
176+
}
177+
178+
stream := func(
179+
stream drpc.Stream, rpc string, handler drpcmux.StreamHandler,
180+
) (interface{}, error) {
181+
var resp interface{}
182+
if err := rpcCtx.Stopper.RunTaskWithErr(stream.Context(), rpc, func(ctx context.Context) error {
183+
var err error
184+
resp, err = handler(stream)
185+
return err
186+
}); err != nil {
187+
return nil, err
188+
}
189+
return resp, nil
190+
}
191+
return unary, stream
192+
}
193+
159194
// NewDRPCServer creates a new DRPCServer with the provided rpc context.
160195
func NewDRPCServer(_ context.Context, rpcCtx *Context, opts ...ServerOption) (DRPCServer, error) {
161196
d := &drpcServer{}
@@ -168,6 +203,16 @@ func NewDRPCServer(_ context.Context, rpcCtx *Context, opts ...ServerOption) (DR
168203
var unaryInterceptors []drpcmux.UnaryServerInterceptor
169204
var streamInterceptors []drpcmux.StreamServerInterceptor
170205

206+
// These interceptors run in the order they're appended. The first
207+
// interceptor added becomes the outermost wrapper around the handler.
208+
209+
// We start with an interceptor that ensures every RPC executes inside a
210+
// stopper task. Running the handler in a stopper task lets the stopper
211+
// keep track of in-flight RPCs and reject new ones once draining begins.
212+
stopUnary, stopStream := makeStopperInterceptors(rpcCtx)
213+
unaryInterceptors = append(unaryInterceptors, stopUnary)
214+
streamInterceptors = append(streamInterceptors, stopStream)
215+
171216
if !rpcCtx.ContextOptions.Insecure {
172217
a := kvAuth{
173218
sv: &rpcCtx.Settings.SV,
@@ -194,19 +239,6 @@ func NewDRPCServer(_ context.Context, rpcCtx *Context, opts ...ServerOption) (DR
194239
})
195240
d.Mux = mux
196241

197-
// NB: any server middleware (server interceptors in gRPC parlance) would go
198-
// here:
199-
// dmux = whateverMiddleware1(dmux)
200-
// dmux = whateverMiddleware2(dmux)
201-
// ...
202-
//
203-
// Each middleware must implement the Handler interface:
204-
//
205-
// HandleRPC(stream Stream, rpc string) error
206-
//
207-
// where Stream
208-
// See here for an example:
209-
// https://github.com/bryk-io/pkg/blob/4da5fbfef47770be376e4022eab5c6c324984bf7/net/drpc/server.go#L91-L101
210242
return d, nil
211243
}
212244

pkg/rpc/drpc_test.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package rpc
7+
8+
import (
9+
"context"
10+
"testing"
11+
12+
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
13+
"github.com/cockroachdb/cockroach/pkg/util/log"
14+
"github.com/cockroachdb/cockroach/pkg/util/stop"
15+
"github.com/stretchr/testify/require"
16+
"storj.io/drpc"
17+
)
18+
19+
// dummyStream is a minimal implementation of drpc.Stream used for testing the
20+
// stream interceptor.
21+
type dummyStream struct {
22+
ctx context.Context
23+
}
24+
25+
func (d dummyStream) Context() context.Context { return d.ctx }
26+
func (d dummyStream) MsgSend(drpc.Message, drpc.Encoding) error { return nil }
27+
func (d dummyStream) MsgRecv(drpc.Message, drpc.Encoding) error { return nil }
28+
func (d dummyStream) CloseSend() error { return nil }
29+
func (d dummyStream) Close() error { return nil }
30+
31+
// TestMakeStopperInterceptors verifies that the stopper interceptors allow RPCs
32+
// to run before the stopper quiesces and reject them afterward.
33+
func TestMakeStopperInterceptors(t *testing.T) {
34+
defer leaktest.AfterTest(t)()
35+
defer log.Scope(t).Close(t)
36+
ctx := context.Background()
37+
stopper := stop.NewStopper()
38+
defer stopper.Stop(ctx)
39+
40+
rpcCtx := &Context{ContextOptions: ContextOptions{Stopper: stopper}}
41+
42+
unaryInterceptor, streamInterceptor := makeStopperInterceptors(rpcCtx)
43+
44+
// Before quiesce runs.
45+
called := false
46+
_, err := unaryInterceptor(ctx, nil, "test", func(ctx context.Context, req interface{}) (interface{}, error) {
47+
called = true
48+
return nil, nil
49+
})
50+
require.NoError(t, err)
51+
require.True(t, called)
52+
53+
called = false
54+
_, err = streamInterceptor(dummyStream{ctx: ctx}, "test", func(stream drpc.Stream) (interface{}, error) {
55+
called = true
56+
return nil, nil
57+
})
58+
require.NoError(t, err)
59+
require.True(t, called)
60+
61+
// After quiesce, RPCs are rejected.
62+
stopper.Quiesce(ctx)
63+
64+
called = false
65+
_, err = unaryInterceptor(ctx, nil, "test", func(ctx context.Context, req interface{}) (interface{}, error) {
66+
called = true
67+
return nil, nil
68+
})
69+
require.ErrorIs(t, err, stop.ErrUnavailable)
70+
require.False(t, called)
71+
72+
called = false
73+
_, err = streamInterceptor(dummyStream{ctx: ctx}, "test", func(stream drpc.Stream) (interface{}, error) {
74+
called = true
75+
return nil, nil
76+
})
77+
require.ErrorIs(t, err, stop.ErrUnavailable)
78+
require.False(t, called)
79+
}

0 commit comments

Comments
 (0)