Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions cluster/router/chain/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,21 @@ type RouterChain struct {

// Route Loop routers in RouterChain and call Route method to determine the target invokers list.
func (c *RouterChain) Route(url *common.URL, invocation base.Invocation) []base.Invoker {
finalInvokers := make([]base.Invoker, 0, len(c.invokers))
c.mutex.RLock()
invokers := c.invokers
c.mutex.RUnlock()

finalInvokers := make([]base.Invoker, 0, len(invokers))
// multiple invoker may include different methods, find correct invoker otherwise
// will return the invoker without methods
for _, invoker := range c.invokers {
for _, invoker := range invokers {
if invoker.GetURL().ServiceKey() == url.ServiceKey() {
finalInvokers = append(finalInvokers, invoker)
}
}

if len(finalInvokers) == 0 {
finalInvokers = c.invokers
finalInvokers = invokers
}

for _, r := range c.copyRouters() {
Expand Down
138 changes: 138 additions & 0 deletions cluster/router/chain/chain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package chain

import (
"testing"
)

import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

import (
"dubbo.apache.org/dubbo-go/v3/cluster/router"
"dubbo.apache.org/dubbo-go/v3/common"
"dubbo.apache.org/dubbo-go/v3/protocol/base"
"dubbo.apache.org/dubbo-go/v3/protocol/invocation"
)

const testConsumerServiceURL = "consumer://127.0.0.1/com.demo.Service"

type testPriorityRouter struct {
priority int64
called int
lastSize int

notifyFn func([]base.Invoker)
routeFn func([]base.Invoker, *common.URL, base.Invocation) []base.Invoker
}

func (r *testPriorityRouter) Route(invokers []base.Invoker, url *common.URL, inv base.Invocation) []base.Invoker {
r.called++
r.lastSize = len(invokers)
if r.routeFn != nil {
return r.routeFn(invokers, url, inv)
}
return invokers
}

func (r *testPriorityRouter) URL() *common.URL {
return nil
}

func (r *testPriorityRouter) Priority() int64 {
return r.priority
}

func (r *testPriorityRouter) Notify(invokers []base.Invoker) {
if r.notifyFn != nil {
r.notifyFn(invokers)
}
}

func buildInvoker(t *testing.T, rawURL string) base.Invoker {
u, err := common.NewURL(rawURL)
require.NoError(t, err)
return base.NewBaseInvoker(u)
}

func TestRouteUsesServiceKeyMatchWhenAvailable(t *testing.T) {
consumerURL, err := common.NewURL(testConsumerServiceURL)
require.NoError(t, err)

match := buildInvoker(t, "dubbo://127.0.0.1:20000/com.demo.Service")
nonMatch := buildInvoker(t, "dubbo://127.0.0.1:20001/com.other.Service")

r := &testPriorityRouter{priority: 1}
chain := &RouterChain{
invokers: []base.Invoker{match, nonMatch},
routers: []router.PriorityRouter{r},
}

result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say", nil, nil))
assert.Len(t, result, 1)
assert.Equal(t, match.GetURL().String(), result[0].GetURL().String())
assert.Equal(t, 1, r.called)
assert.Equal(t, 1, r.lastSize)
}

func TestRouteFallsBackToAllInvokersWhenNoMatch(t *testing.T) {
consumerURL, err := common.NewURL(testConsumerServiceURL)
require.NoError(t, err)

invokerA := buildInvoker(t, "dubbo://127.0.0.1:20000/com.foo.Service")
invokerB := buildInvoker(t, "dubbo://127.0.0.1:20001/com.bar.Service")

r := &testPriorityRouter{priority: 1}
chain := &RouterChain{
invokers: []base.Invoker{invokerA, invokerB},
routers: []router.PriorityRouter{r},
}

result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say", nil, nil))
assert.Len(t, result, 2)
assert.Equal(t, 1, r.called)
assert.Equal(t, 2, r.lastSize)
}

func TestRouteAppliesRoutersOnSnapshot(t *testing.T) {
consumerURL, err := common.NewURL(testConsumerServiceURL)
require.NoError(t, err)

invokerA := buildInvoker(t, "dubbo://127.0.0.1:20000/com.demo.Service")
invokerB := buildInvoker(t, "dubbo://127.0.0.1:20001/com.demo.Service")

r1 := &testPriorityRouter{priority: 1, routeFn: func(invokers []base.Invoker, _ *common.URL, _ base.Invocation) []base.Invoker {
return invokers[:1]
}}
r2 := &testPriorityRouter{priority: 2}

chain := &RouterChain{
invokers: []base.Invoker{invokerA, invokerB},
routers: []router.PriorityRouter{r1, r2},
}

result := chain.Route(consumerURL, invocation.NewRPCInvocation("Say", nil, nil))
assert.Len(t, result, 1)
assert.Equal(t, invokerA.GetURL().String(), result[0].GetURL().String())
assert.Equal(t, 1, r1.called)
assert.Equal(t, 1, r2.called)
assert.Equal(t, 1, r2.lastSize)
}
18 changes: 16 additions & 2 deletions config/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ func GetProviderService(name string) common.RPCService {

// GetProviderServiceMap gets ProviderServiceMap
func GetProviderServiceMap() map[string]common.RPCService {
return proServices
proServicesLock.Lock()
defer proServicesLock.Unlock()

m := make(map[string]common.RPCService, len(proServices))
for k, v := range proServices {
m[k] = v
}
return m
}

func GetProviderServiceInfo(name string) any {
Expand All @@ -100,7 +107,14 @@ func GetProviderServiceInfo(name string) any {

// GetConsumerServiceMap gets ProviderServiceMap
func GetConsumerServiceMap() map[string]common.RPCService {
return conServices
conServicesLock.Lock()
defer conServicesLock.Unlock()

m := make(map[string]common.RPCService, len(conServices))
for k, v := range conServices {
m[k] = v
}
return m
}

// SetConsumerServiceByInterfaceName is used by pb serialization
Expand Down
73 changes: 73 additions & 0 deletions config/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,23 @@
package config

import (
"maps"
"testing"
)

import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

import (
"dubbo.apache.org/dubbo-go/v3/common"
)

func cloneRPCServiceMap(src map[string]common.RPCService) map[string]common.RPCService {
return maps.Clone(src)
}

func TestGetConsumerService(t *testing.T) {

SetConsumerService(&HelloService{})
Expand All @@ -43,3 +53,66 @@ func TestGetConsumerService(t *testing.T) {
callback := GetCallback(reference)
assert.Nil(t, callback)
}

func TestGetProviderServiceMapReturnsCopy(t *testing.T) {
proServicesLock.Lock()
originalProServices := cloneRPCServiceMap(proServices)
originalProServicesInfo := maps.Clone(proServicesInfo)
proServices = map[string]common.RPCService{}
proServicesInfo = map[string]any{}
proServicesLock.Unlock()

defer func() {
proServicesLock.Lock()
proServices = originalProServices
proServicesInfo = originalProServicesInfo
proServicesLock.Unlock()
}()

svc := &HelloService{}
SetProviderService(svc)

got := GetProviderServiceMap()
require.Len(t, got, 1)

got["Injected"] = &HelloService{}
got["HelloService"] = &HelloService{}

proServicesLock.Lock()
_, hasInjected := proServices["Injected"]
_, hasHelloService := proServices["HelloService"]
proServicesLock.Unlock()

assert.False(t, hasInjected)
assert.True(t, hasHelloService)
}

func TestGetConsumerServiceMapReturnsCopy(t *testing.T) {
conServicesLock.Lock()
originalConServices := cloneRPCServiceMap(conServices)
conServices = map[string]common.RPCService{}
conServicesLock.Unlock()

defer func() {
conServicesLock.Lock()
conServices = originalConServices
conServicesLock.Unlock()
}()

svc := &HelloService{}
SetConsumerService(svc)

got := GetConsumerServiceMap()
require.Len(t, got, 1)

got["Injected"] = &HelloService{}
got["HelloService"] = &HelloService{}

conServicesLock.Lock()
_, hasInjected := conServices["Injected"]
stored := conServices["HelloService"]
conServicesLock.Unlock()

assert.False(t, hasInjected)
assert.Equal(t, svc, stored)
}
6 changes: 4 additions & 2 deletions dubbo.go
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,15 @@ func SetConsumerService(svc common.RPCService) {
}

func SetProviderService(svc common.RPCService) {
conLock.Lock()
defer conLock.Unlock()
proLock.Lock()
defer proLock.Unlock()
providerServices[common.GetReference(svc)] = &server.ServiceDefinition{
Handler: svc,
}
}

func GetConsumerConnection(interfaceName string) (*client.Connection, error) {
conLock.RLock()
defer conLock.RUnlock()
return consumerServices[interfaceName].GetConnection()
}
70 changes: 70 additions & 0 deletions dubbo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
package dubbo

import (
"maps"
"testing"
)

import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

import (
Expand All @@ -32,6 +34,22 @@ import (
"dubbo.apache.org/dubbo-go/v3/server"
)

type testRPCService struct {
ref string
}

func (s *testRPCService) Reference() string {
return s.ref
}

func cloneClientDefinitions(src map[string]*client.ClientDefinition) map[string]*client.ClientDefinition {
return maps.Clone(src)
}

func cloneServiceDefinitions(src map[string]*server.ServiceDefinition) map[string]*server.ServiceDefinition {
return maps.Clone(src)
}

// TestIndependentConfig tests the configurations of the `instance`, `client`, and `server` are independent.
func TestIndependentConfig(t *testing.T) {
// instance configuration
Expand Down Expand Up @@ -96,3 +114,55 @@ func TestIndependentConfig(t *testing.T) {
panic(err)
}
}

func TestSetProviderServiceRegistersByReference(t *testing.T) {
proLock.Lock()
original := cloneServiceDefinitions(providerServices)
providerServices = make(map[string]*server.ServiceDefinition)
proLock.Unlock()

defer func() {
proLock.Lock()
providerServices = original
proLock.Unlock()
}()

svc := &testRPCService{ref: "provider.test.Service"}
SetProviderService(svc)

proLock.RLock()
defer proLock.RUnlock()
def, ok := providerServices[svc.Reference()]
require.True(t, ok)
require.NotNil(t, def)
assert.Equal(t, svc, def.Handler)
}

func TestGetConsumerConnectionFromConsumerServices(t *testing.T) {
conLock.Lock()
original := cloneClientDefinitions(consumerServices)
consumerServices = make(map[string]*client.ClientDefinition)
conLock.Unlock()

defer func() {
conLock.Lock()
consumerServices = original
conLock.Unlock()
}()

svc := &testRPCService{ref: "consumer.test.Service"}
SetConsumerService(svc)

conn, err := GetConsumerConnection(svc.Reference())
require.Error(t, err)
require.Nil(t, conn)

expectedConn := &client.Connection{}
conLock.Lock()
consumerServices[svc.Reference()].SetConnection(expectedConn)
conLock.Unlock()

conn, err = GetConsumerConnection(svc.Reference())
require.NoError(t, err)
assert.Equal(t, expectedConn, conn)
}
Loading