Skip to content

Commit 07d0daa

Browse files
themullemanuelmueller2jinzhu
authored
enable AutoMigrate when using non default sqlserver schemas (#50)
* enable AutoMigrate when using custom sqlserver schemas Sqlserver supports multiple schemas per database. This feature works in gorm (custom TableName/NamingStrategy) but fails at AutoMigrate. This commit fixes the migrator to also support non-default schemas. * fix constraint lookup when using explicit schema names * bugfixes due to upgrade * fix bug in foreign-key query * Update migrator_test.go Co-authored-by: Mueller Manuel (LWE) <[email protected]> Co-authored-by: Jinzhu <[email protected]>
1 parent 88a3f7b commit 07d0daa

File tree

2 files changed

+153
-7
lines changed

2 files changed

+153
-7
lines changed

migrator.go

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"gorm.io/gorm"
1010
"gorm.io/gorm/clause"
1111
"gorm.io/gorm/migrator"
12+
"gorm.io/gorm/schema"
1213
)
1314

1415
type Migrator struct {
@@ -19,12 +20,46 @@ func (m Migrator) GetTables() (tableList []string, err error) {
1920
return tableList, m.DB.Raw("SELECT table_name FROM INFORMATION_SCHEMA.tables WHERE table_catalog = ?", m.CurrentDatabase()).Scan(&tableList).Error
2021
}
2122

23+
func getTableSchemaName(schema *schema.Schema) string {
24+
//return the schema name if it is explicitly provided in the table name
25+
//otherwise return a sql wildcard -> use any table_schema
26+
if schema == nil || !strings.Contains(schema.Table, ".") {
27+
return ""
28+
}
29+
_, schemaName, _ := splitFullQualifiedName(schema.Table)
30+
return schemaName
31+
}
32+
33+
func splitFullQualifiedName(name string) (string, string, string) {
34+
nameParts := strings.Split(name, ".")
35+
if len(nameParts) == 1 { //[table_name]
36+
return "", "", nameParts[0]
37+
} else if len(nameParts) == 2 { //[table_schema].[table_name]
38+
return "", nameParts[0], nameParts[1]
39+
} else if len(nameParts) == 3 { //[table_catalog].[table_schema].[table_name]
40+
return nameParts[0], nameParts[1], nameParts[2]
41+
}
42+
return "", "", ""
43+
}
44+
45+
func getFullQualifiedTableName(stmt *gorm.Statement) string {
46+
fullQualifiedTableName := stmt.Table
47+
if schemaName := getTableSchemaName(stmt.Schema); schemaName != "" {
48+
fullQualifiedTableName = schemaName + "." + fullQualifiedTableName
49+
}
50+
return fullQualifiedTableName
51+
}
52+
2253
func (m Migrator) HasTable(value interface{}) bool {
2354
var count int
2455
m.RunWithValue(value, func(stmt *gorm.Statement) error {
56+
schemaName := getTableSchemaName(stmt.Schema)
57+
if schemaName == "" {
58+
schemaName = "%"
59+
}
2560
return m.DB.Raw(
26-
"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ?",
27-
stmt.Table, m.CurrentDatabase(),
61+
"SELECT count(*) FROM INFORMATION_SCHEMA.tables WHERE table_name = ? AND table_catalog = ? and table_schema like ? AND table_type = ?",
62+
stmt.Table, m.CurrentDatabase(), schemaName, "BASE TABLE",
2863
).Row().Scan(&count)
2964
})
3065
return count > 0
@@ -40,7 +75,7 @@ func (m Migrator) DropTable(values ...interface{}) error {
4075
Parent string
4176
}
4277
var constraints []constraint
43-
err := tx.Raw("SELECT name, OBJECT_NAME(parent_object_id) as parent FROM sys.foreign_keys WHERE referenced_object_id = object_id(?)", stmt.Table).Scan(&constraints).Error
78+
err := tx.Raw("SELECT name, OBJECT_NAME(parent_object_id) as parent FROM sys.foreign_keys WHERE referenced_object_id = object_id(?)", getFullQualifiedTableName(stmt)).Scan(&constraints).Error
4479

4580
for _, c := range constraints {
4681
if err == nil {
@@ -150,7 +185,7 @@ var defaultValueTrimRegexp = regexp.MustCompile("^\\('?([^']*)'?\\)$")
150185
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
151186
columnTypes := make([]gorm.ColumnType, 0)
152187
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
153-
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
188+
rows, err := m.DB.Session(&gorm.Session{}).Table(getFullQualifiedTableName(stmt)).Limit(1).Rows()
154189
if err != nil {
155190
return err
156191
}
@@ -259,7 +294,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool {
259294

260295
return m.DB.Raw(
261296
"SELECT count(*) FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)",
262-
name, stmt.Table,
297+
name, getFullQualifiedTableName(stmt),
263298
).Row().Scan(&count)
264299
})
265300
return count > 0
@@ -285,9 +320,17 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
285320
name = chk.Name
286321
}
287322

323+
tableCatalog, schema, tableName := splitFullQualifiedName(table)
324+
if tableCatalog == "" {
325+
tableCatalog = m.CurrentDatabase()
326+
}
327+
if schema == "" {
328+
schema = "%"
329+
}
330+
288331
return m.DB.Raw(
289-
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND T.Name = ? AND I.TABLE_CATALOG = ?;`,
290-
name, table, m.CurrentDatabase(),
332+
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join information_schema.tables as I on I.TABLE_NAME = T.name WHERE F.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ?;`,
333+
name, tableName, schema, tableCatalog,
291334
).Row().Scan(&count)
292335
})
293336
return count > 0
@@ -297,3 +340,8 @@ func (m Migrator) CurrentDatabase() (name string) {
297340
m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
298341
return
299342
}
343+
344+
func (m Migrator) DefaultSchema() (name string) {
345+
m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
346+
return
347+
}

