Skip to content

Commit 175bf10

Browse files
committed
refactor(tests): streamline database connection handling in dockertest integration tests
- Removed unused database connection variables for MySQL and PostgreSQL. - Simplified database connection setup by directly using the connection returned from gorm.Open. - Enhanced cleanup logic to ensure proper closure of database connections after tests. - Improved error handling during database setup and migration processes for clearer test failure reporting.
1 parent 4928f83 commit 175bf10

File tree

3 files changed

+84
-65
lines changed

3 files changed

+84
-65
lines changed

cache/cache.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,14 @@ func (c *Gorm2Cache) SearchKeyExists(ctx context.Context, tableName string, SQL
149149
}
150150

151151
func (c *Gorm2Cache) BatchSetPrimaryKeyCache(ctx context.Context, tableName string, kvs []util.Kv) error {
152-
for idx, kv := range kvs {
153-
// kv.Key 已经是最终格式(单个值或已用":"连接的联合主键),直接传入
154-
kvs[idx].Key = util.GenPrimaryCacheKey(c.InstanceId, tableName, kv.Key)
152+
cacheKvs := make([]util.Kv, 0, len(kvs))
153+
for _, kv := range kvs {
154+
cacheKvs = append(cacheKvs, util.Kv{
155+
Key: util.GenPrimaryCacheKey(c.InstanceId, tableName, kv.Key),
156+
Value: kv.Value,
157+
})
155158
}
156-
return c.cache.BatchSetKeys(ctx, kvs)
159+
return c.cache.BatchSetKeys(ctx, cacheKvs)
157160
}
158161

159162
func (c *Gorm2Cache) SetSearchCache(ctx context.Context, cacheValue string, tableName string,
@@ -189,14 +192,16 @@ func (c *Gorm2Cache) BatchGetUniqueCache(ctx context.Context, tableName string,
189192
return c.cache.BatchGetValues(ctx, cacheKeys)
190193
}
191194

192-
// BatchSetUniqueCache 批量设置unique键缓存。
193-
// 注意:会原地修改 kvs 中每个元素的 Key 为完整缓存 key,调用方不应再使用传入的 kvs。
195+
// BatchSetUniqueCache 批量设置 unique 键缓存。不会修改调用方传入的 kvs。
194196
func (c *Gorm2Cache) BatchSetUniqueCache(ctx context.Context, tableName string, uniqueIndexName string, kvs []util.Kv) error {
195-
for idx, kv := range kvs {
196-
// kv.Key 已经是最终格式(单个值或已用":"连接的联合unique键),直接传入
197-
kvs[idx].Key = util.GenUniqueCacheKey(c.InstanceId, tableName, uniqueIndexName, kv.Key)
197+
cacheKvs := make([]util.Kv, 0, len(kvs))
198+
for _, kv := range kvs {
199+
cacheKvs = append(cacheKvs, util.Kv{
200+
Key: util.GenUniqueCacheKey(c.InstanceId, tableName, uniqueIndexName, kv.Key),
201+
Value: kv.Value,
202+
})
198203
}
199-
return c.cache.BatchSetKeys(ctx, kvs)
204+
return c.cache.BatchSetKeys(ctx, cacheKvs)
200205
}
201206

202207
// InvalidateUniqueCache 失效unique键缓存

cache/dockertest_integration_test.go

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,13 @@ func (UserSession) TableName() string {
5858
var (
5959
mysqlPool *dockertest.Pool
6060
mysqlResource *dockertest.Resource
61-
mysqlDB *gorm.DB
6261
mysqlDSN string
6362
setupMySQLOnce sync.Once
6463
cleanupMySQLOnce sync.Once
6564
mysqlSetupErr error
6665

6766
pgPool *dockertest.Pool
6867
pgResource *dockertest.Resource
69-
pgDB *gorm.DB
7068
pgDSN string
7169
setupPGOnce sync.Once
7270
cleanupPGOnce sync.Once
@@ -106,21 +104,19 @@ func setupMySQL(t *testing.T) *gorm.DB {
106104

107105
mysqlPool.MaxWait = 120 * time.Second
108106
if err := mysqlPool.Retry(func() error {
109-
var openErr error
110-
mysqlDB, openErr = gorm.Open(mysql.Open(mysqlDSN), &gorm.Config{
107+
conn, openErr := gorm.Open(mysql.Open(mysqlDSN), &gorm.Config{
111108
Logger: logger.Default.LogMode(logger.Silent),
112109
})
113110
if openErr != nil {
114111
return openErr
115112
}
116-
sqlDB, openErr := mysqlDB.DB()
113+
sqlDB, openErr := conn.DB()
117114
if openErr != nil {
118115
return openErr
119116
}
120-
sqlDB.SetMaxOpenConns(10)
121-
sqlDB.SetMaxIdleConns(5)
122-
sqlDB.SetConnMaxLifetime(time.Hour)
123-
return sqlDB.Ping()
117+
err := sqlDB.Ping()
118+
_ = sqlDB.Close()
119+
return err
124120
}); err != nil {
125121
mysqlSetupErr = fmt.Errorf("could not connect to MySQL: %w", err)
126122
return
@@ -129,22 +125,31 @@ func setupMySQL(t *testing.T) *gorm.DB {
129125
if mysqlSetupErr != nil {
130126
t.Fatalf("MySQL setup failed: %v", mysqlSetupErr)
131127
}
132-
if mysqlDB == nil {
133-
t.Fatal("MySQL DB is nil after setup")
128+
129+
db, err := gorm.Open(mysql.Open(mysqlDSN), &gorm.Config{
130+
Logger: logger.Default.LogMode(logger.Silent),
131+
})
132+
if err != nil {
133+
t.Fatalf("open MySQL DB failed: %v", err)
134134
}
135+
sqlDB, err := db.DB()
136+
if err != nil {
137+
t.Fatalf("get sql.DB from gorm failed: %v", err)
138+
}
139+
sqlDB.SetMaxOpenConns(10)
140+
sqlDB.SetMaxIdleConns(5)
141+
sqlDB.SetConnMaxLifetime(time.Hour)
142+
t.Cleanup(func() {
143+
_ = sqlDB.Close()
144+
})
135145

136-
// Clean tables before each test
137-
if err := mysqlDB.Exec("DROP TABLE IF EXISTS user_roles, users, user_sessions").Error; err != nil {
146+
if err := db.Exec("DROP TABLE IF EXISTS user_roles, users, user_sessions").Error; err != nil {
138147
t.Fatalf("failed to drop MySQL tables: %v", err)
139148
}
140-
141-
// Auto migrate
142-
err := mysqlDB.AutoMigrate(&UserRole{}, &User{}, &UserSession{})
143-
if err != nil {
149+
if err := db.AutoMigrate(&UserRole{}, &User{}, &UserSession{}); err != nil {
144150
t.Fatalf("Auto migrate error: %v", err)
145151
}
146-
147-
return mysqlDB
152+
return db
148153
}
149154

150155
func setupPostgreSQL(t *testing.T) *gorm.DB {
@@ -188,21 +193,19 @@ func setupPostgreSQL(t *testing.T) *gorm.DB {
188193

189194
pgPool.MaxWait = 120 * time.Second
190195
if err := pgPool.Retry(func() error {
191-
var openErr error
192-
pgDB, openErr = gorm.Open(postgres.Open(pgDSN), &gorm.Config{
196+
conn, openErr := gorm.Open(postgres.Open(pgDSN), &gorm.Config{
193197
Logger: logger.Default.LogMode(logger.Silent),
194198
})
195199
if openErr != nil {
196200
return openErr
197201
}
198-
sqlDB, openErr := pgDB.DB()
202+
sqlDB, openErr := conn.DB()
199203
if openErr != nil {
200204
return openErr
201205
}
202-
sqlDB.SetMaxOpenConns(10)
203-
sqlDB.SetMaxIdleConns(5)
204-
sqlDB.SetConnMaxLifetime(time.Hour)
205-
return sqlDB.Ping()
206+
err := sqlDB.Ping()
207+
_ = sqlDB.Close()
208+
return err
206209
}); err != nil {
207210
pgSetupErr = fmt.Errorf("could not connect to PostgreSQL: %w", err)
208211
return
@@ -211,47 +214,43 @@ func setupPostgreSQL(t *testing.T) *gorm.DB {
211214
if pgSetupErr != nil {
212215
t.Fatalf("PostgreSQL setup failed: %v", pgSetupErr)
213216
}
214-
if pgDB == nil {
215-
t.Fatal("PostgreSQL DB is nil after setup")
217+
218+
db, err := gorm.Open(postgres.Open(pgDSN), &gorm.Config{
219+
Logger: logger.Default.LogMode(logger.Silent),
220+
})
221+
if err != nil {
222+
t.Fatalf("open PostgreSQL DB failed: %v", err)
223+
}
224+
sqlDB, err := db.DB()
225+
if err != nil {
226+
t.Fatalf("get sql.DB from gorm failed: %v", err)
216227
}
228+
sqlDB.SetMaxOpenConns(10)
229+
sqlDB.SetMaxIdleConns(5)
230+
sqlDB.SetConnMaxLifetime(time.Hour)
231+
t.Cleanup(func() {
232+
_ = sqlDB.Close()
233+
})
217234

218-
// Clean tables before each test
219-
if err := pgDB.Exec("DROP TABLE IF EXISTS user_roles, users, user_sessions CASCADE").Error; err != nil {
235+
if err := db.Exec("DROP TABLE IF EXISTS user_roles, users, user_sessions CASCADE").Error; err != nil {
220236
t.Fatalf("failed to drop PostgreSQL tables: %v", err)
221237
}
222-
223-
// Auto migrate
224-
err := pgDB.AutoMigrate(&UserRole{}, &User{}, &UserSession{})
225-
if err != nil {
238+
if err := db.AutoMigrate(&UserRole{}, &User{}, &UserSession{}); err != nil {
226239
t.Fatalf("Auto migrate error: %v", err)
227240
}
228-
229-
return pgDB
241+
return db
230242
}
231243

232244
func TestMain(m *testing.M) {
233245
code := m.Run()
234246

235-
// Cleanup
247+
// Cleanup: 每测试独立 *gorm.DB 由 t.Cleanup 关闭,此处仅回收容器
236248
cleanupMySQLOnce.Do(func() {
237-
if mysqlDB != nil {
238-
sqlDB, _ := mysqlDB.DB()
239-
if sqlDB != nil {
240-
sqlDB.Close()
241-
}
242-
}
243249
if mysqlResource != nil && mysqlPool != nil {
244250
_ = mysqlPool.Purge(mysqlResource)
245251
}
246252
})
247-
248253
cleanupPGOnce.Do(func() {
249-
if pgDB != nil {
250-
sqlDB, _ := pgDB.DB()
251-
if sqlDB != nil {
252-
sqlDB.Close()
253-
}
254-
}
255254
if pgResource != nil && pgPool != nil {
256255
_ = pgPool.Purge(pgResource)
257256
}
@@ -622,7 +621,9 @@ func TestCacheStats_MySQL(t *testing.T) {
622621
if err != nil {
623622
t.Fatalf("Failed to create cache: %v", err)
624623
}
625-
db.Use(cache)
624+
if err := db.Use(cache); err != nil {
625+
t.Fatalf("failed to register cache plugin: %v", err)
626+
}
626627

627628
// 创建测试数据
628629
user := User{Email: "test@example.com", Username: "test", Name: "Test User"}

cache/query.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,17 @@ func (h *queryHandler) AfterQuery() func(db *gorm.DB) {
429429
}
430430
}
431431
}()
432-
if !cache.Config.AsyncWrite {
432+
if cache.Config.AsyncWrite {
433+
// 异步写时不在主路径 Wait,由后台 goroutine 在写完后 cancel,避免 fillCallAfterQuery 提前 cancel 导致写缓存被中止
434+
if cancelObj, hasCancel := db.InstanceGet("gorm:cache:query:single_flight_cancel"); hasCancel {
435+
if cancel, ok := cancelObj.(context.CancelFunc); ok {
436+
go func() {
437+
wg.Wait()
438+
cancel()
439+
}()
440+
}
441+
}
442+
} else {
433443
wg.Wait()
434444
}
435445
return
@@ -491,9 +501,12 @@ func (h *queryHandler) fillCallAfterQuery(db *gorm.DB) {
491501
return
492502
}
493503
// 释放 singleflight leader 使用的 background context,避免 timer 泄漏
494-
if cancelObj, hasCancel := db.InstanceGet("gorm:cache:query:single_flight_cancel"); hasCancel {
495-
if cancel, ok := cancelObj.(context.CancelFunc); ok {
496-
cancel()
504+
// AsyncWrite 时由 AfterQuery 内启动的 goroutine 在 wg.Wait() 后 cancel,此处不 cancel 以免提前中止异步写缓存
505+
if !h.cache.Config.AsyncWrite {
506+
if cancelObj, hasCancel := db.InstanceGet("gorm:cache:query:single_flight_cancel"); hasCancel {
507+
if cancel, ok := cancelObj.(context.CancelFunc); ok {
508+
cancel()
509+
}
497510
}
498511
}
499512
c.dest = db.Statement.Dest

0 commit comments

Comments
 (0)