Skip to content

Commit 3246e11

Browse files
committed
fix: singleflight context isolation, composite key behavior, and key collision
- query: use db.Statement.Context for singleflight leader cache reads (BatchGetPrimaryCache, BatchGetUniqueCache, GetSearchCache) so leader cancellation does not affect cache lookups. - query: log unique cache write count only; avoid logging full uniqueKvs payload. - cache: document that BatchSetUniqueCache mutates the caller's kvs slice. - helpers: composite primary/unique keys use Cartesian product (generateCartesianProduct) so e.g. user_id=1, role_id IN (1,2,3) yields ["1:1","1:2","1:3"]; add TestGetPrimaryKeysFromWhereClause_CompositeKeyCartesianProduct. - helpers: getObjectsAfterLoad composite key: treat nil ValueOf and zero values as invalid and skip incomplete keys (valid flag, no partial key append). - util/key: encode composite key parts with base64.RawURLEncoding to avoid collision (e.g. "a:b","c" vs "a","b:c"); add TestGenPrimaryCacheKey_NoCollision; update key_test expected values for encoded output.
1 parent e188810 commit 3246e11

File tree

6 files changed

+169
-126
lines changed

6 files changed

+169
-126
lines changed

cache/cache.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ func (c *Gorm2Cache) BatchGetUniqueCache(ctx context.Context, tableName string,
189189
return c.cache.BatchGetValues(ctx, cacheKeys)
190190
}
191191

