Skip to content

Commit 7da654f

Browse files
committed
Split primary_key gen as single file
1 parent 63aa7f8 commit 7da654f

File tree

3 files changed

+42
-19
lines changed

3 files changed

+42
-19
lines changed

primary_key.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package sharding
2+
3+
import "fmt"
4+
5+
const (
6+
// Use Snowflake primary key generator
7+
PKSnowflake = iota
8+
// Use PostgreSQL sequence primary key generator
9+
PKPGSequence
10+
// Use custom primary key generator
11+
PKCustom
12+
)
13+
14+
func (s *Sharding) genSnowflakeKey(index int64) int64 {
15+
return s.snowflakeNodes[index].Generate().Int64()
16+
}
17+
18+
func (s *Sharding) genPostgreSQLSequenceKey(tableName string, index int64) int64 {
19+
var id int64
20+
err := s.DB.Raw("SELECT nextval('" + pgSeqName(tableName) + "')").Scan(&id).Error
21+
if err != nil {
22+
panic(err)
23+
}
24+
return id
25+
}
26+
27+
func pgSeqName(table string) string {
28+
return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
29+
}

primary_key_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package sharding
2+
3+
import (
4+
"testing"
5+
6+
"github.com/longbridgeapp/assert"
7+
)
8+
9+
func Test_pgSeqName(t *testing.T) {
10+
assert.Equal(t, "gorm_sharding_users_id_seq", pgSeqName("users"))
11+
}

sharding.go

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@ import (
1414
"gorm.io/gorm/schema"
1515
)
1616

17-
const (
18-
PKSnowflake = iota // Use Snowflake primary key generator
19-
PKPGSequence // Use PostgreSQL sequence primary key generator
20-
PKCustom // Use custom primary key generator
21-
)
22-
2317
var (
2418
ErrMissingShardingKey = errors.New("sharding key or id required, and use operator =")
2519
ErrInvalidID = errors.New("invalid id format")
@@ -109,17 +103,10 @@ func (s *Sharding) Register(config Config, tables ...interface{}) *Sharding {
109103
}
110104

111105
if c.PrimaryKeyGenerator == PKSnowflake {
112-
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
113-
return s.snowflakeNodes[index].Generate().Int64()
114-
}
106+
c.PrimaryKeyGeneratorFn = s.genSnowflakeKey
115107
} else if c.PrimaryKeyGenerator == PKPGSequence {
116108
c.PrimaryKeyGeneratorFn = func(index int64) int64 {
117-
var id int64
118-
err := s.DB.Raw("SELECT nextval('" + pgSeqName(t) + "')").Scan(&id).Error
119-
if err != nil {
120-
panic(err)
121-
}
122-
return id
109+
return s.genPostgreSQLSequenceKey(t, index)
123110
}
124111
} else if c.PrimaryKeyGenerator == PKCustom {
125112
if c.PrimaryKeyGeneratorFn == nil {
@@ -435,7 +422,3 @@ func replaceOrderByTableName(orderBy []*sqlparser.OrderingTerm, oldName, newName
435422

436423
return orderBy
437424
}
438-
439-
func pgSeqName(table string) string {
440-
return fmt.Sprintf("gorm_sharding_%s_id_seq", table)
441-
}

0 commit comments

Comments
 (0)