migrator_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package sqlserver_test
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"gorm.io/driver/sqlserver"
8+
"gorm.io/gorm"
9+
)
10+
11+
var sqlserverDSN = "sqlserver://gorm:LoremIpsum86@localhost:9930?database=gorm"
12+
13+
func init() {
14+
if dbDSN := os.Getenv("GORM_DSN"); dbDSN != "" {
15+
sqlserverDSN = dbDSN
16+
}
17+
}
18+
19+
type Testtable struct {
20+
Test uint64 `gorm:"index"`
21+
}
22+
23+
type Testtable2 struct {
24+
Test uint64 `gorm:"index"`
25+
Test2 uint64
26+
}
27+
28+
func (*Testtable2) TableName() string { return "testtables" }
29+
30+
type Testtable3 struct {
31+
Test3 uint64
32+
}
33+
34+
func (*Testtable3) TableName() string { return "testschema1.Testtables" }
35+
36+
type Testtable4 struct {
37+
Test4 uint64
38+
}
39+
40+
func (*Testtable4) TableName() string { return "testschema2.Testtables" }
41+
42+
type Testtable5 struct {
43+
Test4 uint64
44+
Test5 uint64 `gorm:"index"`
45+
}
46+
47+
func (*Testtable5) TableName() string { return "testschema2.Testtables" }
48+
49+
func TestAutomigrateTablesWithoutDefaultSchema(t *testing.T) {
50+
db, err := gorm.Open(sqlserver.Open(sqlserverDSN))
51+
if err != nil {
52+
t.Error(err)
53+
}
54+
55+
if tx := db.Exec("create schema testschema1"); tx.Error != nil {
56+
t.Error("couldn't create schema testschema1", tx.Error)
57+
}
58+
if tx := db.Exec("create schema testschema2"); tx.Error != nil {
59+
t.Error("couldn't create schema testschema2", tx.Error)
60+
}
61+
62+
if err = db.AutoMigrate(&Testtable{}); err != nil {
63+
t.Error("couldn't create a table at user default schema", err)
64+
}
65+
if err = db.AutoMigrate(&Testtable2{}); err != nil {
66+
t.Error("couldn't update a table at user default schema", err)
67+
}
68+
if err = db.AutoMigrate(&Testtable3{}); err != nil {
69+
t.Error("couldn't create a table at schema testschema1", err)
70+
}
71+
if err = db.AutoMigrate(&Testtable4{}); err != nil {
72+
t.Error("couldn't create a table at schema testschema2", err)
73+
}
74+
if err = db.AutoMigrate(&Testtable5{}); err != nil {
75+
t.Error("couldn't update a table at schema testschema2", err)
76+
}
77+
78+
if tx := db.Exec("drop table testtables"); tx.Error != nil {
79+
t.Error("couldn't drop table testtable at user default schema", tx.Error)
80+
}
81+
82+
if tx := db.Exec("drop table testschema1.testtables"); tx.Error != nil {
83+
t.Error("couldn't drop table testschema1.testtable", tx.Error)
84+
}
85+
86+
if tx := db.Exec("drop table testschema2.testtables"); tx.Error != nil {
87+
t.Error("couldn't drop table testschema2.testtable", tx.Error)
88+
}
89+
90+
if tx := db.Exec("drop schema testschema1"); tx.Error != nil {
91+
t.Error("couldn't drop schema testschema1", tx.Error)
92+
}
93+
94+
if tx := db.Exec("drop schema testschema2"); tx.Error != nil {
95+
t.Error("couldn't drop schema testschema2", tx.Error)
96+
}
97+
98+
}

0 commit comments

Comments
 (0)