diff --git a/common/requestfilter/rulesverifier_test.go b/common/requestfilter/rulesverifier_test.go index 74e3a52f6..33bb72275 100644 --- a/common/requestfilter/rulesverifier_test.go +++ b/common/requestfilter/rulesverifier_test.go @@ -9,6 +9,7 @@ package requestfilter_test import ( "errors" "sync" + "sync/atomic" "testing" "time" @@ -51,16 +52,24 @@ func TestRulesVerifier(t *testing.T) { }) } -// scenario: test that verifiers can run in parallel. +// scenario: test that verifiers will execute verify in parallel. func TestConcurrentVeirfy(t *testing.T) { v := requestfilter.NewRulesVerifier(nil) - fc := &mocks.FakeFilterConfig{} - fc.GetMaxSizeBytesReturns(1000, nil) - v.AddRule(requestfilter.NewMaxSizeFilter(fc)) + fr := &mocks.FakeRule{} + var activeVerifiers int64 = 0 + var maxVerifiers int64 = 0 + fr.VerifyStub = func(r *comm.Request) error { + atomic.AddInt64(&activeVerifiers, 1) + time.Sleep(50 * time.Millisecond) + atomic.StoreInt64(&maxVerifiers, max(atomic.LoadInt64(&activeVerifiers), atomic.LoadInt64(&maxVerifiers))) + time.Sleep(50 * time.Millisecond) + atomic.AddInt64(&activeVerifiers, -1) + return nil + } + v.AddRule(fr) var wg sync.WaitGroup start := make(chan struct{}) - end := make(chan struct{}) verifiers := 8 for i := 0; i < verifiers; i++ { @@ -70,21 +79,12 @@ func TestConcurrentVeirfy(t *testing.T) { <-start err := v.Verify(&comm.Request{}) require.NoError(t, err) - time.Sleep(1 * time.Second) }() } - go func() { - close(start) - wg.Wait() - close(end) - }() - - select { - case <-end: - case <-time.After(7 * time.Second): - t.Error("concurrent verify took too long") - } + close(start) + wg.Wait() + require.True(t, maxVerifiers > 1) } // scenario: multiple goroutines call verify and update. chech with -race flag. diff --git a/config/config.go b/config/config.go index a5707a4e9..63fcdc331 100644 --- a/config/config.go +++ b/config/config.go @@ -155,6 +155,7 @@ func (config *Configuration) ExtractRouterConfig() *nodeconfig.RouterNodeConfig NumOfgRPCStreamsPerConnection: config.LocalConfig.NodeLocalConfig.RouterParams.NumberOfStreamsPerConnection, UseTLS: config.LocalConfig.TLSConfig.Enabled, ClientAuthRequired: config.LocalConfig.TLSConfig.ClientAuthRequired, + RequestMaxBytes: config.SharedConfig.BatchingConfig.RequestMaxBytes, } return routerConfig } diff --git a/node/config/config.go b/node/config/config.go index aeb0b8df0..5e2f9e6ce 100644 --- a/node/config/config.go +++ b/node/config/config.go @@ -77,6 +77,7 @@ type RouterNodeConfig struct { NumOfgRPCStreamsPerConnection int UseTLS bool ClientAuthRequired bool + RequestMaxBytes uint64 } type AssemblerNodeConfig struct { diff --git a/node/config/utils.go b/node/config/utils.go index 4b9afb6ca..de0eae8cc 100644 --- a/node/config/utils.go +++ b/node/config/utils.go @@ -22,3 +22,7 @@ func (c *BatcherNodeConfig) GetShardsIDs() []types.ShardID { }) return ids } + +func (rc *RouterNodeConfig) GetMaxSizeBytes() (uint64, error) { + return rc.RequestMaxBytes, nil +} diff --git a/node/router/router.go b/node/router/router.go index 4af3fc025..6cd109f9e 100644 --- a/node/router/router.go +++ b/node/router/router.go @@ -20,6 +20,7 @@ import ( "github.com/hyperledger/fabric-protos-go-apiv2/common" "github.com/hyperledger/fabric-protos-go-apiv2/orderer" + "github.com/hyperledger/fabric-x-orderer/common/requestfilter" "github.com/hyperledger/fabric-x-orderer/common/types" "github.com/hyperledger/fabric-x-orderer/config" "github.com/hyperledger/fabric-x-orderer/node" @@ -40,6 +41,7 @@ type Router struct { shardIDs []types.ShardID incoming uint64 routerNodeConfig *nodeconfig.RouterNodeConfig + verifier *requestfilter.RulesVerifier } func NewRouter(config *nodeconfig.RouterNodeConfig, logger types.Logger) *Router { @@ -68,7 +70,9 @@ func NewRouter(config *nodeconfig.RouterNodeConfig, logger types.Logger) *Router return int(shardIDs[i]) < int(shardIDs[j]) }) - r := createRouter(shardIDs, batcherEndpoints, tlsCAsOfBatchers, config, logger) + verifier := createVerifier(config) + + r := createRouter(shardIDs, batcherEndpoints, tlsCAsOfBatchers, config, logger, verifier) r.init() return r } @@ -169,7 +173,7 @@ func (r *Router) Deliver(server orderer.AtomicBroadcast_DeliverServer) error { return fmt.Errorf("not implemented") } -func createRouter(shardIDs []types.ShardID, batcherEndpoints map[types.ShardID]string, batcherRootCAs map[types.ShardID][][]byte, rconfig *nodeconfig.RouterNodeConfig, logger types.Logger) *Router { +func createRouter(shardIDs []types.ShardID, batcherEndpoints map[types.ShardID]string, batcherRootCAs map[types.ShardID][][]byte, rconfig *nodeconfig.RouterNodeConfig, logger types.Logger, verifier *requestfilter.RulesVerifier) *Router { if rconfig.NumOfConnectionsForBatcher == 0 { rconfig.NumOfConnectionsForBatcher = config.DefaultRouterParams.NumberOfConnectionsPerBatcher } @@ -187,10 +191,11 @@ func createRouter(shardIDs []types.ShardID, batcherEndpoints map[types.ShardID]s logger: logger, shardIDs: shardIDs, routerNodeConfig: rconfig, + verifier: verifier, } for _, shardId := range shardIDs { - r.shardRouters[shardId] = NewShardRouter(logger, batcherEndpoints[shardId], batcherRootCAs[shardId], rconfig.TLSCertificateFile, rconfig.TLSPrivateKeyFile, rconfig.NumOfConnectionsForBatcher, rconfig.NumOfgRPCStreamsPerConnection) + r.shardRouters[shardId] = NewShardRouter(logger, batcherEndpoints[shardId], batcherRootCAs[shardId], rconfig.TLSCertificateFile, rconfig.TLSPrivateKeyFile, rconfig.NumOfConnectionsForBatcher, rconfig.NumOfgRPCStreamsPerConnection, verifier) } go func() { @@ -326,6 +331,13 @@ func createTraceID(rand *rand2.Rand) []byte { return trace } +func createVerifier(config *nodeconfig.RouterNodeConfig) *requestfilter.RulesVerifier { + rv := requestfilter.NewRulesVerifier(nil) + rv.AddRule(requestfilter.PayloadNotEmptyRule{}) + rv.AddRule(requestfilter.NewMaxSizeFilter(config)) + return rv +} + // IsAllStreamsOK checks that all the streams accross all shard-routers are non-faulty. // Use for testing only. func (r *Router) IsAllStreamsOK() bool { diff --git a/node/router/router_test.go b/node/router/router_test.go index 9ed920560..7e6da2e88 100644 --- a/node/router/router_test.go +++ b/node/router/router_test.go @@ -380,6 +380,54 @@ func TestClientRouterBroadcastRequestsAgainstMultipleBatchers(t *testing.T) { }, 60*time.Second, 10*time.Millisecond) } +// test request filters +// 1) Start a client, router and stub batcher +// 2) Send valid request, expect no error. +// 3) Send request with empty payload, expect error. +// 4) Send request that exceed the maximal size, expect error. +// 5) ** Not implemented ** send request with bad signature, expect error. +func TestRequestFilters(t *testing.T) { + // 1) Start a client, router and stub batcher + testSetup := createRouterTestSetup(t, types.PartyID(1), 1, true, false) + err := createServerTLSClientConnection(testSetup, testSetup.ca) + require.NoError(t, err) + require.NotNil(t, testSetup.clientConn) + defer testSetup.Close() + conn := testSetup.clientConn + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + cl := protos.NewRequestTransmitClient(conn) + defer cancel() + + // 2) send a valid request. + buff := make([]byte, 300) + binary.BigEndian.PutUint32(buff, uint32(12345)) + req := &protos.Request{ + Payload: buff, + } + resp, err := cl.Submit(ctx, req) + require.NoError(t, err) + require.Equal(t, "", resp.Error) + + // 3) send request with empty payload. + req = &protos.Request{ + Payload: nil, + } + resp, err = cl.Submit(ctx, req) + require.NoError(t, err) + require.Equal(t, "request verification error: empty payload field", resp.Error) + // 4) send request with payload too big. (3000 is more than 1 << 10, the maximal request size in bytes) + buff = make([]byte, 3000) + binary.BigEndian.PutUint32(buff, uint32(12345)) + req = &protos.Request{ + Payload: buff, + } + resp, err = cl.Submit(ctx, req) + require.NoError(t, err) + require.Equal(t, "request verification error: the request's size exceeds the maximum size: actual = 3000, limit = 1024", resp.Error) + + // 5) send request with invalid signature. Not implemented +} + func createServerTLSClientConnection(testSetup *routerTestSetup, ca tlsgen.CA) error { cc := comm.ClientConfig{ SecOpts: comm.SecureOptions{ @@ -558,6 +606,7 @@ func createAndStartRouter(t *testing.T, partyID types.PartyID, ca tlsgen.CA, bat ListenAddress: "127.0.0.1:0", ClientAuthRequired: clientAuthRequired, Shards: shards, + RequestMaxBytes: 1 << 10, } r := router.NewRouter(conf, logger) diff --git a/node/router/shard_router.go b/node/router/shard_router.go index 2d97cb9e5..32cbdabd4 100644 --- a/node/router/shard_router.go +++ b/node/router/shard_router.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc/connectivity" + "github.com/hyperledger/fabric-x-orderer/common/requestfilter" "github.com/hyperledger/fabric-x-orderer/common/types" "github.com/hyperledger/fabric-x-orderer/node/comm" protos "github.com/hyperledger/fabric-x-orderer/node/protos/comm" @@ -69,6 +70,7 @@ type ShardRouter struct { closeReconnectOnce sync.Once reconnectRequests chan reconnectReq closeReconnect chan bool + verifier *requestfilter.RulesVerifier } func NewShardRouter(l types.Logger, @@ -78,6 +80,7 @@ func NewShardRouter(l types.Logger, tlsKey []byte, numOfConnectionsForBatcher int, numOfgRPCStreamsPerConnection int, + verifier *requestfilter.RulesVerifier, ) *ShardRouter { cc := comm.ClientConfig{ AsyncConnect: false, @@ -106,6 +109,7 @@ func NewShardRouter(l types.Logger, clientConfig: cc, reconnectRequests: make(chan reconnectReq, 2*numOfgRPCStreamsPerConnection*numOfConnectionsForBatcher), closeReconnect: make(chan bool), + verifier: verifier, } return sr @@ -332,6 +336,7 @@ func (sr *ShardRouter) initStream(i int, j int) error { streamNum: j, srReconnectChan: sr.reconnectRequests, notifiedReconnect: false, + verifier: sr.verifier, } go s.sendRequests() go s.readResponses() diff --git a/node/router/shard_router_test.go b/node/router/shard_router_test.go index e1a4fb302..af5581545 100644 --- a/node/router/shard_router_test.go +++ b/node/router/shard_router_test.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc/grpclog" + "github.com/hyperledger/fabric-x-orderer/common/requestfilter" "github.com/hyperledger/fabric-x-orderer/common/types" "github.com/hyperledger/fabric-x-orderer/node/comm/tlsgen" "github.com/hyperledger/fabric-x-orderer/node/router" @@ -157,11 +158,14 @@ func createTestSetup(t *testing.T, partyID types.PartyID) *TestSetup { ckp, err := ca.NewServerCertKeyPair("127.0.0.1") require.NoError(t, err) + verifier := requestfilter.NewRulesVerifier(nil) + verifier.AddRule(requestfilter.AcceptRule{}) + // create stub batcher batcher := NewStubBatcher(t, ca, partyID, types.ShardID(1)) // create shard router - shardRouter := router.NewShardRouter(logger, batcher.GetBatcherEndpoint(), [][]byte{ca.CertBytes()}, ckp.Cert, ckp.Key, 10, 20) + shardRouter := router.NewShardRouter(logger, batcher.GetBatcherEndpoint(), [][]byte{ca.CertBytes()}, ckp.Cert, ckp.Key, 10, 20, verifier) // start the batcher batcher.Start() diff --git a/node/router/stream.go b/node/router/stream.go index 98fd3d80b..a9ce33c90 100644 --- a/node/router/stream.go +++ b/node/router/stream.go @@ -12,6 +12,7 @@ import ( "maps" "sync" + "github.com/hyperledger/fabric-x-orderer/common/requestfilter" "github.com/hyperledger/fabric-x-orderer/common/types" protos "github.com/hyperledger/fabric-x-orderer/node/protos/comm" ) @@ -32,6 +33,7 @@ type stream struct { streamNum int srReconnectChan chan reconnectReq notifiedReconnect bool + verifier *requestfilter.RulesVerifier } // readResponses listens for responses from the batcher. @@ -50,18 +52,9 @@ func (s *stream) readResponses() { s.cancelOnServerError() return } - - s.lock.Lock() - ch, exists := s.requestTraceIdToResponseChannel[string(resp.TraceId)] - delete(s.requestTraceIdToResponseChannel, string(resp.TraceId)) - s.lock.Unlock() - if exists { - s.logger.Debugf("read response from batcher %s on request with trace id %x", s.endpoint, resp.TraceId) - s.logger.Debugf("registration for request with trace id %x was removed upon receiving a response", resp.TraceId) - ch <- Response{ - SubmitResponse: resp, - } - } else { + s.logger.Debugf("read response from batcher %s on request with trace id %x", s.endpoint, resp.TraceId) + err = s.sendResponseToClient(resp) + if err != nil { s.logger.Debugf("received a response from batcher %s for a request with trace id %x, which does not exist in the map, dropping response", s.endpoint, resp.TraceId) } } @@ -81,17 +74,45 @@ func (s *stream) sendRequests() { s.cancelOnServerError() return } - s.logger.Debugf("send request with trace id %x to batcher %s", msg.TraceId, s.endpoint) - err := s.requestTransmitSubmitStreamClient.Send(msg) - if err != nil { - s.logger.Errorf("Failed sending request to batcher %s", s.endpoint) - s.cancelOnServerError() - return + // verify the request + if err := s.verifier.Verify(msg); err != nil { + s.logger.Debugf("request is invalid: %s", err) + // send a response to the client + resp := protos.SubmitResponse{Error: fmt.Sprintf("request verification error: %s", err), TraceId: msg.TraceId} + err = s.sendResponseToClient(&resp) + if err != nil { + s.logger.Debugf("error sending response to client: %s", err) + } + } else { + s.logger.Debugf("send request with trace id %x to batcher %s", msg.TraceId, s.endpoint) + err := s.requestTransmitSubmitStreamClient.Send(msg) + if err != nil { + s.logger.Errorf("Failed sending request to batcher %s", s.endpoint) + s.cancelOnServerError() + return + } } } } } +func (s *stream) sendResponseToClient(response *protos.SubmitResponse) error { + traceID := response.TraceId + s.lock.Lock() + ch, exists := s.requestTraceIdToResponseChannel[string(traceID)] + delete(s.requestTraceIdToResponseChannel, string(traceID)) + s.lock.Unlock() + if exists { + s.logger.Debugf("registration for request with trace id %x was removed upon receiving a response", traceID) + ch <- Response{ + SubmitResponse: response, + } + return nil + } else { + return fmt.Errorf("request with traceID %x is not in map", traceID) + } +} + func (s *stream) cancelOnServerError() { s.cancel() s.sendResponseToAllClientsOnError(fmt.Errorf("server error: could not establish connection between router and batcher %s", s.endpoint)) @@ -201,6 +222,7 @@ CopyChannelLoop: streamNum: s.streamNum, srReconnectChan: s.srReconnectChan, notifiedReconnect: false, + verifier: s.verifier, } s.lock.Unlock() diff --git a/node/router/stream_test.go b/node/router/stream_test.go index ae61ab466..0aea7ee37 100644 --- a/node/router/stream_test.go +++ b/node/router/stream_test.go @@ -14,6 +14,7 @@ import ( "testing" "time" + "github.com/hyperledger/fabric-x-orderer/common/requestfilter" "github.com/hyperledger/fabric-x-orderer/testutil" protos "github.com/hyperledger/fabric-x-orderer/node/protos/comm" @@ -52,7 +53,7 @@ func TestSendRequests(t *testing.T) { fakeSubmitStreamClient.SendReturns(nil) fakeSubmitStreamClient.ContextReturns(ctx) logger := testutil.CreateLogger(t, 0) - + verifier := createTestVerifier() s := &stream{ endpoint: "127.0.0.1:5017", logger: logger, @@ -63,6 +64,7 @@ func TestSendRequests(t *testing.T) { doneChannel: make(chan bool), requestTraceIdToResponseChannel: make(map[string]chan Response), srReconnectChan: make(chan reconnectReq, 20), + verifier: verifier, } go s.sendRequests() @@ -91,6 +93,7 @@ func TestSendRequestsReturnsWithError(t *testing.T) { fakeSubmitStreamClient.SendReturns(fmt.Errorf("error")) ctx, cancel := context.WithCancel(context.Background()) logger := testutil.CreateLogger(t, 1) + verifier := createTestVerifier() s := &stream{ endpoint: "127.0.0.1:5017", @@ -102,6 +105,7 @@ func TestSendRequestsReturnsWithError(t *testing.T) { doneChannel: make(chan bool), requestTraceIdToResponseChannel: make(map[string]chan Response), srReconnectChan: make(chan reconnectReq, 20), + verifier: verifier, } go s.sendRequests() @@ -135,6 +139,7 @@ func TestReadResponses(t *testing.T) { logger := testutil.CreateLogger(t, 2) ctx, cancel := context.WithCancel(context.Background()) + verifier := createTestVerifier() responseChan := make(chan Response, 1) @@ -148,6 +153,7 @@ func TestReadResponses(t *testing.T) { doneChannel: make(chan bool), requestTraceIdToResponseChannel: make(map[string]chan Response), srReconnectChan: make(chan reconnectReq, 20), + verifier: verifier, } s.registerReply(traceID, responseChan) @@ -178,6 +184,7 @@ func TestReadResponsesReturnsWithError(t *testing.T) { TraceId: traceID, }, fmt.Errorf("rpc error: service unavailable")) logger := testutil.CreateLogger(t, 3) + verifier := createTestVerifier() ctx, cancel := context.WithCancel(context.Background()) @@ -191,6 +198,7 @@ func TestReadResponsesReturnsWithError(t *testing.T) { doneChannel: make(chan bool), requestTraceIdToResponseChannel: make(map[string]chan Response), srReconnectChan: make(chan reconnectReq, 20), + verifier: verifier, } go s.readResponses() @@ -225,6 +233,7 @@ func TestRenewStreamSuccess(t *testing.T) { } requests <- req2 requestTraceIdToResponseChannel[string(req2.TraceId)] = make(chan Response, 100) + verifier := createTestVerifier() faultyStream := &stream{ endpoint: "127.0.0.1:7015", @@ -236,6 +245,7 @@ func TestRenewStreamSuccess(t *testing.T) { doneChannel: make(chan bool), requestTraceIdToResponseChannel: requestTraceIdToResponseChannel, srReconnectChan: make(chan reconnectReq, 20), + verifier: verifier, } faultyStream.cancel() @@ -294,6 +304,7 @@ func TestReconnectRequest(t *testing.T) { fakeSubmitStreamClient.SendReturns(fmt.Errorf("error")) ctx, cancel := context.WithCancel(context.Background()) logger := testutil.CreateLogger(t, 1) + verifier := createTestVerifier() connectionNumber := 2 streamNumber := 3 @@ -310,6 +321,7 @@ func TestReconnectRequest(t *testing.T) { srReconnectChan: make(chan reconnectReq, 20), connNum: connectionNumber, streamNum: streamNumber, + verifier: verifier, } go s.sendRequests() @@ -349,3 +361,9 @@ func (srp *safeReqPool) getElement(i int) *protos.Request { defer srp.mu.Unlock() return srp.reqPool[i] } + +func createTestVerifier() *requestfilter.RulesVerifier { + rv := requestfilter.NewRulesVerifier(nil) + rv.AddRule(requestfilter.AcceptRule{}) + return rv +} diff --git a/test/utils_test.go b/test/utils_test.go index 9a9ed5073..b787f216a 100644 --- a/test/utils_test.go +++ b/test/utils_test.go @@ -90,7 +90,8 @@ func createRouters(t *testing.T, num int, batcherInfos []nodeconfig.BatcherInfo, ShardId: shardId, Batchers: batcherInfos, }}, - UseTLS: true, + UseTLS: true, + RequestMaxBytes: 1 << 10, } router := router.NewRouter(config, l)