@@ -7,10 +7,12 @@ import (
77 "fmt"
88 "io"
99 "net/url"
10- "sync/atomic"
1110
1211 "github.com/hashicorp/go-multierror"
1312 "github.com/ydb-platform/ydb-go-sdk/v3"
13+ "github.com/ydb-platform/ydb-go-sdk/v3/balancers"
14+ "github.com/ydb-platform/ydb-go-sdk/v3/retry"
15+ "go.uber.org/atomic"
1416
1517 "github.com/golang-migrate/migrate/v4"
1618 "github.com/golang-migrate/migrate/v4/database"
@@ -22,9 +24,11 @@ func init() {
2224
2325const (
2426 defaultMigrationsTable = "schema_migrations"
27+ defaultLockTable = "schema_lock"
2528
2629 queryParamAuthToken = "x-auth-token"
2730 queryParamMigrationsTable = "x-migrations-table"
31+ queryParamLockTable = "x-lock-table"
2832 queryParamUseGRPCS = "x-use-grpcs"
2933 queryParamTLSCertificateAuthorities = "x-tls-ca"
3034 queryParamTLSInsecureSkipVerify = "x-tls-insecure-skip-verify"
3943
4044type Config struct {
4145 MigrationsTable string
46+ LockTable string
47+ DatabaseName string
4248}
4349
4450type YDB struct {
@@ -63,6 +69,10 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
6369 config .MigrationsTable = defaultMigrationsTable
6470 }
6571
72+ if len (config .LockTable ) == 0 {
73+ config .LockTable = defaultLockTable
74+ }
75+
6676 conn , err := instance .Conn (context .TODO ())
6777 if err != nil {
6878 return nil , err
@@ -73,6 +83,9 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
7383 db : instance ,
7484 config : config ,
7585 }
86+ if err = db .ensureLockTable (); err != nil {
87+ return nil , err
88+ }
7689 if err = db .ensureVersionTable (); err != nil {
7790 return nil , err
7891 }
@@ -109,7 +122,11 @@ func (y *YDB) Open(dsn string) (database.Driver, error) {
109122 return nil , err
110123 }
111124
112- nativeDriver , err := ydb .Open (context .TODO (), purl .String (), append (tlsOptions , credentials )... )
125+ nativeDriver , err := ydb .Open (
126+ context .TODO (),
127+ purl .String (),
128+ append (tlsOptions , credentials , ydb .WithBalancer (balancers .SingleConn ()))... ,
129+ )
113130 if err != nil {
114131 return nil , err
115132 }
@@ -123,6 +140,8 @@ func (y *YDB) Open(dsn string) (database.Driver, error) {
123140
124141 db , err := WithInstance (sql .OpenDB (connector ), & Config {
125142 MigrationsTable : pquery .Get (queryParamMigrationsTable ),
143+ LockTable : pquery .Get (queryParamLockTable ),
144+ DatabaseName : purl .Path ,
126145 })
127146 if err != nil {
128147 return nil , err
@@ -188,7 +207,7 @@ func (y *YDB) Run(migration io.Reader) error {
188207 return err
189208 }
190209
191- if _ , err = y .conn .ExecContext (ydb . WithQueryMode ( context .TODO (), ydb . SchemeQueryMode ), string (rawMigrations )); err != nil {
210+ if _ , err = y .conn .ExecContext (context .Background ( ), string (rawMigrations )); err != nil {
192211 return database.Error {OrigErr : err , Err : "migration failed" , Query : rawMigrations }
193212 }
194213 return nil
@@ -278,23 +297,77 @@ func (y *YDB) Drop() (err error) {
278297
279298 for _ , path := range paths {
280299 dropQuery := fmt .Sprintf ("DROP TABLE IF EXISTS `%s`" , path )
281- if _ , err = y .conn .ExecContext (ydb . WithQueryMode ( context .TODO (), ydb . SchemeQueryMode ), dropQuery ); err != nil {
300+ if _ , err = y .conn .ExecContext (context .Background ( ), dropQuery ); err != nil {
282301 return & database.Error {OrigErr : err , Query : []byte (dropQuery )}
283302 }
284303 }
285304 return nil
286305}
287306
288307func (y * YDB ) Lock () error {
289- if ! y .isLocked .CompareAndSwap (false , true ) {
290- return database .ErrLocked
291- }
292- return nil
308+ return database .CasRestoreOnErr (& y .isLocked , false , true , database .ErrLocked , func () (err error ) {
309+ return retry .DoTx (context .TODO (), y .db , func (ctx context.Context , tx * sql.Tx ) (err error ) {
310+ aid , err := database .GenerateAdvisoryLockId (y .config .DatabaseName )
311+ if err != nil {
312+ return err
313+ }
314+
315+ getLockQuery := fmt .Sprintf ("SELECT * FROM %s WHERE lock_id = '%s'" , y .config .LockTable , aid )
316+ rows , err := tx .Query (getLockQuery , aid )
317+ if err != nil {
318+ return database.Error {OrigErr : err , Err : "failed to fetch migration lock" , Query : []byte (getLockQuery )}
319+ }
320+ defer func () {
321+ if errClose := rows .Close (); errClose != nil {
322+ err = multierror .Append (err , errClose )
323+ }
324+ }()
325+
326+ // If row exists at all, lock is present
327+ locked := rows .Next ()
328+ if locked {
329+ return database .ErrLocked
330+ }
331+
332+ setLockQuery := fmt .Sprintf ("INSERT INTO %s (lock_id) VALUES ('%s')" , y .config .LockTable , aid )
333+ if _ , err = tx .Exec (setLockQuery ); err != nil {
334+ return database.Error {OrigErr : err , Err : "failed to set migration lock" , Query : []byte (setLockQuery )}
335+ }
336+ return nil
337+ }, retry .WithTxOptions (& sql.TxOptions {Isolation : sql .LevelSerializable }))
338+ })
293339}
294340
295341func (y * YDB ) Unlock () error {
296- if ! y .isLocked .CompareAndSwap (true , false ) {
297- return database .ErrNotLocked
342+ return database .CasRestoreOnErr (& y .isLocked , true , false , database .ErrNotLocked , func () (err error ) {
343+ aid , err := database .GenerateAdvisoryLockId (y .config .DatabaseName )
344+ if err != nil {
345+ return err
346+ }
347+
348+ releaseLockQuery := fmt .Sprintf ("DELETE FROM %s WHERE lock_id = '%s'" , y .config .LockTable , aid )
349+ if _ , err = y .conn .ExecContext (context .TODO (), releaseLockQuery ); err != nil {
350+ // On drops, the lock table is fully removed; This is fine, and is a valid "unlocked" state for the schema.
351+ if ydb .IsOperationErrorSchemeError (err ) {
352+ return nil
353+ }
354+ return database.Error {OrigErr : err , Err : "failed to release migration lock" , Query : []byte (releaseLockQuery )}
355+ }
356+
357+ return nil
358+ })
359+ }
360+
361+ // ensureLockTable checks if lock table exists and, if not, creates it.
362+ func (y * YDB ) ensureLockTable () (err error ) {
363+ createLockTableQuery := fmt .Sprintf (`
364+ CREATE TABLE IF NOT EXISTS %s (
365+ lock_id String NOT NULL,
366+ PRIMARY KEY(lock_id)
367+ )
368+ ` , y .config .LockTable )
369+ if _ , err = y .conn .ExecContext (context .Background (), createLockTableQuery ); err != nil {
370+ return & database.Error {OrigErr : err , Query : []byte (createLockTableQuery )}
298371 }
299372 return nil
300373}
@@ -323,7 +396,7 @@ func (y *YDB) ensureVersionTable() (err error) {
323396 PRIMARY KEY(version)
324397 )
325398 ` , y .config .MigrationsTable )
326- if _ , err = y .conn .ExecContext (ydb . WithQueryMode ( context .TODO (), ydb . SchemeQueryMode ), createVersionTableQuery ); err != nil {
399+ if _ , err = y .conn .ExecContext (context .Background ( ), createVersionTableQuery ); err != nil {
327400 return & database.Error {OrigErr : err , Query : []byte (createVersionTableQuery )}
328401 }
329402 return nil
0 commit comments