@@ -7,11 +7,19 @@ package rpc
7
7
8
8
import (
9
9
"context"
10
+ "crypto/tls"
11
+ "sync/atomic"
10
12
"testing"
13
+ "time"
11
14
15
+ "github.com/cockroachdb/cockroach/pkg/roachpb"
16
+ "github.com/cockroachdb/cockroach/pkg/rpc/rpcbase"
12
17
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
13
18
"github.com/cockroachdb/cockroach/pkg/util/log"
19
+ "github.com/cockroachdb/cockroach/pkg/util/netutil"
14
20
"github.com/cockroachdb/cockroach/pkg/util/stop"
21
+ "github.com/cockroachdb/cockroach/pkg/util/timeutil"
22
+ "github.com/cockroachdb/cockroach/pkg/util/uuid"
15
23
"github.com/cockroachdb/errors"
16
24
"github.com/stretchr/testify/require"
17
25
"storj.io/drpc"
@@ -129,3 +137,47 @@ func TestGatewayRequestDRPCRecoveryInterceptor(t *testing.T) {
129
137
require .ErrorIs (t , err , expectedErr )
130
138
})
131
139
}
140
+
141
+ // TestDRPCServerWithInterceptor verifies that configured interceptors are
142
+ // invoked.
143
+ func TestDRPCServerWithInterceptor (t * testing.T ) {
144
+ defer leaktest .AfterTest (t )()
145
+ defer log .Scope (t ).Close (t )
146
+
147
+ ctx := context .Background ()
148
+ stopper := stop .NewStopper ()
149
+ defer stopper .Stop (ctx )
150
+
151
+ loopbackL := netutil .NewLoopbackListener (ctx , stopper )
152
+
153
+ serverCtx := newTestContext (uuid .MakeV4 (), timeutil .NewManualTime (timeutil .Unix (0 , 1 )), 0 , stopper )
154
+ serverCtx .NodeID .Set (ctx , 1 )
155
+ serverCtx .SetLoopbackDRPCDialer (loopbackL .Connect )
156
+ serverCtx .AdvertiseAddr = "127.0.0.1:1"
157
+ serverCtx .RPCHeartbeatInterval = 1 * time .Hour
158
+
159
+ var shouldBlock atomic.Bool
160
+ blockedErr := errors .New ("RPC blocked by interceptor" )
161
+ drpcServer , err := NewDRPCServer (ctx , serverCtx , WithInterceptor (func (rpcName string ) error {
162
+ if shouldBlock .Load () && rpcName == "/cockroach.rpc.Heartbeat/Ping" {
163
+ return blockedErr
164
+ }
165
+ return nil
166
+ }))
167
+ require .NoError (t , err )
168
+ require .NoError (t , DRPCRegisterHeartbeat (drpcServer , serverCtx .NewHeartbeatService ()))
169
+
170
+ tlsConfig , err := serverCtx .GetServerTLSConfig ()
171
+ require .NoError (t , err )
172
+ tlsListener := tls .NewListener (loopbackL , tlsConfig )
173
+ require .NoError (t , stopper .RunAsyncTask (ctx , "drpc-server" , func (ctx context.Context ) {
174
+ netutil .FatalIfUnexpected (drpcServer .Serve (ctx , tlsListener ))
175
+ }))
176
+
177
+ conn , err := serverCtx .DRPCDialNode ("127.0.0.1:1" , 1 , roachpb.Locality {}, rpcbase .DefaultClass ).Connect (ctx )
178
+ require .NoError (t , err )
179
+ client := NewDRPCHeartbeatClient (conn )
180
+ shouldBlock .Store (true )
181
+ _ , err = client .Ping (ctx , & PingRequest {})
182
+ require .Contains (t , err .Error (), "RPC blocked by interceptor" )
183
+ }
0 commit comments