Skip to content

Commit c6e2cad

Browse files
authored
feat: Support for AutoMigrate (#17)
1 parent d9ee537 commit c6e2cad

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

dialector.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package sharding
2+
3+
import (
4+
"fmt"
5+
6+
"gorm.io/gorm"
7+
)
8+
9+
type ShardingDialector struct {
10+
gorm.Dialector
11+
sharding *Sharding
12+
}
13+
14+
type ShardingMigrator struct {
15+
gorm.Migrator
16+
sharding *Sharding
17+
dialector gorm.Dialector
18+
}
19+
20+
func NewShardingDialector(d gorm.Dialector, s *Sharding) ShardingDialector {
21+
return ShardingDialector{
22+
Dialector: d,
23+
sharding: s,
24+
}
25+
}
26+
27+
func (d ShardingDialector) Migrator(db *gorm.DB) gorm.Migrator {
28+
m := d.Dialector.Migrator(db)
29+
return ShardingMigrator{
30+
Migrator: m,
31+
sharding: d.sharding,
32+
dialector: d.Dialector,
33+
}
34+
}
35+
36+
func (m ShardingMigrator) AutoMigrate(dst ...interface{}) error {
37+
noShardingDsts := make([]interface{}, 0)
38+
for _, model := range dst {
39+
stmt := &gorm.Statement{DB: m.sharding.DB}
40+
if err := stmt.Parse(model); err == nil {
41+
if cfg, ok := m.sharding.configs[stmt.Table]; ok {
42+
// support sharding table
43+
suffixs := cfg.ShardingSuffixs()
44+
if len(suffixs) == 0 {
45+
return fmt.Errorf("sharding table:%s suffixs is empty", stmt.Table)
46+
}
47+
48+
for _, suffix := range suffixs {
49+
shardingTable := stmt.Table + suffix
50+
tx := stmt.DB.Session(&gorm.Session{}).Table(shardingTable)
51+
if err := m.dialector.Migrator(tx).AutoMigrate(model); err != nil {
52+
return err
53+
}
54+
}
55+
56+
if cfg.DoubleWrite {
57+
noShardingDsts = append(noShardingDsts, model)
58+
}
59+
} else {
60+
noShardingDsts = append(noShardingDsts, model)
61+
}
62+
} else {
63+
return err
64+
}
65+
}
66+
67+
if len(noShardingDsts) > 0 {
68+
if err := m.Migrator.AutoMigrate(noShardingDsts...); err != nil {
69+
return err
70+
}
71+
}
72+
return nil
73+
}
74+
75+
// TODO: DropTable drop sharding table
76+
// func (m ShardingMigrator) DropTable(dst ...interface{}) error {
77+
// }

sharding.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ type Config struct {
5656
// }
5757
ShardingAlgorithm func(columnValue interface{}) (suffix string, err error)
5858

59+
// ShardingSuffixs specifies a function to generate all table's suffix.
60+
// Used to support Migrator.
61+
// For example, this function get a mod all sharding suffixs.
62+
//
63+
// func () (suffixs []string) {
64+
// numberOfShards := 5
65+
// for i := 0; i < numberOfShards; i++ {
66+
// suffixs = append(suffixs, fmt.Sprintf("_%02d", i%numberOfShards))
67+
// }
68+
// return
69+
// }
70+
ShardingSuffixs func() (suffixs []string)
71+
5972
// ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding
6073
// table's suffix by the primary key. Used when no sharding key specified.
6174
// For example, this function use the Snowflake library to generate the suffix.
@@ -161,10 +174,24 @@ func (s *Sharding) compile() error {
161174
return "", fmt.Errorf("default algorithm only support integer and string column," +
162175
"if you use other type, specify you own ShardingAlgorithm")
163176
}
177+
164178
return fmt.Sprintf(c.tableFormat, id%int(c.NumberOfShards)), nil
165179
}
166180
}
167181

182+
if c.ShardingSuffixs == nil {
183+
c.ShardingSuffixs = func() (suffixs []string) {
184+
for i := 0; i < int(c.NumberOfShards); i++ {
185+
suffix, err := c.ShardingAlgorithm(i)
186+
if err != nil {
187+
return nil
188+
}
189+
suffixs = append(suffixs, suffix)
190+
}
191+
return
192+
}
193+
}
194+
168195
if c.ShardingAlgorithmByPrimaryKey == nil {
169196
if c.PrimaryKeyGenerator == PKSnowflake {
170197
c.ShardingAlgorithmByPrimaryKey = func(id int64) (suffix string) {
@@ -194,6 +221,7 @@ func (s *Sharding) LastQuery() string {
194221

195222
// Initialize implement for Gorm plugin interface
196223
func (s *Sharding) Initialize(db *gorm.DB) error {
224+
db.Dialector = NewShardingDialector(db.Dialector, s)
197225
s.DB = db
198226
s.registerCallbacks(db)
199227

sharding_test.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"os"
66
"regexp"
7+
"sort"
78
"strings"
89
"testing"
910

@@ -151,6 +152,20 @@ func dropTables() {
151152
}
152153
}
153154

155+
func TestAutoMigrate(t *testing.T) {
156+
targetTables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
157+
for _, table := range targetTables {
158+
db.Exec("DROP TABLE IF EXISTS " + table)
159+
db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq"))
160+
}
161+
162+
db.AutoMigrate(&Order{}, &Category{})
163+
tables, _ := db.Migrator().GetTables()
164+
sort.Strings(tables)
165+
sort.Strings(targetTables)
166+
assert.Equal(t, tables, targetTables)
167+
}
168+
154169
func TestInsert(t *testing.T) {
155170
tx := db.Create(&Order{ID: 100, UserID: 100, Product: "iPhone"})
156171
assertQueryResult(t, `INSERT INTO orders_0 ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx)

0 commit comments

Comments
 (0)