Skip to content

Commit aaefbe6

Browse files
committed
feat(data): improve concurrency safety and resource management
- Replace map with SyncMap for concurrent access in closes field - Ensure proper closure of resources using Set method - Fix closure variable capture issue in loop - Add stopOnce and clustersMu for message repository synchronization - Shuffle cluster indices instead of modifying original slice - Handle closed message channel in AppendMessage - Lock when updating clusters slice in initClusters
1 parent 4829a06 commit aaefbe6

File tree

3 files changed

+63
-23
lines changed

3 files changed

+63
-23
lines changed

cmd/gorm/gorm.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,13 +136,13 @@ func initDB() (*gorm.DB, error) {
136136
}
137137
databases = append(databases, database)
138138
}
139+
defer rows.Close()
139140
klog.Debugw("msg", "show databases success", "databases", databases)
140141
if !slices.Contains(databases, flags.database) {
141142
// create database
142143
klog.Warnw("msg", "database not exists", "database", flags.database)
143144
klog.Debugw("msg", "create database", "database", flags.database)
144-
_, err := sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", flags.database))
145-
if err != nil {
145+
if _, err := sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %s", flags.database)); err != nil {
146146
klog.Errorw("msg", "create database failed", "error", err, "database", flags.database)
147147
return nil, err
148148
}

internal/data/data.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func New(c *conf.Bootstrap, helper *klog.Helper) (*Data, func(), error) {
3434
helper: helper,
3535
c: c,
3636
dbs: safety.NewSyncMap(make(map[string]*gorm.DB)),
37-
closes: make(map[string]func() error),
37+
closes: safety.NewSyncMap(make(map[string]func() error)),
3838
useDatabase: strutil.IsNotEmpty(c.GetUseDatabase()) && strings.EqualFold(c.GetUseDatabase(), "true"),
3939
reloadFuncs: safety.NewSyncMap(make(map[string]func())),
4040
}
@@ -48,7 +48,7 @@ func New(c *conf.Bootstrap, helper *klog.Helper) (*Data, func(), error) {
4848
return nil, d.close, err
4949
}
5050
d.mainDB = mainDB
51-
d.closes["mainDB"] = func() error { return connect.CloseDB(mainDB) }
51+
d.closes.Set("mainDB", func() error { return connect.CloseDB(mainDB) })
5252

5353
for namespace, biz := range d.c.GetBiz() {
5454
db, err := connect.NewGorm(biz, d.helper)
@@ -60,7 +60,10 @@ func New(c *conf.Bootstrap, helper *klog.Helper) (*Data, func(), error) {
6060
d.dbs.Set(ns, db)
6161
}
6262

63-
d.closes["bizDB.["+namespace+"]"] = func() error { return connect.CloseDB(db) }
63+
// 使用局部变量避免闭包捕获问题
64+
namespaceKey := "bizDB.[" + namespace + "]"
65+
dbToClose := db
66+
d.closes.Set(namespaceKey, func() error { return connect.CloseDB(dbToClose) })
6467
}
6568
} else {
6669
if err := d.LoadFileConfig(d.c, d.helper); err != nil {
@@ -75,7 +78,7 @@ func New(c *conf.Bootstrap, helper *klog.Helper) (*Data, func(), error) {
7578
return nil, d.close, err
7679
}
7780
d.cache = cache
78-
d.closes["cache"] = func() error { return cache.Close() }
81+
d.closes.Set("cache", func() error { return cache.Close() })
7982

8083
return d, d.close, nil
8184
}
@@ -87,24 +90,25 @@ type Data struct {
8790
mainDB *gorm.DB
8891
registry connect.Registry
8992
cache cache.Interface
90-
closes map[string]func() error
93+
closes *safety.SyncMap[string, func() error] // 使用SyncMap保证并发安全
9194
useDatabase bool
9295
fileConfig conf.Config
9396
reloadFuncs *safety.SyncMap[string, func()]
9497
}
9598

9699
func (d *Data) AppendClose(name string, close func() error) {
97-
d.closes[name] = close
100+
d.closes.Set(name, close)
98101
}
99102

100103
func (d *Data) close() {
101-
for name, close := range d.closes {
104+
d.closes.Range(func(name string, close func() error) bool {
102105
if err := close(); err != nil {
103106
d.helper.Errorw("msg", "close db failed", "name", name, "error", err)
104-
continue
107+
return true // 继续遍历
105108
}
106109
d.helper.Debugw("msg", "close success", "name", name)
107-
}
110+
return true // 继续遍历
111+
})
108112
}
109113

110114
func (d *Data) UseDatabase() bool {
@@ -183,7 +187,7 @@ func (d *Data) initRegistry() error {
183187
}
184188
registrar := etcd.New(client, etcd.Namespace(namespace))
185189
d.registry = registrar
186-
d.closes["etcdClient"] = func() error { return client.Close() }
190+
d.closes.Set("etcdClient", func() error { return client.Close() })
187191
}
188192
return nil
189193
}

