Skip to content

Commit 1a26fe7

Browse files
craig[bot]Nukitt
andcommitted
Merge #153868
153868: rpc, server: add interceptor to catch panics in DRPC for DB console requests r=cthumuluru-crdb,shubhamdhama a=Nukitt Previously if a DRPC handler experienced an uncaught panic, the entire node would crash. When serving a DB console request, uncaught panics due to its functionality should not cause the CRDB node to crash. To address this, this patch introduces metadata into the context for all requests originating from the gateway and this is used by a panic recovery interceptor to detect such requests. This allows the server to decide whether to recover from the panic instead of letting it propagate and crash the node. Epic: none Fixes: #153452 Release note: None Co-authored-by: Nukitt <[email protected]>
2 parents eb654e4 + fcca805 commit 1a26fe7

File tree

5 files changed

+85
-1
lines changed

5 files changed

+85
-1
lines changed

pkg/rpc/drpc.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ func NewDRPCServer(_ context.Context, rpcCtx *Context, opts ...ServerOption) (DR
219219
unaryInterceptors = append(unaryInterceptors, stopUnary)
220220
streamInterceptors = append(streamInterceptors, stopStream)
221221

222+
// Recover from any uncaught panics caused by DB Console requests.
223+
unaryInterceptors = append(unaryInterceptors, DRPCGatewayRequestRecoveryInterceptor)
224+
222225
if !rpcCtx.ContextOptions.Insecure {
223226
a := kvAuth{
224227
sv: &rpcCtx.Settings.SV,

pkg/rpc/drpc_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
1313
"github.com/cockroachdb/cockroach/pkg/util/log"
1414
"github.com/cockroachdb/cockroach/pkg/util/stop"
15+
"github.com/cockroachdb/errors"
1516
"github.com/stretchr/testify/require"
1617
"storj.io/drpc"
1718
)
@@ -77,3 +78,54 @@ func TestMakeStopperInterceptors(t *testing.T) {
7778
require.ErrorIs(t, err, stop.ErrUnavailable)
7879
require.False(t, called)
7980
}
81+
82+
func TestGatewayRequestDRPCRecoveryInterceptor(t *testing.T) {
83+
defer leaktest.AfterTest(t)()
84+
85+
// With gateway metadata - should recover from panic
86+
t.Run("with gateway metadata", func(t *testing.T) {
87+
ctx := MarkDRPCGatewayRequest(context.Background())
88+
89+
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
90+
panic("test panic")
91+
}
92+
93+
resp, err := DRPCGatewayRequestRecoveryInterceptor(ctx, nil, "test", handler)
94+
95+
require.Nil(t, resp)
96+
require.ErrorContains(t, err, "unexpected error occurred")
97+
})
98+
99+
// Without gateway metadata - should not recover from panic
100+
t.Run("without gateway metadata", func(t *testing.T) {
101+
ctx := context.Background()
102+
103+
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
104+
panic("test panic")
105+
}
106+
107+
defer func() {
108+
if r := recover(); r == nil {
109+
t.Fatal("expected panic to propagate, got none")
110+
}
111+
}()
112+
113+
_, _ = DRPCGatewayRequestRecoveryInterceptor(ctx, nil, "test", handler)
114+
})
115+
116+
// With gateway metadata but no panic - should pass through normally
117+
t.Run("with gateway metadata no panic", func(t *testing.T) {
118+
ctx := MarkDRPCGatewayRequest(context.Background())
119+
120+
expectedResp := "success"
121+
expectedErr := errors.New("expected error")
122+
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
123+
return expectedResp, expectedErr
124+
}
125+
126+
resp, err := DRPCGatewayRequestRecoveryInterceptor(ctx, nil, "test", handler)
127+
128+
require.Equal(t, expectedResp, resp)
129+
require.ErrorIs(t, err, expectedErr)
130+
})
131+
}

pkg/rpc/metrics.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ import (
2323
"google.golang.org/grpc/codes"
2424
"google.golang.org/grpc/metadata"
2525
"google.golang.org/grpc/status"
26+
"storj.io/drpc/drpcmetadata"
27+
"storj.io/drpc/drpcmux"
2628
)
2729

2830
// gwRequestKey is a field set on the context to indicate a request
29-
// is coming from gRPC gateway.
31+
// is coming from RPC gateway.
3032
const gwRequestKey = "gw-request"
3133

3234
var (
@@ -512,3 +514,27 @@ func gatewayRequestRecoveryInterceptor(
512514
resp, err = handler(ctx, req)
513515
return resp, err
514516
}
517+
518+
// MarkDRPCGatewayRequest annotates ctx so that downstream DRPC calls can
519+
// be recognized as originating from the DB Console HTTP gateway.
520+
func MarkDRPCGatewayRequest(ctx context.Context) context.Context {
521+
return drpcmetadata.Add(ctx, gwRequestKey, "true")
522+
}
523+
524+
// DRPCGatewayRequestRecoveryInterceptor recovers from panics in DRPC handlers
525+
// that are invoked due to DB console requests. For these requests, we do not
526+
// want an uncaught panic to crash the node.
527+
func DRPCGatewayRequestRecoveryInterceptor(
528+
ctx context.Context, req interface{}, rpc string, handler drpcmux.UnaryHandler,
529+
) (resp interface{}, err error) {
530+
if val, ok := drpcmetadata.GetValue(ctx, gwRequestKey); ok && val != "" {
531+
defer func() {
532+
if p := recover(); p != nil {
533+
logcrash.ReportPanic(ctx, nil, p, 1 /* depth */)
534+
err = errors.New("an unexpected error occurred")
535+
}
536+
}()
537+
}
538+
resp, err = handler(ctx, req)
539+
return resp, err
540+
}

pkg/server/apiinternal/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ go_library(
1010
visibility = ["//visibility:public"],
1111
deps = [
1212
"//pkg/roachpb",
13+
"//pkg/rpc",
1314
"//pkg/rpc/rpcbase",
1415
"//pkg/server/authserver",
1516
"//pkg/server/serverpb",

pkg/server/apiinternal/api_internal.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"reflect"
1414

1515
"github.com/cockroachdb/cockroach/pkg/roachpb"
16+
"github.com/cockroachdb/cockroach/pkg/rpc"
1617
"github.com/cockroachdb/cockroach/pkg/rpc/rpcbase"
1718
"github.com/cockroachdb/cockroach/pkg/server/authserver"
1819
"github.com/cockroachdb/cockroach/pkg/server/serverpb"
@@ -120,6 +121,7 @@ func executeRPC[TReq, TResp protoutil.Message](
120121
) error {
121122
ctx := req.Context()
122123
ctx = authserver.ForwardHTTPAuthInfoToRPCCalls(ctx, req)
124+
ctx = rpc.MarkDRPCGatewayRequest(ctx)
123125

124126
if err := decoder.Decode(rpcReq, req.URL.Query()); err != nil {
125127
return err

0 commit comments

Comments
 (0)