diff --git a/README.md b/README.md index 3c0debd..ab03c0c 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,37 @@ db.Use(sharding.Register(sharding.Config{ }, "orders", Notification{}, AuditLog{})) // This case for show up give notifications, audit_logs table use same sharding rule. ``` +support multiple config with multiple model register +model must implement ShardingInterface +example: base_test.go +```go +type Order struct { + gorm.Model + UserID uint64 + ProductID int64 + Amount int64 + + sharding.BaseSharding +} + +func (Order) TableName() string { + return "order" +} + +type OtherTable struct { + gorm.Model + UserID uint64 + + sharding.BaseSharding +} + +func (OtherTable) TableName() string { + return "other_table" +} + + +db.Use(sharding.RegisterWithModel(&Order{}, &OtherTable{})) +``` Use the db session as usual. Just note that the query should have the `Sharding Key` when operate sharding tables. diff --git a/base.go b/base.go new file mode 100644 index 0000000..72fd99d --- /dev/null +++ b/base.go @@ -0,0 +1,76 @@ +package sharding + +import ( + "fmt" +) + +const DefaultShardingNumber = 128 + +type ShardingInterface interface { + TableName() string + + NumberOfShards() uint + Sharding() Config + TableSuffix() string +} + +type BaseSharding struct { + // 存储子类设置的默认分片数 + DefaultShards uint `gorm:"-"` +} + +func (b *BaseSharding) TableName() string { + return "" +} + +func (b *BaseSharding) NumberOfShards() uint { + if b.DefaultShards == 0 { + return DefaultShardingNumber + } + + // 回退到默认值 + return b.DefaultShards +} + +func (b *BaseSharding) Sharding() Config { + return Config{ + DoubleWrite: false, + ShardingKey: "user_id", + NumberOfShards: b.NumberOfShards(), + ShardingSuffixs: func() []string { + suffixLists := make([]string, b.NumberOfShards()) + var i uint = 0 + for ; i < b.NumberOfShards(); i += 1 { + suffixLists[i] = fmt.Sprintf(b.TableSuffix(), i) + } + return suffixLists + }, + ShardingAlgorithm: func(columnValue any) (suffix string, err error) { + userID, ok := columnValue.(uint64) + if !ok { + return "", fmt.Errorf("invalid ID type, expected uint64") + } + // 根据 id 计算分表后缀 + return fmt.Sprintf(b.TableSuffix(), userID&uint64(b.NumberOfShards()-1)), nil + }, + PrimaryKeyGenerator: PKCustom, + PrimaryKeyGeneratorFn: func(_ int64) int64 { + return 0 + }, + } +} + +func (b *BaseSharding) TableSuffix() string { + var tableFormat string + if b.NumberOfShards() < 10 { + tableFormat = "_%01d" + } else if b.NumberOfShards() < 100 { + tableFormat = "_%02d" + } else if b.NumberOfShards() < 1000 { + tableFormat = "_%03d" + } else if b.NumberOfShards() < 10000 { + tableFormat = "_%04d" + } + + return tableFormat +} diff --git a/base_test.go b/base_test.go new file mode 100644 index 0000000..a0bbc54 --- /dev/null +++ b/base_test.go @@ -0,0 +1,42 @@ +package sharding + +import ( + "gorm.io/gorm" + "testing" +) + +type OrderOther struct { + ID int64 `gorm:"primarykey"` + UserID uint64 + Product string + Deleted gorm.DeletedAt + + BaseSharding +} + +func (OrderOther) TableName() string { + return "order_other" +} + +func useBase() { + shardingMiddleware := RegisterWithModel(&OrderOther{}) + db.Use(shardingMiddleware) +} + +func useBaseWithShardingNumbers() { + orderOther := &OrderOther{} + orderOther.DefaultShards = 2 + shardingMiddleware := RegisterWithModel(orderOther) + db.Use(shardingMiddleware) +} + +func TestUseBase(t *testing.T) { + useBase() + db.AutoMigrate(&OrderOther{}) +} + +func TestUseBaseInsert(t *testing.T) { + useBaseWithShardingNumbers() + tx := db.Create(&OrderOther{ID: 100, UserID: 100, Product: "iPhone"}) + assertQueryResult(t, `INSERT INTO order_other_0 ("user_id", "product", "deleted", "id") VALUES ($1, $2, $3, $4) RETURNING "id"`, tx) +} diff --git a/sharding.go b/sharding.go index a8010d3..35f557c 100644 --- a/sharding.go +++ b/sharding.go @@ -110,6 +110,20 @@ func Register(config Config, tables ...any) *Sharding { } } +func RegisterWithConfig(config map[string]Config) *Sharding { + return &Sharding{ + configs: config, + } +} + +func RegisterWithModel(modelList ...ShardingInterface) *Sharding { + config := make(map[string]Config) + for _, item := range modelList { + config[item.TableName()] = item.Sharding() + } + return RegisterWithConfig(config) +} + func (s *Sharding) compile() error { if s.configs == nil { s.configs = make(map[string]Config)