Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/application_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestApplicationConfig(t *testing.T) {
err := Load(WithPath("./testdata/config/application/application.yaml"))
require.NoError(t, err)

center := rootConfig.Registries
center := GetRootConfig().Registries
assert.NotNil(t, center)
}

Expand Down
40 changes: 27 additions & 13 deletions config/config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package config

import (
"errors"
"sync/atomic"
)

import (
Expand All @@ -34,9 +35,21 @@ import (
)

var (
rootConfig = NewRootConfigBuilder().Build()
rootConfigStore = func() *atomic.Pointer[RootConfig] {
store := &atomic.Pointer[RootConfig]{}
store.Store(NewRootConfigBuilder().Build())
return store
}()
)

func getRootConfigInternal() *RootConfig {
return rootConfigStore.Load()
}
Comment on lines +45 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数有啥用啊 返回的又不是深拷贝

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确实,感谢,我现在去去掉这个加锁的逻辑

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

图片 已经修改了,把加锁逻辑去掉了,换成原子读写rootConfig ,然后修改了相关的单测,将直接读全局变量的地方改成统一走 GetRootConfig()/SetRootConfig()


func setRootConfigInternal(rc *RootConfig) {
rootConfigStore.Store(rc)
}

func init() {
log := zap.NewDefault()
logger.SetLogger(log)
Expand All @@ -45,39 +58,40 @@ func init() {
func Load(opts ...LoaderConfOption) error {
// conf
conf := NewLoaderConf(opts...)
loadConfig := conf.rc

if conf.rc == nil {
loadConfig = NewRootConfigBuilder().Build()
koan := GetConfigResolver(conf)
koan = conf.MergeConfig(koan)
if err := koan.UnmarshalWithConf(rootConfig.Prefix(),
rootConfig, koanf.UnmarshalConf{Tag: "yaml"}); err != nil {
if err := koan.UnmarshalWithConf(loadConfig.Prefix(),
loadConfig, koanf.UnmarshalConf{Tag: "yaml"}); err != nil {
return err
}
} else {
rootConfig = conf.rc
}

if err := rootConfig.Init(); err != nil {
if err := loadConfig.Init(); err != nil {
return err
}
return nil
}

func check() error {
if rootConfig == nil {
if GetRootConfig() == nil {
return errors.New("execute the config.Load() method first")
}
return nil
}

// GetRPCService get rpc service for consumer
func GetRPCService(name string) common.RPCService {
return rootConfig.Consumer.References[name].GetRPCService()
return GetRootConfig().Consumer.References[name].GetRPCService()
}

// RPCService create rpc service for consumer
func RPCService(service common.RPCService) {
ref := common.GetReference(service)
rootConfig.Consumer.References[ref].Implement(service)
GetRootConfig().Consumer.References[ref].Implement(service)
}

// GetMetricConfig find the MetricsConfig
Expand All @@ -95,17 +109,17 @@ func GetMetricConfig() *MetricsConfig {
// }
//}
//return GetBaseConfig().Metrics
return rootConfig.Metrics
return GetRootConfig().Metrics
}

func GetTracingConfig(tracingKey string) *TracingConfig {
return rootConfig.Tracing[tracingKey]
return GetRootConfig().Tracing[tracingKey]
}

func GetMetadataReportConfg() *MetadataReportConfig {
return rootConfig.MetadataReport
return GetRootConfig().MetadataReport
}

func IsProvider() bool {
return len(rootConfig.Provider.Services) > 0
return len(GetRootConfig().Provider.Services) > 0
}
10 changes: 8 additions & 2 deletions config/consumer_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (cc *ConsumerConfig) Load() {
// use interface name defined by pb
refConfig.InterfaceName = triplePBService.XXX_InterfaceName()
}
if err := refConfig.Init(rootConfig); err != nil {
if err := refConfig.Init(GetRootConfig()); err != nil {
logger.Errorf(fmt.Sprintf("reference with registeredTypeName = %s init failed! err: %#v", registeredTypeName, err))
continue
}
Expand Down Expand Up @@ -185,7 +185,13 @@ func (cc *ConsumerConfig) Load() {

// SetConsumerConfig sets consumerConfig by @c
func SetConsumerConfig(c ConsumerConfig) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetConsumerConfig 里的 nil 分支在 rc == nil 时会丢失其他字段。如果 GetRootConfig() 还没被调用过(返回 nil),这里直接 SetRootConfig(RootConfig{Consumer: &c}) 会创建一个只有 Consumer 的 RootConfig,之前通过 Load() 设置的 Protocol、Application 等字段全部丢失。

建议:如果 GetRootConfig() 返回 nil,直接用 NewRootConfigBuilder() 构建一个默认值,再设置 Consumer:

rc := GetRootConfig()
if rc == nil {
    rc = NewRootConfigBuilder().Build()
}
next := *rc
next.Consumer = &c
SetRootConfig(next)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢,我将尽快修复

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

图片 已按照建议修复,谢谢

rootConfig.Consumer = &c
rc := GetRootConfig()
if rc == nil {
rc = NewRootConfigBuilder().Build()
}
next := *rc
next.Consumer = &c
SetRootConfig(next)
}

func newEmptyConsumerConfig() *ConsumerConfig {
Expand Down
10 changes: 5 additions & 5 deletions config/custom_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ func TestCustomInit(t *testing.T) {
t.Run("empty use default", func(t *testing.T) {
err := Load(WithPath("./testdata/config/custom/empty.yaml"))
require.NoError(t, err)
assert.NotNil(t, rootConfig)
customConfig := rootConfig.Custom
assert.NotNil(t, GetRootConfig())
customConfig := GetRootConfig().Custom
assert.NotNil(t, customConfig)
assert.Equal(t, customConfig.ConfigMap, map[string]any(nil))
assert.Equal(t, "test", customConfig.GetDefineValue("test", "test"))
Expand All @@ -47,8 +47,8 @@ func TestCustomInit(t *testing.T) {
t.Run("use config", func(t *testing.T) {
err := Load(WithPath("./testdata/config/custom/custom.yaml"))
require.NoError(t, err)
assert.NotNil(t, rootConfig)
customConfig := rootConfig.Custom
assert.NotNil(t, GetRootConfig())
customConfig := GetRootConfig().Custom
assert.NotNil(t, customConfig)
assert.Equal(t, map[string]any{"test-config": true}, customConfig.ConfigMap)
assert.Equal(t, true, customConfig.GetDefineValue("test-config", false))
Expand All @@ -65,7 +65,7 @@ func TestCustomInit(t *testing.T) {
assert.Equal(t, true, customConfig.GetDefineValue("test-build", false))
assert.Equal(t, false, customConfig.GetDefineValue("test-no-build", false))
// todo @(laurence) now we should guarantee rootConfig ptr can't be changed during test
tempRootConfig := rootConfig
tempRootConfig := GetRootConfig()
rt := NewRootConfigBuilder().SetCustom(customConfig).Build()
SetRootConfig(*rt)
assert.Equal(t, true, GetDefineValue("test-build", false))
Expand Down
43 changes: 24 additions & 19 deletions config/graceful_shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ func gracefulShutdownInit() {
if !exist {
return
}
if filter, ok := gracefulShutdownConsumerFilter.(Setter); ok && rootConfig.Shutdown != nil {
rc := GetRootConfig()
if filter, ok := gracefulShutdownConsumerFilter.(Setter); ok && rc != nil && rc.Shutdown != nil {
filter.Set(constant.GracefulShutdownFilterShutdownConfig, GetShutDown())
}

if filter, ok := gracefulShutdownProviderFilter.(Setter); ok && rootConfig.Shutdown != nil {
if filter, ok := gracefulShutdownProviderFilter.(Setter); ok && rc != nil && rc.Shutdown != nil {
filter.Set(constant.GracefulShutdownFilterShutdownConfig, GetShutDown())
}

Expand Down Expand Up @@ -127,21 +128,22 @@ func destroyAllRegistries() {
func destroyProtocols() {
logger.Info("Graceful shutdown --- Destroy protocols. ")

if rootConfig.Protocols == nil {
rc := GetRootConfig()
if rc == nil || rc.Protocols == nil {
return
}

consumerProtocols := getConsumerProtocols()
consumerProtocols := getConsumerProtocols(rc)

destroyProviderProtocols(consumerProtocols)
destroyProviderProtocols(rc, consumerProtocols)
destroyConsumerProtocols(consumerProtocols)
}

// destroyProviderProtocols destroys the provider's protocol.
// if the protocol is consumer's protocol too, we will keep it
func destroyProviderProtocols(consumerProtocols *gxset.HashSet) {
func destroyProviderProtocols(rc *RootConfig, consumerProtocols *gxset.HashSet) {
logger.Info("Graceful shutdown --- First destroy provider's protocols. ")
for _, protocol := range rootConfig.Protocols {
for _, protocol := range rc.Protocols {
// the protocol is the consumer's protocol too, we can not destroy it.
if consumerProtocols.Contains(protocol.Name) {
continue
Expand All @@ -159,18 +161,19 @@ func destroyConsumerProtocols(consumerProtocols *gxset.HashSet) {

func waitAndAcceptNewRequests() {
logger.Info("Graceful shutdown --- Keep waiting and accept new requests for a short time. ")
if rootConfig.Shutdown == nil {
rc := GetRootConfig()
if rc == nil || rc.Shutdown == nil {
return
}

time.Sleep(rootConfig.Shutdown.GetConsumerUpdateWaitTime())
time.Sleep(rc.Shutdown.GetConsumerUpdateWaitTime())

timeout := rootConfig.Shutdown.GetStepTimeout()
timeout := rc.Shutdown.GetStepTimeout()
// ignore this step
if timeout < 0 {
return
}
waitingProviderProcessedTimeout(rootConfig.Shutdown)
waitingProviderProcessedTimeout(rc.Shutdown)
}

func waitingProviderProcessedTimeout(shutdownConfig *ShutdownConfig) {
Expand All @@ -193,12 +196,13 @@ func waitingProviderProcessedTimeout(shutdownConfig *ShutdownConfig) {
// for provider. It will wait for processing receiving requests
func waitForSendingAndReceivingRequests() {
logger.Info("Graceful shutdown --- Keep waiting until sending/accepting requests finish or timeout. ")
if rootConfig == nil || rootConfig.Shutdown == nil {
rc := GetRootConfig()
if rc == nil || rc.Shutdown == nil {
// ignore this step
return
}
rootConfig.Shutdown.RejectRequest.Store(true)
waitingConsumerProcessedTimeout(rootConfig.Shutdown)
rc.Shutdown.RejectRequest.Store(true)
waitingConsumerProcessedTimeout(rc.Shutdown)
}

func waitingConsumerProcessedTimeout(shutdownConfig *ShutdownConfig) {
Expand All @@ -217,21 +221,22 @@ func waitingConsumerProcessedTimeout(shutdownConfig *ShutdownConfig) {

func totalTimeout() time.Duration {
timeout := defaultShutDownTime
if rootConfig.Shutdown != nil && rootConfig.Shutdown.GetTimeout() > timeout {
timeout = rootConfig.Shutdown.GetTimeout()
rc := GetRootConfig()
if rc != nil && rc.Shutdown != nil && rc.Shutdown.GetTimeout() > timeout {
timeout = rc.Shutdown.GetTimeout()
}

return timeout
}

// we can not get the protocols from consumerConfig because some protocol don't have configuration, like jsonrpc.
func getConsumerProtocols() *gxset.HashSet {
func getConsumerProtocols(rc *RootConfig) *gxset.HashSet {
result := gxset.NewSet()
if rootConfig.Consumer == nil || rootConfig.Consumer.References == nil {
if rc == nil || rc.Consumer == nil || rc.Consumer.References == nil {
return result
}

for _, reference := range rootConfig.Consumer.References {
for _, reference := range rc.Consumer.References {
result.Add(reference.Protocol)
}
return result
Expand Down
8 changes: 4 additions & 4 deletions config/logger_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ func TestLoggerInit(t *testing.T) {
t.Run("empty use default", func(t *testing.T) {
err := Load(WithPath("./testdata/config/logger/empty_log.yaml"))
require.NoError(t, err)
assert.NotNil(t, rootConfig)
loggerConfig := rootConfig.Logger
assert.NotNil(t, GetRootConfig())
loggerConfig := GetRootConfig().Logger
assert.NotNil(t, loggerConfig)
})

t.Run("use config", func(t *testing.T) {
err := Load(WithPath("./testdata/config/logger/log.yaml"))
require.NoError(t, err)
loggerConfig := rootConfig.Logger
loggerConfig := GetRootConfig().Logger
assert.NotNil(t, loggerConfig)
// default
logger.Info("hello")
Expand All @@ -49,7 +49,7 @@ func TestLoggerInit(t *testing.T) {
t.Run("use config with file", func(t *testing.T) {
err := Load(WithPath("./testdata/config/logger/file_log.yaml"))
require.NoError(t, err)
loggerConfig := rootConfig.Logger
loggerConfig := GetRootConfig().Logger
assert.NotNil(t, loggerConfig)
logger.Debug("debug")
logger.Info("info")
Expand Down
2 changes: 1 addition & 1 deletion config/metadata_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func initMetadata(rc *RootConfig) error {
func getMetadataPort(rc *RootConfig) int {
port := rc.Application.MetadataServicePort
if port == "" {
protocolConfig, ok := rootConfig.Protocols[constant.DefaultProtocol]
protocolConfig, ok := rc.Protocols[constant.DefaultProtocol]
if ok {
port = protocolConfig.Port
} else {
Expand Down
4 changes: 2 additions & 2 deletions config/protocol_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestGetProtocolsConfig(t *testing.T) {
t.Run("empty use default", func(t *testing.T) {
err := Load(WithPath("./testdata/config/protocol/empty_application.yaml"))
require.NoError(t, err)
protocols := rootConfig.Protocols
protocols := GetRootConfig().Protocols
assert.NotNil(t, protocols)
// default
assert.Equal(t, "tri", protocols["tri"].Name)
Expand All @@ -42,7 +42,7 @@ func TestGetProtocolsConfig(t *testing.T) {
t.Run("use config", func(t *testing.T) {
err := Load(WithPath("./testdata/config/protocol/application.yaml"))
require.NoError(t, err)
protocols := rootConfig.Protocols
protocols := GetRootConfig().Protocols
assert.NotNil(t, protocols)
// default
assert.Equal(t, "dubbo", protocols["dubbo"].Name)
Expand Down
2 changes: 1 addition & 1 deletion config/provider_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (c *ProviderConfig) Load() {
serviceConfig = NewServiceConfigBuilder().Build()
// use interface name defined by pb
serviceConfig.Interface = supportPBPackagerNameService.XXX_InterfaceName()
if err := serviceConfig.Init(rootConfig); err != nil {
if err := serviceConfig.Init(GetRootConfig()); err != nil {
logger.Errorf("Service with registeredTypeName = %s init failed with error = %#v", registeredTypeName, err)
}
serviceConfig.adaptiveService = c.AdaptiveService
Expand Down
4 changes: 2 additions & 2 deletions config/provider_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ import (
func TestProviderConfigEmptyRegistry(t *testing.T) {
err := Load(WithPath("./testdata/config/provider/empty_registry_application.yaml"))
require.NoError(t, err)
provider := rootConfig.Provider
provider := GetRootConfig().Provider
assert.Len(t, provider.RegistryIDs, 1)
assert.Equal(t, "nacos", provider.RegistryIDs[0])
}

func TestProviderConfigRootRegistry(t *testing.T) {
err := Load(WithPath("./testdata/config/provider/registry_application.yaml"))
require.NoError(t, err)
provider := rootConfig.Provider
provider := GetRootConfig().Provider
assert.NotNil(t, provider)
assert.NotNil(t, provider.Services["HelloService"])
assert.NotNil(t, provider.Services["OrderService"])
Expand Down
Loading
Loading