diff --git a/internal/bootstrap/db.go b/internal/bootstrap/db.go index d97cb6796..e7ae35048 100644 --- a/internal/bootstrap/db.go +++ b/internal/bootstrap/db.go @@ -2,6 +2,8 @@ package bootstrap import ( "fmt" + "github.com/OpenListTeam/OpenList/v4/internal/bootstrap/dbengine" + "github.com/OpenListTeam/OpenList/v4/internal/model" stdlog "log" "strings" "time" @@ -10,9 +12,6 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/db" log "github.com/sirupsen/logrus" - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" @@ -38,10 +37,10 @@ func InitDB() { }, Logger: newLogger, } - var dB *gorm.DB + var dB model.Connection var err error if flags.Dev { - dB, err = gorm.Open(sqlite.Open("file::memory:?cache=shared"), gormConfig) + dB, err = dbengine.CreateSqliteCon("file::memory:?cache=shared", gormConfig) conf.Conf.Database.Type = "sqlite3" } else { database := conf.Conf.Database @@ -51,18 +50,18 @@ func InitDB() { if !(strings.HasSuffix(database.DBFile, ".db") && len(database.DBFile) > 3) { log.Fatalf("db name error.") } - dB, err = gorm.Open(sqlite.Open(fmt.Sprintf("%s?_journal=WAL&_vacuum=incremental", - database.DBFile)), gormConfig) + dB, err = dbengine.CreateSqliteCon(fmt.Sprintf("%s?_journal=WAL&_vacuum=incremental&_txlock=immediate", + database.DBFile), gormConfig) } case "mysql": { dsn := database.DSN if dsn == "" { - //[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] + // [username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local&tls=%s", database.User, database.Password, database.Host, database.Port, database.Name, database.SSLMode) } - dB, err = gorm.Open(mysql.Open(dsn), gormConfig) + dB, err = dbengine.CreateMysqlCon(dsn, gormConfig) } case "postgres": { @@ -76,7 +75,7 @@ func InitDB() { database.Host, database.User, database.Name, database.Port, database.SSLMode) } } - dB, err = gorm.Open(postgres.Open(dsn), gormConfig) + dB, err = dbengine.CreatePostgresCon(dsn, gormConfig) } default: log.Fatalf("not supported database type: %s", database.Type) diff --git a/internal/bootstrap/dbengine/mysql.go b/internal/bootstrap/dbengine/mysql.go new file mode 100644 index 000000000..e15e3e402 --- /dev/null +++ b/internal/bootstrap/dbengine/mysql.go @@ -0,0 +1,38 @@ +package dbengine + +import ( + "fmt" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "gorm.io/driver/mysql" + "gorm.io/gorm" +) + +// CreateMysqlCon creates MySQL database connections +func CreateMysqlCon(dsn string, gormConfig *gorm.Config) (con model.Connection, err error) { + var ( + db *gorm.DB + ) + + // Create MySQL database connection + db, err = gorm.Open(mysql.Open(dsn), gormConfig) + if err != nil { + return model.Connection{}, fmt.Errorf("cannot create MySQL database connection: %w", err) + } + + // Get underlying database connection for configuration + sqlDB, err := db.DB() + if err != nil { + return model.Connection{}, fmt.Errorf("cannot access underlying MySQL database connection: %w", err) + } + + // Set connection pool parameters + sqlDB.SetMaxOpenConns(100) + sqlDB.SetMaxIdleConns(10) + sqlDB.SetConnMaxLifetime(0) + + // For MySQL, both read and write connections point to the same database instance + return model.Connection{ + Read: db, // Read connection + Write: db, // Write connection + }, nil +} diff --git a/internal/bootstrap/dbengine/postgres.go b/internal/bootstrap/dbengine/postgres.go new file mode 100644 index 000000000..de15650a7 --- /dev/null +++ b/internal/bootstrap/dbengine/postgres.go @@ -0,0 +1,38 @@ +package dbengine + +import ( + "fmt" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +// CreatePostgresCon creates PostgreSQL database connections +func CreatePostgresCon(dsn string, gormConfig *gorm.Config) (con model.Connection, err error) { + var ( + db *gorm.DB + ) + + // Create PostgreSQL database connection + db, err = gorm.Open(postgres.Open(dsn), gormConfig) + if err != nil { + return model.Connection{}, fmt.Errorf("cannot create PostgreSQL database connection: %w", err) + } + + // Get underlying database connection for configuration + sqlDB, err := db.DB() + if err != nil { + return model.Connection{}, fmt.Errorf("cannot access underlying PostgreSQL database connection: %w", err) + } + + // Set connection pool parameters + sqlDB.SetMaxOpenConns(100) + sqlDB.SetMaxIdleConns(10) + sqlDB.SetConnMaxLifetime(0) + + // For PostgreSQL, both read and write connections point to the same database instance + return model.Connection{ + Read: db, // Read connection + Write: db, // Write connection + }, nil +} diff --git a/internal/bootstrap/dbengine/sqlite.go b/internal/bootstrap/dbengine/sqlite.go new file mode 100644 index 000000000..e47dddf3a --- /dev/null +++ b/internal/bootstrap/dbengine/sqlite.go @@ -0,0 +1,102 @@ +package dbengine + +import ( + "database/sql" + "fmt" + "github.com/OpenListTeam/OpenList/v4/internal/model" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "runtime" + "strings" +) + +// CreateSqliteCon uses best practices for sqlite and creates two connections +// One optimized for reading and the other optimized for writing +// found the information here: https://kerkour.com/sqlite-for-servers +// copied from https://github.com/bihe/monorepo/blob/477e534bd4c0814cdca73fea774b518148cebd3f/pkg/persistence/sqlite.go#L59 +// with little edit. +// it should solve "database is locked error", and make better performance. +func CreateSqliteCon(dsn string, gormConfig *gorm.Config) (con model.Connection, err error) { + var ( + read *gorm.DB + write *gorm.DB + ) + + // Read DB + read, err = gorm.Open(sqlite.Open(dsn), gormConfig) + if err != nil { + return model.Connection{}, fmt.Errorf("cannot create read database connection: %w", err) + } + readDB, err := read.DB() + if err != nil { + return model.Connection{}, fmt.Errorf("cannot access underlying read database connection: %w", err) + } + if !strings.Contains(dsn, ":memory:") && !strings.Contains(dsn, "mode=memory") { + err = setDefaultPragmas(readDB) + } + if err != nil { + return model.Connection{}, err + } + readDB.SetMaxOpenConns(max(4, runtime.NumCPU())) // read in parallel with open connection per core + + // WriteDB + write, err = gorm.Open(sqlite.Open(dsn), gormConfig) + if err != nil { + return model.Connection{}, fmt.Errorf("cannot create write database connection: %w", err) + } + writeDB, err := write.DB() + if err != nil { + return model.Connection{}, fmt.Errorf("cannot access underlying write database connection: %w", err) + } + if !strings.Contains(dsn, ":memory:") && !strings.Contains(dsn, "mode=memory") { + err = setDefaultPragmas(writeDB) + } + if err != nil { + return model.Connection{}, err + } + writeDB.SetMaxOpenConns(1) // only use one active connection for writing + + return model.Connection{ + Read: read, + Write: write, + }, nil +} + +// SetDefaultPragmas defines some sqlite pragmas for good performance and litestream compatibility +// https://highperformancesqlite.com/articles/sqlite-recommended-pragmas +// https://litestream.io/tips/ +func setDefaultPragmas(db *sql.DB) error { + var ( + stmt string + val string + ) + defaultPragmas := map[string]string{ + "journal_mode": "wal", // https://www.sqlite.org/pragma.html#pragma_journal_mode + "busy_timeout": "5000", // https://www.sqlite.org/pragma.html#pragma_busy_timeout + "synchronous": "1", // NORMAL --> https://www.sqlite.org/pragma.html#pragma_synchronous + "cache_size": "10000", // 10000 pages = 40MB --> https://www.sqlite.org/pragma.html#pragma_cache_size + "foreign_keys": "1", // 1(bool) --> https://www.sqlite.org/pragma.html#pragma_foreign_keys + } + + // set the pragmas + for k := range defaultPragmas { + stmt = fmt.Sprintf("pragma %s = %s", k, defaultPragmas[k]) + if _, err := db.Exec(stmt); err != nil { + return err + } + } + + // validate the pragmas + for k := range defaultPragmas { + row := db.QueryRow(fmt.Sprintf("pragma %s", k)) + err := row.Scan(&val) + if err != nil { + return err + } + if val != defaultPragmas[k] { + return fmt.Errorf("could not set pragma %s to %s", k, defaultPragmas[k]) + } + } + + return nil +} diff --git a/internal/db/db.go b/internal/db/db.go index 96529c15d..0c700e9d1 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -5,13 +5,12 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/conf" "github.com/OpenListTeam/OpenList/v4/internal/model" - "gorm.io/gorm" ) -var db *gorm.DB +var rwDb model.Connection -func Init(d *gorm.DB) { - db = d +func Init(d model.Connection) { + rwDb = d err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey), new(model.SharingDB)) if err != nil { log.Fatalf("failed migrate database: %s", err.Error()) @@ -21,27 +20,28 @@ func Init(d *gorm.DB) { func AutoMigrate(dst ...interface{}) error { var err error if conf.Conf.Database.Type == "mysql" { - err = db.Set("gorm:table_options", "ENGINE=InnoDB CHARSET=utf8mb4").AutoMigrate(dst...) + err = rwDb.W().Set("gorm:table_options", "ENGINE=InnoDB CHARSET=utf8mb4").AutoMigrate(dst...) } else { - err = db.AutoMigrate(dst...) + err = rwDb.W().AutoMigrate(dst...) } return err } -func GetDb() *gorm.DB { - return db +func GetDb() model.Connection { + return rwDb } func Close() { log.Info("closing db") - sqlDB, err := db.DB() - if err != nil { - log.Errorf("failed to get db: %s", err.Error()) - return + var err error + switch conf.Conf.Database.Type { + case "sqlite3": + err = rwDb.Close(true) + default: + err = rwDb.Close(false) } - err = sqlDB.Close() if err != nil { - log.Errorf("failed to close db: %s", err.Error()) + log.Errorf(err.Error()) return } } diff --git a/internal/db/meta.go b/internal/db/meta.go index 32eec2c38..34f836d2a 100644 --- a/internal/db/meta.go +++ b/internal/db/meta.go @@ -7,7 +7,7 @@ import ( func GetMetaByPath(path string) (*model.Meta, error) { meta := model.Meta{Path: path} - if err := db.Where(meta).First(&meta).Error; err != nil { + if err := rwDb.R().Where(meta).First(&meta).Error; err != nil { return nil, errors.Wrapf(err, "failed select meta") } return &meta, nil @@ -15,22 +15,22 @@ func GetMetaByPath(path string) (*model.Meta, error) { func GetMetaById(id uint) (*model.Meta, error) { var u model.Meta - if err := db.First(&u, id).Error; err != nil { + if err := rwDb.R().First(&u, id).Error; err != nil { return nil, errors.Wrapf(err, "failed get old meta") } return &u, nil } func CreateMeta(u *model.Meta) error { - return errors.WithStack(db.Create(u).Error) + return errors.WithStack(rwDb.W().Create(u).Error) } func UpdateMeta(u *model.Meta) error { - return errors.WithStack(db.Save(u).Error) + return errors.WithStack(rwDb.W().Save(u).Error) } func GetMetas(pageIndex, pageSize int) (metas []model.Meta, count int64, err error) { - metaDB := db.Model(&model.Meta{}) + metaDB := rwDb.R().Model(&model.Meta{}) if err = metaDB.Count(&count).Error; err != nil { return nil, 0, errors.Wrapf(err, "failed get metas count") } @@ -41,5 +41,5 @@ func GetMetas(pageIndex, pageSize int) (metas []model.Meta, count int64, err err } func DeleteMetaById(id uint) error { - return errors.WithStack(db.Delete(&model.Meta{}, id).Error) + return errors.WithStack(rwDb.W().Delete(&model.Meta{}, id).Error) } diff --git a/internal/db/searchnode.go b/internal/db/searchnode.go index 11d345a77..888af02d1 100644 --- a/internal/db/searchnode.go +++ b/internal/db/searchnode.go @@ -14,40 +14,40 @@ import ( func whereInParent(parent string) *gorm.DB { if parent == "/" { - return db.Where("1 = 1") + return rwDb.R().Where("1 = 1") } - return db.Where(fmt.Sprintf("%s LIKE ?", columnName("parent")), + return rwDb.R().Where(fmt.Sprintf("%s LIKE ?", columnName("parent")), fmt.Sprintf("%s/%%", parent)). Or(fmt.Sprintf("%s = ?", columnName("parent")), parent) } func CreateSearchNode(node *model.SearchNode) error { - return db.Create(node).Error + return rwDb.W().Create(node).Error } func BatchCreateSearchNodes(nodes *[]model.SearchNode) error { - return db.CreateInBatches(nodes, 1000).Error + return rwDb.W().CreateInBatches(nodes, 1000).Error } func DeleteSearchNodesByParent(path string) error { path = utils.FixAndCleanPath(path) - err := db.Where(whereInParent(path)).Delete(&model.SearchNode{}).Error + err := rwDb.W().Where(whereInParent(path)).Delete(&model.SearchNode{}).Error if err != nil { return err } dir, name := stdpath.Split(path) - return db.Where(fmt.Sprintf("%s = ? AND %s = ?", + return rwDb.W().Where(fmt.Sprintf("%s = ? AND %s = ?", columnName("parent"), columnName("name")), dir, name).Delete(&model.SearchNode{}).Error } func ClearSearchNodes() error { - return db.Where("1 = 1").Delete(&model.SearchNode{}).Error + return rwDb.W().Where("1 = 1").Delete(&model.SearchNode{}).Error } func GetSearchNodesByParent(parent string) ([]model.SearchNode, error) { var nodes []model.SearchNode - if err := db.Where(fmt.Sprintf("%s = ?", + if err := rwDb.R().Where(fmt.Sprintf("%s = ?", columnName("parent")), parent).Find(&nodes).Error; err != nil { return nil, err } @@ -57,25 +57,25 @@ func GetSearchNodesByParent(parent string) ([]model.SearchNode, error) { func SearchNode(req model.SearchReq, useFullText bool) ([]model.SearchNode, int64, error) { var searchDB *gorm.DB if !useFullText || conf.Conf.Database.Type == "sqlite3" { - keywordsClause := db.Where("1 = 1") + keywordsClause := rwDb.R().Where("1 = 1") for _, keyword := range strings.Fields(req.Keywords) { keywordsClause = keywordsClause.Where("name LIKE ?", fmt.Sprintf("%%%s%%", keyword)) } - searchDB = db.Model(&model.SearchNode{}).Where(whereInParent(req.Parent)).Where(keywordsClause) + searchDB = rwDb.R().Model(&model.SearchNode{}).Where(whereInParent(req.Parent)).Where(keywordsClause) } else { switch conf.Conf.Database.Type { case "mysql": - searchDB = db.Model(&model.SearchNode{}).Where(whereInParent(req.Parent)). + searchDB = rwDb.R().Model(&model.SearchNode{}).Where(whereInParent(req.Parent)). Where("MATCH (name) AGAINST (? IN BOOLEAN MODE)", "'*"+req.Keywords+"*'") case "postgres": - searchDB = db.Model(&model.SearchNode{}).Where(whereInParent(req.Parent)). + searchDB = rwDb.R().Model(&model.SearchNode{}).Where(whereInParent(req.Parent)). Where("to_tsvector(name) @@ to_tsquery(?)", strings.Join(strings.Fields(req.Keywords), " & ")) } } if req.Scope != 0 { isDir := req.Scope == 1 - searchDB.Where(db.Where("is_dir = ?", isDir)) + searchDB = searchDB.Where("is_dir = ?", isDir) } var count int64 diff --git a/internal/db/settingitem.go b/internal/db/settingitem.go index f20e507f0..a6d3279ad 100644 --- a/internal/db/settingitem.go +++ b/internal/db/settingitem.go @@ -9,7 +9,7 @@ import ( func GetSettingItems() ([]model.SettingItem, error) { var settingItems []model.SettingItem - if err := db.Find(&settingItems).Error; err != nil { + if err := rwDb.R().Find(&settingItems).Error; err != nil { return nil, errors.WithStack(err) } return settingItems, nil @@ -17,7 +17,7 @@ func GetSettingItems() ([]model.SettingItem, error) { func GetSettingItemByKey(key string) (*model.SettingItem, error) { var settingItem model.SettingItem - if err := db.Where(fmt.Sprintf("%s = ?", columnName("key")), key).First(&settingItem).Error; err != nil { + if err := rwDb.R().Where(fmt.Sprintf("%s = ?", columnName("key")), key).First(&settingItem).Error; err != nil { return nil, errors.WithStack(err) } return &settingItem, nil @@ -33,7 +33,7 @@ func GetSettingItemByKey(key string) (*model.SettingItem, error) { func GetPublicSettingItems() ([]model.SettingItem, error) { var settingItems []model.SettingItem - if err := db.Where(fmt.Sprintf("%s in ?", columnName("flag")), []int{model.PUBLIC, model.READONLY}).Find(&settingItems).Error; err != nil { + if err := rwDb.R().Where(fmt.Sprintf("%s in ?", columnName("flag")), []int{model.PUBLIC, model.READONLY}).Find(&settingItems).Error; err != nil { return nil, errors.WithStack(err) } return settingItems, nil @@ -41,7 +41,7 @@ func GetPublicSettingItems() ([]model.SettingItem, error) { func GetSettingItemsByGroup(group int) ([]model.SettingItem, error) { var settingItems []model.SettingItem - if err := db.Where(fmt.Sprintf("%s = ?", columnName("group")), group).Find(&settingItems).Error; err != nil { + if err := rwDb.R().Where(fmt.Sprintf("%s = ?", columnName("group")), group).Find(&settingItems).Error; err != nil { return nil, errors.WithStack(err) } return settingItems, nil @@ -49,7 +49,7 @@ func GetSettingItemsByGroup(group int) ([]model.SettingItem, error) { func GetSettingItemsInGroups(groups []int) ([]model.SettingItem, error) { var settingItems []model.SettingItem - err := db.Order(columnName("index")).Where(fmt.Sprintf("%s in ?", columnName("group")), groups).Find(&settingItems).Error + err := rwDb.R().Order(columnName("index")).Where(fmt.Sprintf("%s in ?", columnName("group")), groups).Find(&settingItems).Error if err != nil { return nil, errors.WithStack(err) } @@ -57,13 +57,13 @@ func GetSettingItemsInGroups(groups []int) ([]model.SettingItem, error) { } func SaveSettingItems(items []model.SettingItem) (err error) { - return errors.WithStack(db.Save(items).Error) + return errors.WithStack(rwDb.W().Save(items).Error) } func SaveSettingItem(item *model.SettingItem) error { - return errors.WithStack(db.Save(item).Error) + return errors.WithStack(rwDb.W().Save(item).Error) } func DeleteSettingItemByKey(key string) error { - return errors.WithStack(db.Delete(&model.SettingItem{Key: key}).Error) + return errors.WithStack(rwDb.W().Delete(&model.SettingItem{Key: key}).Error) } diff --git a/internal/db/sharing.go b/internal/db/sharing.go index 3748796b0..56a50aed5 100644 --- a/internal/db/sharing.go +++ b/internal/db/sharing.go @@ -8,14 +8,14 @@ import ( func GetSharingById(id string) (*model.SharingDB, error) { s := model.SharingDB{ID: id} - if err := db.Where(s).First(&s).Error; err != nil { + if err := rwDb.R().Where(s).First(&s).Error; err != nil { return nil, errors.Wrapf(err, "failed get sharing") } return &s, nil } func GetSharings(pageIndex, pageSize int) (sharings []model.SharingDB, count int64, err error) { - sharingDB := db.Model(&model.SharingDB{}) + sharingDB := rwDb.R().Model(&model.SharingDB{}) if err := sharingDB.Count(&count).Error; err != nil { return nil, 0, errors.Wrapf(err, "failed get sharings count") } @@ -26,7 +26,7 @@ func GetSharings(pageIndex, pageSize int) (sharings []model.SharingDB, count int } func GetSharingsByCreatorId(creator uint, pageIndex, pageSize int) (sharings []model.SharingDB, count int64, err error) { - sharingDB := db.Model(&model.SharingDB{}) + sharingDB := rwDb.R().Model(&model.SharingDB{}) cond := model.SharingDB{CreatorId: creator} if err := sharingDB.Where(cond).Count(&count).Error; err != nil { return nil, 0, errors.Wrapf(err, "failed get sharings count") @@ -43,9 +43,9 @@ func CreateSharing(s *model.SharingDB) (string, error) { old := model.SharingDB{ ID: id, } - if err := db.Where(old).First(&old).Error; err != nil { + if err := rwDb.R().Where(old).First(&old).Error; err != nil { s.ID = id - return id, errors.WithStack(db.Create(s).Error) + return id, errors.WithStack(rwDb.W().Create(s).Error) } id += random.String(1) } @@ -53,10 +53,10 @@ func CreateSharing(s *model.SharingDB) (string, error) { } func UpdateSharing(s *model.SharingDB) error { - return errors.WithStack(db.Save(s).Error) + return errors.WithStack(rwDb.W().Save(s).Error) } func DeleteSharingById(id string) error { s := model.SharingDB{ID: id} - return errors.WithStack(db.Where(s).Delete(&s).Error) + return errors.WithStack(rwDb.W().Where(s).Delete(&s).Error) } diff --git a/internal/db/sshkey.go b/internal/db/sshkey.go index 9d6526b4c..ab1e29fbf 100644 --- a/internal/db/sshkey.go +++ b/internal/db/sshkey.go @@ -6,7 +6,7 @@ import ( ) func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) { - keyDB := db.Model(&model.SSHPublicKey{}) + keyDB := rwDb.R().Model(&model.SSHPublicKey{}) query := model.SSHPublicKey{UserId: userId} if err := keyDB.Where(query).Count(&count).Error; err != nil { return nil, 0, errors.Wrapf(err, "failed get user's keys count") @@ -19,7 +19,7 @@ func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model func GetSSHPublicKeyById(id uint) (*model.SSHPublicKey, error) { var k model.SSHPublicKey - if err := db.First(&k, id).Error; err != nil { + if err := rwDb.R().First(&k, id).Error; err != nil { return nil, errors.Wrapf(err, "failed get old key") } return &k, nil @@ -27,22 +27,22 @@ func GetSSHPublicKeyById(id uint) (*model.SSHPublicKey, error) { func GetSSHPublicKeyByUserTitle(userId uint, title string) (*model.SSHPublicKey, error) { key := model.SSHPublicKey{UserId: userId, Title: title} - if err := db.Where(key).First(&key).Error; err != nil { + if err := rwDb.R().Where(key).First(&key).Error; err != nil { return nil, errors.Wrapf(err, "failed find key with title of user") } return &key, nil } func CreateSSHPublicKey(k *model.SSHPublicKey) error { - return errors.WithStack(db.Create(k).Error) + return errors.WithStack(rwDb.W().Create(k).Error) } func UpdateSSHPublicKey(k *model.SSHPublicKey) error { - return errors.WithStack(db.Save(k).Error) + return errors.WithStack(rwDb.W().Save(k).Error) } func GetSSHPublicKeys(pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) { - keyDB := db.Model(&model.SSHPublicKey{}) + keyDB := rwDb.R().Model(&model.SSHPublicKey{}) if err := keyDB.Count(&count).Error; err != nil { return nil, 0, errors.Wrapf(err, "failed get keys count") } @@ -53,5 +53,5 @@ func GetSSHPublicKeys(pageIndex, pageSize int) (keys []model.SSHPublicKey, count } func DeleteSSHPublicKeyById(id uint) error { - return errors.WithStack(db.Delete(&model.SSHPublicKey{}, id).Error) + return errors.WithStack(rwDb.W().Delete(&model.SSHPublicKey{}, id).Error) } diff --git a/internal/db/storage.go b/internal/db/storage.go index 0c660a156..12fb4a839 100644 --- a/internal/db/storage.go +++ b/internal/db/storage.go @@ -14,22 +14,22 @@ import ( // CreateStorage just insert storage to database func CreateStorage(storage *model.Storage) error { - return errors.WithStack(db.Create(storage).Error) + return errors.WithStack(rwDb.W().Create(storage).Error) } // UpdateStorage just update storage in database func UpdateStorage(storage *model.Storage) error { - return errors.WithStack(db.Save(storage).Error) + return errors.WithStack(rwDb.W().Save(storage).Error) } // DeleteStorageById just delete storage from database by id func DeleteStorageById(id uint) error { - return errors.WithStack(db.Delete(&model.Storage{}, id).Error) + return errors.WithStack(rwDb.W().Delete(&model.Storage{}, id).Error) } // GetStorages Get all storages from database order by index func GetStorages(pageIndex, pageSize int) ([]model.Storage, int64, error) { - storageDB := db.Model(&model.Storage{}) + storageDB := rwDb.R().Model(&model.Storage{}) var count int64 if err := storageDB.Count(&count).Error; err != nil { return nil, 0, errors.Wrapf(err, "failed get storages count") @@ -45,7 +45,7 @@ func GetStorages(pageIndex, pageSize int) ([]model.Storage, int64, error) { func GetStorageById(id uint) (*model.Storage, error) { var storage model.Storage storage.ID = id - if err := db.First(&storage).Error; err != nil { + if err := rwDb.R().First(&storage).Error; err != nil { return nil, errors.WithStack(err) } return &storage, nil @@ -54,7 +54,7 @@ func GetStorageById(id uint) (*model.Storage, error) { // GetStorageByMountPath Get Storage by mountPath, used to update storage usually func GetStorageByMountPath(mountPath string) (*model.Storage, error) { var storage model.Storage - if err := db.Where("mount_path = ?", mountPath).First(&storage).Error; err != nil { + if err := rwDb.R().Where("mount_path = ?", mountPath).First(&storage).Error; err != nil { return nil, errors.WithStack(err) } return &storage, nil @@ -62,7 +62,7 @@ func GetStorageByMountPath(mountPath string) (*model.Storage, error) { func GetEnabledStorages() ([]model.Storage, error) { var storages []model.Storage - err := addStorageOrder(db).Where(fmt.Sprintf("%s = ?", columnName("disabled")), false).Find(&storages).Error + err := addStorageOrder(rwDb.R()).Where(fmt.Sprintf("%s = ?", columnName("disabled")), false).Find(&storages).Error if err != nil { return nil, errors.WithStack(err) } diff --git a/internal/db/tasks.go b/internal/db/tasks.go index dcb9dfeab..a31c59a2d 100644 --- a/internal/db/tasks.go +++ b/internal/db/tasks.go @@ -7,18 +7,18 @@ import ( func GetTaskDataByType(type_s string) (*model.TaskItem, error) { task := model.TaskItem{Key: type_s} - if err := db.Where(task).First(&task).Error; err != nil { + if err := rwDb.R().Where(task).First(&task).Error; err != nil { return nil, errors.Wrapf(err, "failed find task") } return &task, nil } func UpdateTaskData(t *model.TaskItem) error { - return errors.WithStack(db.Model(&model.TaskItem{}).Where("key = ?", t.Key).Update("persist_data", t.PersistData).Error) + return errors.WithStack(rwDb.W().Model(&model.TaskItem{}).Where("key = ?", t.Key).Update("persist_data", t.PersistData).Error) } func CreateTaskData(t *model.TaskItem) error { - return errors.WithStack(db.Create(t).Error) + return errors.WithStack(rwDb.W().Create(t).Error) } func GetTaskDataFunc(type_s string, enabled bool) func() ([]byte, error) { diff --git a/internal/db/user.go b/internal/db/user.go index 4b9c67ece..d37b24760 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -11,7 +11,7 @@ import ( func GetUserByRole(role int) (*model.User, error) { user := model.User{Role: role} - if err := db.Where(user).Take(&user).Error; err != nil { + if err := rwDb.R().Where(user).Take(&user).Error; err != nil { return nil, err } return &user, nil @@ -19,7 +19,7 @@ func GetUserByRole(role int) (*model.User, error) { func GetUserByName(username string) (*model.User, error) { user := model.User{Username: username} - if err := db.Where(user).First(&user).Error; err != nil { + if err := rwDb.R().Where(user).First(&user).Error; err != nil { return nil, errors.Wrapf(err, "failed find user") } return &user, nil @@ -27,7 +27,7 @@ func GetUserByName(username string) (*model.User, error) { func GetUserBySSOID(ssoID string) (*model.User, error) { user := model.User{SsoID: ssoID} - if err := db.Where(user).First(&user).Error; err != nil { + if err := rwDb.R().Where(user).First(&user).Error; err != nil { return nil, errors.Wrapf(err, "The single sign on platform is not bound to any users") } return &user, nil @@ -35,22 +35,22 @@ func GetUserBySSOID(ssoID string) (*model.User, error) { func GetUserById(id uint) (*model.User, error) { var u model.User - if err := db.First(&u, id).Error; err != nil { + if err := rwDb.R().First(&u, id).Error; err != nil { return nil, errors.Wrapf(err, "failed get old user") } return &u, nil } func CreateUser(u *model.User) error { - return errors.WithStack(db.Create(u).Error) + return errors.WithStack(rwDb.W().Create(u).Error) } func UpdateUser(u *model.User) error { - return errors.WithStack(db.Save(u).Error) + return errors.WithStack(rwDb.W().Save(u).Error) } func GetUsers(pageIndex, pageSize int) (users []model.User, count int64, err error) { - userDB := db.Model(&model.User{}) + userDB := rwDb.R().Model(&model.User{}) if err := userDB.Count(&count).Error; err != nil { return nil, 0, errors.Wrapf(err, "failed get users count") } @@ -61,11 +61,11 @@ func GetUsers(pageIndex, pageSize int) (users []model.User, count int64, err err } func DeleteUserById(id uint) error { - return errors.WithStack(db.Delete(&model.User{}, id).Error) + return errors.WithStack(rwDb.W().Delete(&model.User{}, id).Error) } func UpdateAuthn(userID uint, authn string) error { - return db.Model(&model.User{ID: userID}).Update("authn", authn).Error + return rwDb.W().Model(&model.User{ID: userID}).Update("authn", authn).Error } func RegisterAuthn(u *model.User, credential *webauthn.Credential) error { diff --git a/internal/model/connection.go b/internal/model/connection.go new file mode 100644 index 000000000..3982940b5 --- /dev/null +++ b/internal/model/connection.go @@ -0,0 +1,74 @@ +package model + +import ( + "fmt" + "gorm.io/gorm" +) + +// A Connection struct abstracts the access to the underlying database connections +// It is used for reading from a database connection and writing to a database connection. +// The writing is assumed is always in the context of a database transaction. +type Connection struct { + // Read is a read connection used for fast access to the underlying database transaction + Read *gorm.DB + // Write is a write connection which is used primarily to write in particular to create a transaction connection + Write *gorm.DB +} + +// R returns a suitable connection. It is either read focused connection +// or a transaction. +// The func panics if no read connection is available +func (c Connection) R() *gorm.DB { + if c.Read == nil { + panic("no read database connection is available") + } + return c.Read +} + +// W retrieves a write connection. If this is a transaction use the tx connection +// The func panics if no write connection is available +func (c Connection) W() *gorm.DB { + if c.Write == nil { + panic("no write database connection is available") + } + return c.Write +} + +// Close closes the connection and cleans up resources. +// If twiceFlag = true, need to close both writeDB and readDB. +// If twiceFlag = false, only need to close readDB (when readDB = writeDB). +func (c Connection) Close(twiceFlag bool) error { + var err error + + // Close readDB + readDBRaw, readErr := c.Read.DB() + if readErr != nil { + err = fmt.Errorf("failed to get read db: %s", readErr.Error()) + } else { + if closeErr := readDBRaw.Close(); closeErr != nil { + err = fmt.Errorf("failed to close read db: %s", closeErr.Error()) + } + } + + // Close writeDB if twiceFlag is true + if twiceFlag { + writeDBRaw, writeErr := c.Write.DB() + if writeErr != nil { + if err != nil { + err = fmt.Errorf("%s; failed to get write db: %s", err.Error(), writeErr.Error()) + } else { + err = fmt.Errorf("failed to get write db: %s", writeErr.Error()) + } + } else { + if closeErr := writeDBRaw.Close(); closeErr != nil { + if err != nil { + err = fmt.Errorf("%s; failed to close write db: %s", err.Error(), closeErr.Error()) + } else { + err = fmt.Errorf("failed to close write db: %s", closeErr.Error()) + } + } + } + } + + return err +} diff --git a/internal/op/storage_test.go b/internal/op/storage_test.go index 2b191bd56..4cd4448d3 100644 --- a/internal/op/storage_test.go +++ b/internal/op/storage_test.go @@ -2,6 +2,7 @@ package op_test import ( "context" + "github.com/OpenListTeam/OpenList/v4/internal/bootstrap/dbengine" "testing" "github.com/OpenListTeam/OpenList/v4/internal/conf" @@ -10,17 +11,16 @@ import ( "github.com/OpenListTeam/OpenList/v4/internal/op" "github.com/OpenListTeam/OpenList/v4/pkg/utils" mapset "github.com/deckarep/golang-set/v2" - "gorm.io/driver/sqlite" "gorm.io/gorm" ) func init() { - dB, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{}) + con, err := dbengine.CreateSqliteCon("file::memory:?cache=shared", &gorm.Config{}) if err != nil { panic("failed to connect database") } conf.Conf = conf.DefaultConfig("data") - db.Init(dB) + db.Init(con) } func TestCreateStorage(t *testing.T) { diff --git a/internal/search/db/init.go b/internal/search/db/init.go index 8d4011653..a9cfe400f 100644 --- a/internal/search/db/init.go +++ b/internal/search/db/init.go @@ -18,20 +18,20 @@ var config = searcher.Config{ func init() { searcher.RegisterSearcher(config, func() (searcher.Searcher, error) { - db := db.GetDb() + rwDb := db.GetDb() switch conf.Conf.Database.Type { case "mysql": tableName := fmt.Sprintf("%ssearch_nodes", conf.Conf.Database.TablePrefix) - tx := db.Exec(fmt.Sprintf("CREATE FULLTEXT INDEX idx_%s_name_fulltext ON %s(name);", tableName, tableName)) + tx := rwDb.W().Exec(fmt.Sprintf("CREATE FULLTEXT INDEX idx_%s_name_fulltext ON %s(name);", tableName, tableName)) if err := tx.Error; err != nil && !strings.Contains(err.Error(), "Error 1061 (42000)") { // duplicate error log.Errorf("failed to create full text index: %v", err) return nil, err } case "postgres": - db.Exec("CREATE EXTENSION pg_trgm;") - db.Exec("CREATE EXTENSION btree_gin;") + rwDb.W().Exec("CREATE EXTENSION pg_trgm;") + rwDb.W().Exec("CREATE EXTENSION btree_gin;") tableName := fmt.Sprintf("%ssearch_nodes", conf.Conf.Database.TablePrefix) - tx := db.Exec(fmt.Sprintf("CREATE INDEX idx_%s_name ON %s USING GIN (name);", tableName, tableName)) + tx := rwDb.W().Exec(fmt.Sprintf("CREATE INDEX idx_%s_name ON %s USING GIN (name);", tableName, tableName)) if err := tx.Error; err != nil && !strings.Contains(err.Error(), "SQLSTATE 42P07") { log.Errorf("failed to create index using GIN: %v", err) return nil, err