Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 27 additions & 30 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) {
}
for _, fieldName := range stmt.Schema.DBNames {
field := stmt.Schema.FieldsByDBName[fieldName]
if field.Comment == "" {
if _, ok := field.TagSettings["COMMENT"]; !ok {
continue
}
if err = m.setColumnComment(stmt, field, true); err != nil {
Expand All @@ -65,17 +65,18 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) {

func (m Migrator) setColumnComment(stmt *gorm.Statement, field *schema.Field, add bool) error {
schemaName := m.getTableSchemaName(stmt.Schema)
commentExpr := gorm.Expr(strings.ReplaceAll(field.Comment, "'", "''"))
// add field comment
if add {
return m.DB.Exec(
"EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
field.Comment, schemaName, stmt.Table, field.DBName,
"EXEC sp_addextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
commentExpr, schemaName, stmt.Table, field.DBName,
).Error
}
// update field comment
return m.DB.Exec(
"EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
field.Comment, schemaName, stmt.Table, field.DBName,
"EXEC sp_updateextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this change cause sql inject?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, it indeed could lead to SQL injection, and this issue has been fixed.

commentExpr, schemaName, stmt.Table, field.DBName,
).Error
}

Expand Down Expand Up @@ -121,7 +122,7 @@ func getFullQualifiedTableName(stmt *gorm.Statement) string {

func (m Migrator) HasTable(value interface{}) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
schemaName := getTableSchemaName(stmt.Schema)
if schemaName == "" {
schemaName = "%"
Expand Down Expand Up @@ -202,7 +203,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
if field.Comment == "" {
if _, ok := field.TagSettings["COMMENT"]; !ok {
return
}
if err = m.setColumnComment(stmt, field, true); err != nil {
Expand All @@ -216,7 +217,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error {

func (m Migrator) HasColumn(value interface{}, field string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
currentDatabase := m.DB.Migrator().CurrentDatabase()
name := field
if stmt.Schema != nil {
Expand Down Expand Up @@ -273,17 +274,13 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
})
}

func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (description string) {
func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (comment sql.NullString) {
queryTx := m.DB.Session(&gorm.Session{Logger: m.DB.Logger.LogMode(logger.Warn)})
if m.DB.DryRun {
queryTx.DryRun = false
}
var comment sql.NullString
queryTx.Raw("SELECT value FROM ?.sys.fn_listextendedproperty('MS_Description', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?)",
gorm.Expr(m.CurrentDatabase()), m.getTableSchemaName(stmt.Schema), stmt.Table, fieldDBName).Scan(&comment)
if comment.Valid {
description = comment.String
}
return
}

Expand All @@ -293,12 +290,12 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
}

return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
description := m.GetColumnComment(stmt, field.DBName)
if field.Comment != description {
if description == "" {
err = m.setColumnComment(stmt, field, true)
} else {
comment := m.GetColumnComment(stmt, field.DBName)
if field.Comment != comment.String {
if comment.Valid {
err = m.setColumnComment(stmt, field, false)
} else {
err = m.setColumnComment(stmt, field, true)
}
}
return
Expand All @@ -317,7 +314,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
}

rawColumnTypes, _ := rows.ColumnTypes()
rows.Close()
_ = rows.Close()

{
_, schemaName, tableName := splitFullQualifiedName(stmt.Table)
Expand Down Expand Up @@ -394,7 +391,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
columnTypes = append(columnTypes, column)
}

columns.Close()
_ = columns.Close()
}

{
Expand All @@ -415,7 +412,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)

for columnTypeRows.Next() {
var name, columnType string
columnTypeRows.Scan(&name, &columnType)
_ = columnTypeRows.Scan(&name, &columnType)
for idx, c := range columnTypes {
mc := c.(migrator.ColumnType)
if mc.NameValue.String == name {
Expand All @@ -431,7 +428,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
}
}

columnTypeRows.Close()
_ = columnTypeRows.Close()
}

return
Expand Down Expand Up @@ -473,7 +470,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {

func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
Expand Down Expand Up @@ -538,34 +535,34 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {

func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
name = constraint.GetName()
}

tableCatalog, schema, tableName := splitFullQualifiedName(table)
tableCatalog, tableSchema, tableName := splitFullQualifiedName(table)
if tableCatalog == "" {
tableCatalog = m.CurrentDatabase()
}
if schema == "" {
schema = "%"
if tableSchema == "" {
tableSchema = "%"
}

return m.DB.Raw(
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join INFORMATION_SCHEMA.TABLES as I on I.TABLE_NAME = T.name WHERE F.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ?;`,
name, tableName, schema, tableCatalog,
name, tableName, tableSchema, tableCatalog,
).Row().Scan(&count)
})
return count > 0
}

func (m Migrator) CurrentDatabase() (name string) {
m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
_ = m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
return
}

func (m Migrator) DefaultSchema() (name string) {
m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
_ = m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
return
}
21 changes: 11 additions & 10 deletions migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,56 +190,57 @@ func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, co
}

type TestTableFieldComment struct {
ID string `gorm:"column:id;primaryKey"`
ID string `gorm:"column:id;primaryKey;comment:"` // field comment is an empty string
Name string `gorm:"column:name;comment:姓名"`
Age uint `gorm:"column:age;comment:年龄"`
}

func (*TestTableFieldComment) TableName() string { return "test_table_field_comment" }

type TestTableFieldCommentUpdate struct {
ID string `gorm:"column:id;primaryKey"`
ID string `gorm:"column:id;primaryKey;comment:ID"`
Name string `gorm:"column:name;comment:姓名"`
Age uint `gorm:"column:age;comment:周岁"`
Birthday *time.Time `gorm:"column:birthday;comment:生日"`
Quote string `gorm:"column:quote;comment:注释中包含'单引号'和特殊符号❤️"`
}

func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_field_comment" }

func TestMigrator_MigrateColumnComment(t *testing.T) {
db, err := gorm.Open(sqlserver.Open(sqlserverDSN))
if err != nil {
t.Error(err)
t.Fatal(err)
}
migrator := db.Debug().Migrator()
dm := db.Debug().Migrator()

tableModel := new(TestTableFieldComment)
defer func() {
if err = migrator.DropTable(tableModel); err != nil {
if err = dm.DropTable(tableModel); err != nil {
t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err)
}
}()

if err = migrator.AutoMigrate(tableModel); err != nil {
if err = dm.AutoMigrate(tableModel); err != nil {
t.Fatal(err)
}
tableModelUpdate := new(TestTableFieldCommentUpdate)
if err = migrator.AutoMigrate(tableModelUpdate); err != nil {
if err = dm.AutoMigrate(tableModelUpdate); err != nil {
t.Error(err)
}

if m, ok := migrator.(sqlserver.Migrator); ok {
if m, ok := dm.(sqlserver.Migrator); ok {
stmt := db.Model(tableModelUpdate).Find(nil).Statement
if stmt == nil || stmt.Schema == nil {
t.Fatal("expected Statement.Schema, got nil")
}

wantComments := []string{"", "姓名", "周岁", "生日"}
wantComments := []string{"ID", "姓名", "周岁", "生日", "注释中包含'单引号'和特殊符号❤️"}
gotComments := make([]string, len(stmt.Schema.DBNames))

for i, fieldDBName := range stmt.Schema.DBNames {
comment := m.GetColumnComment(stmt, fieldDBName)
gotComments[i] = comment
gotComments[i] = comment.String
}

if !reflect.DeepEqual(wantComments, gotComments) {
Expand Down