internal/data/impl/message.go

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ type messageRepositoryImpl struct {
8282

8383
clusters []sender.Sender
8484
clusterInitOnce sync.Once
85+
stopOnce sync.Once // 确保Stop只执行一次
86+
clustersMu sync.RWMutex // 保护clusters的并发访问
8587
}
8688

8789
// start 启动后台处理goroutine
@@ -99,18 +101,33 @@ func (m *messageRepositoryImpl) Start(ctx context.Context) error {
99101

100102
// Stop 停止事件总线
101103
func (m *messageRepositoryImpl) Stop(ctx context.Context) error {
104+
var stopped bool
105+
m.stopOnce.Do(func() {
106+
stopped = true
107+
close(m.stopChan)
108+
m.wg.Wait()
109+
close(m.messageChan)
110+
m.helper.Debug("msg", "message bus stopped")
111+
})
112+
113+
if !stopped {
114+
// 如果已经停止,等待context或直接返回
115+
select {
116+
case <-ctx.Done():
117+
m.helper.Debug("msg", "message bus already stopped, context done")
118+
return nil
119+
case <-m.stopChan:
120+
m.helper.Debug("msg", "message bus already stopped")
121+
return nil
122+
}
123+
}
124+
125+
// 检查context是否已取消
102126
select {
103127
case <-ctx.Done():
104128
m.helper.Debug("msg", "message bus stopped by context done")
105129
return nil
106-
case <-m.stopChan:
107-
m.helper.Debug("msg", "message bus stopped by stop channel")
108-
return nil
109130
default:
110-
close(m.stopChan)
111-
m.wg.Wait()
112-
close(m.messageChan)
113-
m.helper.Debug("msg", "message bus stopped")
114131
return nil
115132
}
116133
}
@@ -144,12 +161,22 @@ func (m *messageRepositoryImpl) waitProcessMessage(ctx context.Context, messageU
144161
// notice: 没有使用外部存储,不允许使用集群模式, 避免消息无法共享到其他节点
145162
if m.d.UseDatabase() {
146163
m.initClusters()
147-
rand.Shuffle(len(m.clusters), func(i, j int) {
148-
m.clusters[i], m.clusters[j] = m.clusters[j], m.clusters[i]
164+
165+
clustersIdx := make([]int, 0, len(m.clusters))
166+
for i := range m.clusters {
167+
clustersIdx = append(clustersIdx, i)
168+
}
169+
170+
// 打乱副本,避免修改原始slice
171+
rand.Shuffle(len(clustersIdx), func(i, j int) {
172+
clustersIdx[i], clustersIdx[j] = clustersIdx[j], clustersIdx[i]
149173
})
150174

151175
// 按打乱后的顺序尝试发送,失败则重试下一个节点
152-
for _, cluster := range m.clusters {
176+
for _, clusterIdx := range clustersIdx {
177+
m.clustersMu.RLock()
178+
cluster := m.clusters[clusterIdx]
179+
m.clustersMu.RUnlock()
153180
reply, err := cluster.SendMessage(ctx, req)
154181
if err != nil {
155182
m.helper.Errorw("msg", "send message failed", "error", err, "uid", messageUID, "reply", reply, "cluster", cluster)
@@ -262,6 +289,10 @@ func (m *messageRepositoryImpl) processMessage(ctx context.Context, message *bo.
262289
func (m *messageRepositoryImpl) AppendMessage(ctx context.Context, messageUID snowflake.ID) error {
263290
// 将消息放入channel异步处理
264291
select {
292+
case <-m.stopChan:
293+
// channel已关闭,返回错误
294+
m.helper.Debugw("msg", "message channel is closed, cannot append message", "uid", messageUID)
295+
return merr.ErrorInternal("message channel is closed")
265296
case m.messageChan <- &messageTask{ctx: safety.CopyValueCtx(ctx), messageUID: messageUID}:
266297
m.helper.Debugw("msg", "message appended to channel", "uid", messageUID)
267298
return nil
@@ -278,6 +309,7 @@ func (m *messageRepositoryImpl) initClusters() {
278309
clusterEndpoints := strutil.SplitSkipEmpty(clusterConfig.GetEndpoints(), ",")
279310
clusterTimeout := clusterConfig.GetTimeout().AsDuration()
280311
clusterName := clusterConfig.GetName()
312+
var clusters []sender.Sender
281313
for _, clusterEndpoint := range clusterEndpoints {
282314
opts := []connect.InitOption{
283315
connect.WithProtocol(config.ClusterConfig_GRPC.String()),
@@ -290,8 +322,12 @@ func (m *messageRepositoryImpl) initClusters() {
290322
continue
291323
}
292324
m.d.AppendClose("grpcClient", func() error { return grpcClient.Close() })
293-
m.clusters = append(m.clusters, sender.NewClusterSender(grpcClient))
325+
clusters = append(clusters, sender.NewClusterSender(grpcClient))
294326
}
327+
// 加锁更新clusters
328+
m.clustersMu.Lock()
329+
m.clusters = clusters
330+
m.clustersMu.Unlock()
295331
})
296332
}
297333

0 commit comments

Comments
 (0)