diff --git a/cluster/router/chain/chain.go b/cluster/router/chain/chain.go index 3159f9c393..d849c5804a 100644 --- a/cluster/router/chain/chain.go +++ b/cluster/router/chain/chain.go @@ -53,17 +53,21 @@ type RouterChain struct { // Route Loop routers in RouterChain and call Route method to determine the target invokers list. func (c *RouterChain) Route(url *common.URL, invocation base.Invocation) []base.Invoker { - finalInvokers := make([]base.Invoker, 0, len(c.invokers)) + c.mutex.RLock() + invokers := c.invokers + c.mutex.RUnlock() + + finalInvokers := make([]base.Invoker, 0, len(invokers)) // multiple invoker may include different methods, find correct invoker otherwise // will return the invoker without methods - for _, invoker := range c.invokers { + for _, invoker := range invokers { if invoker.GetURL().ServiceKey() == url.ServiceKey() { finalInvokers = append(finalInvokers, invoker) } } if len(finalInvokers) == 0 { - finalInvokers = c.invokers + finalInvokers = invokers } for _, r := range c.copyRouters() { diff --git a/cluster/router/chain/chain_test.go b/cluster/router/chain/chain_test.go new file mode 100644 index 0000000000..afdc394205 --- /dev/null +++ b/cluster/router/chain/chain_test.go @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package chain + +import ( + "testing" +) + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +import ( + "dubbo.apache.org/dubbo-go/v3/cluster/router" + "dubbo.apache.org/dubbo-go/v3/common" + "dubbo.apache.org/dubbo-go/v3/protocol/base" + "dubbo.apache.org/dubbo-go/v3/protocol/invocation" +) + +const testConsumerServiceURL = "consumer://127.0.0.1/com.demo.Service" + +type testPriorityRouter struct { + priority int64 + called int + lastSize int + + notifyFn func([]base.Invoker) + routeFn func([]base.Invoker, *common.URL, base.Invocation) []base.Invoker +} + +func (r *testPriorityRouter) Route(invokers []base.Invoker, url *common.URL, inv base.Invocation) []base.Invoker { + r.called++ + r.lastSize = len(invokers) + if r.routeFn != nil { + return r.routeFn(invokers, url, inv) + } + return invokers +} + +func (r *testPriorityRouter) URL() *common.URL { + return nil +} + +func (r *testPriorityRouter) Priority() int64 { + return r.priority +} + +func (r *testPriorityRouter) Notify(invokers []base.Invoker) { + if r.notifyFn != nil { + r.notifyFn(invokers) + } +} + +func buildInvoker(t *testing.T, rawURL string) base.Invoker { + u, err := common.NewURL(rawURL) + require.NoError(t, err) + return base.NewBaseInvoker(u) +} + +func TestRouteUsesServiceKeyMatchWhenAvailable(t *testing.T) { + consumerURL, err := common.NewURL(testConsumerServiceURL) + require.NoError(t, err) + + match := buildInvoker(t, "dubbo://127.0.0.1:20000/com.demo.Service") + nonMatch := buildInvoker(t, "dubbo://127.0.0.1:20001/com.other.Service") + + r := &testPriorityRouter{priority: 1} + chain := &RouterChain{ + invokers: []base.Invoker{match, nonMatch}, + routers: []router.PriorityRouter{r}, + } + + result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say", nil, nil)) + assert.Len(t, result, 1) + assert.Equal(t, match.GetURL().String(), result[0].GetURL().String()) + assert.Equal(t, 1, r.called) + assert.Equal(t, 1, r.lastSize) +} + +func TestRouteFallsBackToAllInvokersWhenNoMatch(t *testing.T) { + consumerURL, err := common.NewURL(testConsumerServiceURL) + require.NoError(t, err) + + invokerA := buildInvoker(t, "dubbo://127.0.0.1:20000/com.foo.Service") + invokerB := buildInvoker(t, "dubbo://127.0.0.1:20001/com.bar.Service") + + r := &testPriorityRouter{priority: 1} + chain := &RouterChain{ + invokers: []base.Invoker{invokerA, invokerB}, + routers: []router.PriorityRouter{r}, + } + + result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say", nil, nil)) + assert.Len(t, result, 2) + assert.Equal(t, 1, r.called) + assert.Equal(t, 2, r.lastSize) +} + +func TestRouteAppliesRoutersOnSnapshot(t *testing.T) { + consumerURL, err := common.NewURL(testConsumerServiceURL) + require.NoError(t, err) + + invokerA := buildInvoker(t, "dubbo://127.0.0.1:20000/com.demo.Service") + invokerB := buildInvoker(t, "dubbo://127.0.0.1:20001/com.demo.Service") + + r1 := &testPriorityRouter{priority: 1, routeFn: func(invokers []base.Invoker, _ *common.URL, _ base.Invocation) []base.Invoker { + return invokers[:1] + }} + r2 := &testPriorityRouter{priority: 2} + + chain := &RouterChain{ + invokers: []base.Invoker{invokerA, invokerB}, + routers: []router.PriorityRouter{r1, r2}, + } + + result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say", nil, nil)) + assert.Len(t, result, 1) + assert.Equal(t, invokerA.GetURL().String(), result[0].GetURL().String()) + assert.Equal(t, 1, r1.called) + assert.Equal(t, 1, r2.called) + assert.Equal(t, 1, r2.lastSize) +} diff --git a/config/service.go b/config/service.go index 2a241fa7b4..cdcfbc52a1 100644 --- a/config/service.go +++ b/config/service.go @@ -89,7 +89,14 @@ func GetProviderService(name string) common.RPCService { // GetProviderServiceMap gets ProviderServiceMap func GetProviderServiceMap() map[string]common.RPCService { - return proServices + proServicesLock.Lock() + defer proServicesLock.Unlock() + + m := make(map[string]common.RPCService, len(proServices)) + for k, v := range proServices { + m[k] = v + } + return m } func GetProviderServiceInfo(name string) any { @@ -100,7 +107,14 @@ func GetProviderServiceInfo(name string) any { // GetConsumerServiceMap gets ProviderServiceMap func GetConsumerServiceMap() map[string]common.RPCService { - return conServices + conServicesLock.Lock() + defer conServicesLock.Unlock() + + m := make(map[string]common.RPCService, len(conServices)) + for k, v := range conServices { + m[k] = v + } + return m } // SetConsumerServiceByInterfaceName is used by pb serialization diff --git a/config/service_test.go b/config/service_test.go index be94b05c7c..179ffa2f66 100644 --- a/config/service_test.go +++ b/config/service_test.go @@ -18,13 +18,23 @@ package config import ( + "maps" "testing" ) import ( "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +import ( + "dubbo.apache.org/dubbo-go/v3/common" +) + +func cloneRPCServiceMap(src map[string]common.RPCService) map[string]common.RPCService { + return maps.Clone(src) +} + func TestGetConsumerService(t *testing.T) { SetConsumerService(&HelloService{}) @@ -43,3 +53,66 @@ func TestGetConsumerService(t *testing.T) { callback := GetCallback(reference) assert.Nil(t, callback) } + +func TestGetProviderServiceMapReturnsCopy(t *testing.T) { + proServicesLock.Lock() + originalProServices := cloneRPCServiceMap(proServices) + originalProServicesInfo := maps.Clone(proServicesInfo) + proServices = map[string]common.RPCService{} + proServicesInfo = map[string]any{} + proServicesLock.Unlock() + + defer func() { + proServicesLock.Lock() + proServices = originalProServices + proServicesInfo = originalProServicesInfo + proServicesLock.Unlock() + }() + + svc := &HelloService{} + SetProviderService(svc) + + got := GetProviderServiceMap() + require.Len(t, got, 1) + + got["Injected"] = &HelloService{} + got["HelloService"] = &HelloService{} + + proServicesLock.Lock() + _, hasInjected := proServices["Injected"] + _, hasHelloService := proServices["HelloService"] + proServicesLock.Unlock() + + assert.False(t, hasInjected) + assert.True(t, hasHelloService) +} + +func TestGetConsumerServiceMapReturnsCopy(t *testing.T) { + conServicesLock.Lock() + originalConServices := cloneRPCServiceMap(conServices) + conServices = map[string]common.RPCService{} + conServicesLock.Unlock() + + defer func() { + conServicesLock.Lock() + conServices = originalConServices + conServicesLock.Unlock() + }() + + svc := &HelloService{} + SetConsumerService(svc) + + got := GetConsumerServiceMap() + require.Len(t, got, 1) + + got["Injected"] = &HelloService{} + got["HelloService"] = &HelloService{} + + conServicesLock.Lock() + _, hasInjected := conServices["Injected"] + stored := conServices["HelloService"] + conServicesLock.Unlock() + + assert.False(t, hasInjected) + assert.Equal(t, svc, stored) +} diff --git a/dubbo.go b/dubbo.go index 2a132d3cb6..125af03856 100644 --- a/dubbo.go +++ b/dubbo.go @@ -298,13 +298,15 @@ func SetConsumerService(svc common.RPCService) { } func SetProviderService(svc common.RPCService) { - conLock.Lock() - defer conLock.Unlock() + proLock.Lock() + defer proLock.Unlock() providerServices[common.GetReference(svc)] = &server.ServiceDefinition{ Handler: svc, } } func GetConsumerConnection(interfaceName string) (*client.Connection, error) { + conLock.RLock() + defer conLock.RUnlock() return consumerServices[interfaceName].GetConnection() } diff --git a/dubbo_test.go b/dubbo_test.go index e23c97394a..01d3bc5d4c 100644 --- a/dubbo_test.go +++ b/dubbo_test.go @@ -18,11 +18,13 @@ package dubbo import ( + "maps" "testing" ) import ( "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) import ( @@ -32,6 +34,22 @@ import ( "dubbo.apache.org/dubbo-go/v3/server" ) +type testRPCService struct { + ref string +} + +func (s *testRPCService) Reference() string { + return s.ref +} + +func cloneClientDefinitions(src map[string]*client.ClientDefinition) map[string]*client.ClientDefinition { + return maps.Clone(src) +} + +func cloneServiceDefinitions(src map[string]*server.ServiceDefinition) map[string]*server.ServiceDefinition { + return maps.Clone(src) +} + // TestIndependentConfig tests the configurations of the `instance`, `client`, and `server` are independent. func TestIndependentConfig(t *testing.T) { // instance configuration @@ -96,3 +114,55 @@ func TestIndependentConfig(t *testing.T) { panic(err) } } + +func TestSetProviderServiceRegistersByReference(t *testing.T) { + proLock.Lock() + original := cloneServiceDefinitions(providerServices) + providerServices = make(map[string]*server.ServiceDefinition) + proLock.Unlock() + + defer func() { + proLock.Lock() + providerServices = original + proLock.Unlock() + }() + + svc := &testRPCService{ref: "provider.test.Service"} + SetProviderService(svc) + + proLock.RLock() + defer proLock.RUnlock() + def, ok := providerServices[svc.Reference()] + require.True(t, ok) + require.NotNil(t, def) + assert.Equal(t, svc, def.Handler) +} + +func TestGetConsumerConnectionFromConsumerServices(t *testing.T) { + conLock.Lock() + original := cloneClientDefinitions(consumerServices) + consumerServices = make(map[string]*client.ClientDefinition) + conLock.Unlock() + + defer func() { + conLock.Lock() + consumerServices = original + conLock.Unlock() + }() + + svc := &testRPCService{ref: "consumer.test.Service"} + SetConsumerService(svc) + + conn, err := GetConsumerConnection(svc.Reference()) + require.Error(t, err) + require.Nil(t, conn) + + expectedConn := &client.Connection{} + conLock.Lock() + consumerServices[svc.Reference()].SetConnection(expectedConn) + conLock.Unlock() + + conn, err = GetConsumerConnection(svc.Reference()) + require.NoError(t, err) + assert.Equal(t, expectedConn, conn) +}