Skip to content

Commit 2f9abde

Browse files
authored
fix: dryrun migration should run select (#251)
1 parent b3b67da commit 2f9abde

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

migrator.go

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,18 @@ type Migrator struct {
5454
migrator.Migrator
5555
}
5656

57+
// select querys ignore dryrun
58+
func (m Migrator) queryRaw(sql string, values ...interface{}) (tx *gorm.DB) {
59+
queryTx := m.DB
60+
if m.DB.DryRun {
61+
queryTx = m.DB.Session(&gorm.Session{})
62+
queryTx.DryRun = false
63+
}
64+
return queryTx.Raw(sql, values...)
65+
}
66+
5767
func (m Migrator) CurrentDatabase() (name string) {
58-
m.DB.Raw("SELECT CURRENT_DATABASE()").Scan(&name)
68+
m.queryRaw("SELECT CURRENT_DATABASE()").Scan(&name)
5969
return
6070
}
6171

@@ -87,7 +97,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
8797
}
8898
}
8999
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
90-
return m.DB.Raw(
100+
return m.queryRaw(
91101
"SELECT count(*) FROM pg_indexes WHERE tablename = ? AND indexname = ? AND schemaname = ?", curTable, name, currentSchema,
92102
).Scan(&count).Error
93103
})
@@ -155,7 +165,7 @@ func (m Migrator) DropIndex(value interface{}, name string) error {
155165

156166
func (m Migrator) GetTables() (tableList []string, err error) {
157167
currentSchema, _ := m.CurrentSchema(m.DB.Statement, "")
158-
return tableList, m.DB.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
168+
return tableList, m.queryRaw("SELECT table_name FROM information_schema.tables WHERE table_schema = ? AND table_type = ?", currentSchema, "BASE TABLE").Scan(&tableList).Error
159169
}
160170

161171
func (m Migrator) CreateTable(values ...interface{}) (err error) {
@@ -189,7 +199,7 @@ func (m Migrator) HasTable(value interface{}) bool {
189199
var count int64
190200
m.RunWithValue(value, func(stmt *gorm.Statement) error {
191201
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
192-
return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
202+
return m.queryRaw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentSchema, curTable, "BASE TABLE").Scan(&count).Error
193203
})
194204
return count > 0
195205
}
@@ -241,7 +251,7 @@ func (m Migrator) HasColumn(value interface{}, field string) bool {
241251
}
242252

243253
currentSchema, curTable := m.CurrentSchema(stmt, stmt.Table)
244-
return m.DB.Raw(
254+
return m.queryRaw(
245255
"SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?",
246256
currentSchema, curTable, name,
247257
).Scan(&count).Error
@@ -266,7 +276,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
266276
checkSQL += "WHERE objsubid = (SELECT ordinal_position FROM information_schema.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?) "
267277
checkSQL += "AND objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = ? AND relnamespace = "
268278
checkSQL += "(SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?))"
269-
m.DB.Raw(checkSQL, values...).Scan(&description)
279+
m.queryRaw(checkSQL, values...).Scan(&description)
270280

271281
comment := strings.Trim(field.Comment, "'")
272282
comment = strings.Trim(comment, `"`)
@@ -414,7 +424,7 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
414424
}
415425
currentSchema, curTable := m.CurrentSchema(stmt, table)
416426

417-
return m.DB.Raw(
427+
return m.queryRaw(
418428
"SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE table_schema = ? AND table_name = ? AND constraint_name = ?",
419429
currentSchema, curTable, name,
420430
).Scan(&count).Error
@@ -429,7 +439,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
429439
var (
430440
currentDatabase = m.DB.Migrator().CurrentDatabase()
431441
currentSchema, table = m.CurrentSchema(stmt, stmt.Table)
432-
columns, err = m.DB.Raw(
442+
columns, err = m.queryRaw(
433443
"SELECT c.column_name, c.is_nullable = 'YES', c.udt_name, c.character_maximum_length, c.numeric_precision, c.numeric_precision_radix, c.numeric_scale, c.datetime_precision, 8 * typlen, c.column_default, pd.description, c.identity_increment FROM information_schema.columns AS c JOIN pg_type AS pgt ON c.udt_name = pgt.typname LEFT JOIN pg_catalog.pg_description as pd ON pd.objsubid = c.ordinal_position AND pd.objoid = (SELECT oid FROM pg_catalog.pg_class WHERE relname = c.table_name AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = c.table_schema)) where table_catalog = ? AND table_schema = ? AND table_name = ?",
434444
currentDatabase, currentSchema, table).Rows()
435445
)
@@ -503,7 +513,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
503513

504514
// check primary, unique field
505515
{
506-
columnTypeRows, err := m.DB.Raw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
516+
columnTypeRows, err := m.queryRaw("SELECT constraint_name FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ? AND constraint_type = ?", currentDatabase, currentSchema, table, "UNIQUE").Rows()
507517
if err != nil {
508518
return err
509519
}
@@ -515,7 +525,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
515525
}
516526
columnTypeRows.Close()
517527

518-
columnTypeRows, err = m.DB.Raw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
528+
columnTypeRows, err = m.queryRaw("SELECT c.column_name, constraint_name, constraint_type FROM information_schema.table_constraints tc JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_catalog, table_name, constraint_name) JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema AND tc.table_name = c.table_name AND ccu.column_name = c.column_name WHERE constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_schema = ? AND c.table_name = ?", currentDatabase, currentSchema, table).Rows()
519529
if err != nil {
520530
return err
521531
}
@@ -542,7 +552,7 @@ func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType,
542552

543553
// check column type
544554
{
545-
dataTypeRows, err := m.DB.Raw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
555+
dataTypeRows, err := m.queryRaw(`SELECT a.attname as column_name, format_type(a.atttypid, a.atttypmod) AS data_type
546556
FROM pg_attribute a JOIN pg_class b ON a.attrelid = b.oid AND relnamespace = (SELECT oid FROM pg_catalog.pg_namespace WHERE nspname = ?)
547557
WHERE a.attnum > 0 -- hide internal columns
548558
AND NOT a.attisdropped -- hide deleted columns
@@ -700,7 +710,7 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
700710

701711
err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
702712
result := make([]*Index, 0)
703-
scanErr := m.DB.Raw(indexSql, stmt.Table).Scan(&result).Error
713+
scanErr := m.queryRaw(indexSql, stmt.Table).Scan(&result).Error
704714
if scanErr != nil {
705715
return scanErr
706716
}

0 commit comments

Comments
 (0)