192-
// BatchSetUniqueCache 批量设置unique键缓存
192+
// BatchSetUniqueCache 批量设置unique键缓存。
193+
// 注意:会原地修改 kvs 中每个元素的 Key 为完整缓存 key,调用方不应再使用传入的 kvs。
193194
func (c *Gorm2Cache) BatchSetUniqueCache(ctx context.Context, tableName string, uniqueIndexName string, kvs []util.Kv) error {
194195
for idx, kv := range kvs {
195196
// kv.Key 已经是最终格式(单个值或已用":"连接的联合unique键),直接传入

cache/helpers.go

Lines changed: 70 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -97,31 +97,12 @@ func getPrimaryKeysFromWhereClause(db *gorm.DB) []string {
9797
return uniqueStringSlice(fieldValuesMap[primaryKeyFields[0].DBName])
9898
}
9999

100-
// 对于联合主键:需要组合所有字段的值
101-
// 先获取每个字段的值列表长度,取最小值(因为可能有IN查询)
102-
maxLen := len(fieldValuesMap[primaryKeyFields[0].DBName])
103-
for _, field := range primaryKeyFields[1:] {
104-
if len(fieldValuesMap[field.DBName]) < maxLen {
105-
maxLen = len(fieldValuesMap[field.DBName])
106-
}
107-
}
108-
109-
// 生成联合主键key
110-
primaryKeys := make([]string, 0, maxLen)
111-
for i := 0; i < maxLen; i++ {
112-
keyParts := make([]string, 0, len(primaryKeyFields))
113-
for _, field := range primaryKeyFields {
114-
values := fieldValuesMap[field.DBName]
115-
if i < len(values) {
116-
keyParts = append(keyParts, values[i])
117-
} else {
118-
// 如果某个字段的值不够,使用最后一个值
119-
keyParts = append(keyParts, values[len(values)-1])
120-
}
121-
}
122-
primaryKeys = append(primaryKeys, strings.Join(keyParts, ":"))
100+
// 对于联合主键:生成所有字段值的笛卡尔积(如 user_id=1, role_id IN (1,2,3) -> "1:1","1:2","1:3")
101+
valueSlices := make([][]string, 0, len(primaryKeyFields))
102+
for _, field := range primaryKeyFields {
103+
valueSlices = append(valueSlices, fieldValuesMap[field.DBName])
123104
}
124-
105+
primaryKeys := generateCartesianProduct(valueSlices, ":")
125106
return uniqueStringSlice(primaryKeys)
126107
}
127108

@@ -312,44 +293,29 @@ func getUniqueKeysFromWhereClause(db *gorm.DB) map[string][]string {
312293
result[indexName] = uniqueStringSlice(fieldValuesMap[index.Fields[0].Field.DBName])
313294
}
314295
} else {
315-
// 联合unique键
296+
// 联合unique键:笛卡尔积
316297
if len(index.Fields) == 0 {
317298
continue
318299
}
319-
firstField := index.Fields[0].Field
320-
if firstField == nil {
321-
continue
322-
}
323-
maxLen := len(fieldValuesMap[firstField.DBName])
324-
for _, fieldOption := range index.Fields[1:] {
300+
valueSlices := make([][]string, 0, len(index.Fields))
301+
skip := false
302+
for _, fieldOption := range index.Fields {
325303
if fieldOption.Field == nil {
326-
continue
304+
skip = true
305+
break
327306
}
328-
if len(fieldValuesMap[fieldOption.Field.DBName]) < maxLen {
329-
maxLen = len(fieldValuesMap[fieldOption.Field.DBName])
307+
vals := fieldValuesMap[fieldOption.Field.DBName]
308+
if len(vals) == 0 {
309+
skip = true
310+
break
330311
}
312+
valueSlices = append(valueSlices, vals)
331313
}
332-
333-
uniqueKeys := make([]string, 0, maxLen)
334-
for i := 0; i < maxLen; i++ {
335-
keyParts := make([]string, 0, len(index.Fields))
336-
for _, fieldOption := range index.Fields {
337-
if fieldOption.Field == nil {
338-
continue
339-
}
340-
values := fieldValuesMap[fieldOption.Field.DBName]
341-
if i < len(values) {
342-
keyParts = append(keyParts, values[i])
343-
} else if len(values) > 0 {
344-
keyParts = append(keyParts, values[len(values)-1])
345-
}
314+
if !skip && len(valueSlices) == len(index.Fields) {
315+
uniqueKeys := generateCartesianProduct(valueSlices, ":")
316+
if len(uniqueKeys) > 0 {
317+
result[indexName] = uniqueStringSlice(uniqueKeys)
346318
}
347-
if len(keyParts) == len(index.Fields) {
348-
uniqueKeys = append(uniqueKeys, strings.Join(keyParts, ":"))
349-
}
350-
}
351-
if len(uniqueKeys) > 0 {
352-
result[indexName] = uniqueStringSlice(uniqueKeys)
353319
}
354320
}
355321
}
@@ -557,22 +523,23 @@ func getObjectsAfterLoad(db *gorm.DB) (primaryKeys []string, objects []interface
557523
primaryKeys = append(primaryKeys, fmt.Sprintf("%v", primaryKey))
558524
}
559525
} else {
560-
// 联合主键
526+
// 联合主键:必须所有字段都能取到非零值,且 ValueOf 非 nil
561527
keyParts := make([]string, 0, len(primaryKeyFields))
562-
allZero := true
528+
valid := true
563529
for _, field := range primaryKeyFields {
564530
valueOf := field.ValueOf
565-
if valueOf != nil {
566-
primaryKey, isZero := valueOf(context.Background(), elemValue)
567-
if isZero {
568-
allZero = true
569-
break
570-
}
571-
allZero = false
572-
keyParts = append(keyParts, fmt.Sprintf("%v", primaryKey))
531+
if valueOf == nil {
532+
valid = false
533+
break
573534
}
535+
primaryKey, isZero := valueOf(context.Background(), elemValue)
536+
if isZero {
537+
valid = false
538+
break
539+
}
540+
keyParts = append(keyParts, fmt.Sprintf("%v", primaryKey))
574541
}
575-
if allZero {
542+
if !valid || len(keyParts) != len(primaryKeyFields) {
576543
continue
577544
}
578545
primaryKeys = append(primaryKeys, strings.Join(keyParts, ":"))
@@ -636,6 +603,44 @@ func getUniqueKeysFromObjects(db *gorm.DB, objects []interface{}) map[string]map
636603
return result
637604
}
638605

606+
// generateCartesianProduct 对多列值做笛卡尔积,每行用 sep 连接成一条 key。
607+
// 例如 valueSlices = [["1"], ["1","2","3"]] -> ["1:1","1:2","1:3"]。
608+
func generateCartesianProduct(valueSlices [][]string, sep string) []string {
609+
if len(valueSlices) == 0 {
610+
return nil
611+
}
612+
n := 1
613+
for _, s := range valueSlices {
614+
if len(s) == 0 {
615+
return nil
616+
}
617+
n *= len(s)
618+
}
619+
result := make([]string, 0, n)
620+
idx := make([]int, len(valueSlices))
621+
for {
622+
parts := make([]string, len(valueSlices))
623+
for i, s := range valueSlices {
624+
parts[i] = s[idx[i]]
625+
}
626+
result = append(result, strings.Join(parts, sep))
627+
// next combination
628+
j := len(valueSlices) - 1
629+
for j >= 0 {
630+
idx[j]++
631+
if idx[j] < len(valueSlices[j]) {
632+
break
633+
}
634+
idx[j] = 0
635+
j--
636+
}
637+
if j < 0 {
638+
break
639+
}
640+
}
641+
return result
642+
}
643+
639644
func uniqueStringSlice(slice []string) []string {
640645
retSlice := make([]string, 0)
641646
mmap := make(map[string]struct{})

cache/helpers_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,49 @@ func TestGetPrimaryKeyFields(t *testing.T) {
375375
}
376376
}
377377

378+
// TestGetPrimaryKeysFromWhereClause_CompositeKeyCartesianProduct ensures
379+
// composite primary key with differing value counts (e.g. user_id=1, role_id IN (1,2,3))
380+
// yields Cartesian product ["1:1","1:2","1:3"], not just ["1:1"].
381+
func TestGetPrimaryKeysFromWhereClause_CompositeKeyCartesianProduct(t *testing.T) {
382+
s := &schema.Schema{
383+
Table: "user_roles",
384+
}
385+
s.Fields = []*schema.Field{
386+
{DBName: "user_id", PrimaryKey: true},
387+
{DBName: "role_id", PrimaryKey: true},
388+
{DBName: "name", PrimaryKey: false},
389+
}
390+
db := &gorm.DB{
391+
Statement: &gorm.Statement{
392+
Schema: s,
393+
},
394+
}
395+
db.Statement.Clauses = map[string]clause.Clause{
396+
"WHERE": {
397+
Expression: clause.Where{
398+
Exprs: []clause.Expression{
399+
clause.Eq{Column: "user_id", Value: 1},
400+
clause.IN{Column: "role_id", Values: []interface{}{1, 2, 3}},
401+
},
402+
},
403+
},
404+
}
405+
got := getPrimaryKeysFromWhereClause(db)
406+
want := []string{"1:1", "1:2", "1:3"}
407+
if len(got) != len(want) {
408+
t.Fatalf("expected %d keys, got %d: %v", len(want), len(got), got)
409+
}
410+
seen := make(map[string]bool)
411+
for _, k := range got {
412+
seen[k] = true
413+
}
414+
for _, w := range want {
415+
if !seen[w] {
416+
t.Errorf("missing expected key %q, got %v", w, got)
417+
}
418+
}
419+
}
420+
378421
func TestGetAllUniqueIndexes(t *testing.T) {
379422
tests := []struct {
380423
name string

cache/query.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ func (h *queryHandler) BeforeQuery() func(db *gorm.DB) {
137137
return
138138
}
139139

140-
// primary cache hit
141-
cacheValues, err := cache.BatchGetPrimaryCache(ctx, tableName, primaryKeys)
140+
// primary cache hit (use db.Statement.Context: when leader it is bgCtx to avoid cascading cancel)
141+
cacheValues, err := cache.BatchGetPrimaryCache(db.Statement.Context, tableName, primaryKeys)
142142
if err != nil {
143-
cache.Logger.CtxError(ctx, "[BeforeQuery] get primary cache value for key %v error: %v", primaryKeys, err)
143+
cache.Logger.CtxError(db.Statement.Context, "[BeforeQuery] get primary cache value for key %v error: %v", primaryKeys, err)
144144
db.Error = nil
145145
return
146146
}
@@ -193,10 +193,10 @@ func (h *queryHandler) BeforeQuery() func(db *gorm.DB) {
193193
continue
194194
}
195195

196-
// unique cache hit
197-
cacheValues, err := cache.BatchGetUniqueCache(ctx, tableName, uniqueIndexName, uniqueKeys)
196+
// unique cache hit (use db.Statement.Context: when leader it is bgCtx to avoid cascading cancel)
197+
cacheValues, err := cache.BatchGetUniqueCache(db.Statement.Context, tableName, uniqueIndexName, uniqueKeys)
198198
if err != nil {
199-
cache.Logger.CtxError(ctx, "[BeforeQuery] get unique cache value for index %s key %v error: %v", uniqueIndexName, uniqueKeys, err)
199+
cache.Logger.CtxError(db.Statement.Context, "[BeforeQuery] get unique cache value for index %s key %v error: %v", uniqueIndexName, uniqueKeys, err)
200200
continue
201201
}
202202
if len(cacheValues) != len(uniqueKeys) {
@@ -228,11 +228,11 @@ func (h *queryHandler) BeforeQuery() func(db *gorm.DB) {
228228
}
229229

230230
trySearchCache := func() (hit bool) {
231-
// search cache hit
232-
cacheValue, err := cache.GetSearchCache(ctx, tableName, sql, db.Statement.Vars...)
231+
// search cache hit (use db.Statement.Context: when leader it is bgCtx to avoid cascading cancel)
232+
cacheValue, err := cache.GetSearchCache(db.Statement.Context, tableName, sql, db.Statement.Vars...)
233233
if err != nil {
234234
if !errors.Is(err, storage.ErrCacheNotFound) {
235-
cache.Logger.CtxError(ctx, "[BeforeQuery] get cache value for sql %s error: %v", sql, err)
235+
cache.Logger.CtxError(db.Statement.Context, "[BeforeQuery] get cache value for sql %s error: %v", sql, err)
236236
}
237237
db.Error = nil
238238
return
@@ -419,7 +419,7 @@ func (h *queryHandler) AfterQuery() func(db *gorm.DB) {
419419
}
420420
}
421421
if len(uniqueKvs) > 0 {
422-
cache.Logger.CtxInfo(ctx, "[AfterQuery] start to set unique cache for index %s kvs: %+v", indexName, uniqueKvs)
422+
cache.Logger.CtxInfo(ctx, "[AfterQuery] start to set unique cache for index %s count=%d", indexName, len(uniqueKvs))
423423
err := cache.BatchSetUniqueCache(ctx, tableName, indexName, uniqueKvs)
424424
if err != nil {
425425
cache.Logger.CtxError(ctx, "[AfterQuery] batch set unique cache for index %s error: %v", indexName, err)

util/key.go

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package util
22

33
import (
4+
"encoding/base64"
45
"fmt"
56
"math/rand"
67
"reflect"
@@ -19,17 +20,25 @@ func GenInstanceId() string {
1920
return string(str)
2021
}
2122

23+
// joinKeyParts 对多段做 base64 编码后用 sep 连接,避免 "a:b","c" 与 "a","b:c" 碰撞
24+
func joinKeyParts(parts []string, sep string) string {
25+
if len(parts) == 0 {
26+
return ""
27+
}
28+
if len(parts) == 1 {
29+
return parts[0]
30+
}
31+
encoded := make([]string, len(parts))
32+
for i, p := range parts {
33+
encoded[i] = base64.RawURLEncoding.EncodeToString([]byte(p))
34+
}
35+
return strings.Join(encoded, sep)
36+
}
37+
2238
// GenPrimaryCacheKey 生成主键缓存key,支持单个主键和联合主键
23-
// 如果传入多个参数,会按顺序用":"连接;如果只传入一个参数,直接使用(可能是已经连接好的联合主键)
39+
// 联合主键各段会做 base64 编码,避免含 ":" 的值产生 key 碰撞
2440
func GenPrimaryCacheKey(instanceId string, tableName string, primaryKeyValues ...string) string {
25-
var key string
26-
if len(primaryKeyValues) == 1 {
27-
// 单个参数,直接使用(可能是单个主键值,也可能是已经连接好的联合主键)
28-
key = primaryKeyValues[0]
29-
} else {
30-
// 多个参数,用":"连接
31-
key = strings.Join(primaryKeyValues, ":")
32-
}
41+
key := joinKeyParts(primaryKeyValues, ":")
3342
return fmt.Sprintf("%s:%s:p:%s:%s", GormCachePrefix, instanceId, tableName, key)
3443
}
3544

@@ -54,16 +63,9 @@ func GenPrimaryCachePrefix(instanceId string, tableName string) string {
5463
}
5564

5665
// GenUniqueCacheKey 生成unique键缓存key,支持单个unique键和联合unique键
57-
// 如果传入多个参数,会按顺序用":"连接;如果只传入一个参数,直接使用(可能是已经连接好的联合unique键)
66+
// 联合unique键各段会做 base64 编码,避免含 ":" 的值产生 key 碰撞
5867
func GenUniqueCacheKey(instanceId string, tableName string, uniqueIndexName string, uniqueKeyValues ...string) string {
59-
var key string
60-
if len(uniqueKeyValues) == 1 {
61-
// 单个参数,直接使用(可能是单个unique键值,也可能是已经连接好的联合unique键)
62-
key = uniqueKeyValues[0]
63-
} else {
64-
// 多个参数,用":"连接
65-
key = strings.Join(uniqueKeyValues, ":")
66-
}
68+
key := joinKeyParts(uniqueKeyValues, ":")
6769
return fmt.Sprintf("%s:%s:u:%s:%s:%s", GormCachePrefix, instanceId, tableName, uniqueIndexName, key)
6870
}
6971

0 commit comments

Comments
 (0)