diff --git a/migrator.go b/migrator.go index e6f2160..c737592 100644 --- a/migrator.go +++ b/migrator.go @@ -48,7 +48,7 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) { } for _, fieldName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[fieldName] - if field.Comment == "" { + if _, ok := field.TagSettings["COMMENT"]; !ok { continue } if err = m.setColumnComment(stmt, field, true); err != nil { @@ -65,17 +65,18 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) { func (m Migrator) setColumnComment(stmt *gorm.Statement, field *schema.Field, add bool) error { schemaName := m.getTableSchemaName(stmt.Schema) + commentExpr := gorm.Expr(strings.ReplaceAll(field.Comment, "'", "''")) // add field comment if add { return m.DB.Exec( - "EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?", - field.Comment, schemaName, stmt.Table, field.DBName, + "EXEC sp_addextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?", + commentExpr, schemaName, stmt.Table, field.DBName, ).Error } // update field comment return m.DB.Exec( - "EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?", - field.Comment, schemaName, stmt.Table, field.DBName, + "EXEC sp_updateextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?", + commentExpr, schemaName, stmt.Table, field.DBName, ).Error } @@ -121,7 +122,7 @@ func getFullQualifiedTableName(stmt *gorm.Statement) string { func (m Migrator) HasTable(value interface{}) bool { var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { + _ = m.RunWithValue(value, func(stmt *gorm.Statement) error { schemaName := getTableSchemaName(stmt.Schema) if schemaName == "" { schemaName = "%" @@ -202,7 +203,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { - if field.Comment == "" { + if _, ok := field.TagSettings["COMMENT"]; !ok { return } if err = m.setColumnComment(stmt, field, true); err != nil { @@ -216,7 +217,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error { func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { + _ = m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() name := field if stmt.Schema != nil { @@ -273,17 +274,13 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error }) } -func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (description string) { +func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (comment sql.NullString) { queryTx := m.DB.Session(&gorm.Session{Logger: m.DB.Logger.LogMode(logger.Warn)}) if m.DB.DryRun { queryTx.DryRun = false } - var comment sql.NullString queryTx.Raw("SELECT value FROM ?.sys.fn_listextendedproperty('MS_Description', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?)", gorm.Expr(m.CurrentDatabase()), m.getTableSchemaName(stmt.Schema), stmt.Table, fieldDBName).Scan(&comment) - if comment.Valid { - description = comment.String - } return } @@ -293,12 +290,12 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { - description := m.GetColumnComment(stmt, field.DBName) - if field.Comment != description { - if description == "" { - err = m.setColumnComment(stmt, field, true) - } else { + comment := m.GetColumnComment(stmt, field.DBName) + if field.Comment != comment.String { + if comment.Valid { err = m.setColumnComment(stmt, field, false) + } else { + err = m.setColumnComment(stmt, field, true) } } return @@ -317,7 +314,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { } rawColumnTypes, _ := rows.ColumnTypes() - rows.Close() + _ = rows.Close() { _, schemaName, tableName := splitFullQualifiedName(stmt.Table) @@ -394,7 +391,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`) columnTypes = append(columnTypes, column) } - columns.Close() + _ = columns.Close() } { @@ -415,7 +412,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`) for columnTypeRows.Next() { var name, columnType string - columnTypeRows.Scan(&name, &columnType) + _ = columnTypeRows.Scan(&name, &columnType) for idx, c := range columnTypes { mc := c.(migrator.ColumnType) if mc.NameValue.String == name { @@ -431,7 +428,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`) } } - columnTypeRows.Close() + _ = columnTypeRows.Close() } return @@ -473,7 +470,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { func (m Migrator) HasIndex(value interface{}, name string) bool { var count int - m.RunWithValue(value, func(stmt *gorm.Statement) error { + _ = m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name @@ -538,34 +535,34 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) { func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 - m.RunWithValue(value, func(stmt *gorm.Statement) error { + _ = m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { name = constraint.GetName() } - tableCatalog, schema, tableName := splitFullQualifiedName(table) + tableCatalog, tableSchema, tableName := splitFullQualifiedName(table) if tableCatalog == "" { tableCatalog = m.CurrentDatabase() } - if schema == "" { - schema = "%" + if tableSchema == "" { + tableSchema = "%" } return m.DB.Raw( `SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join INFORMATION_SCHEMA.TABLES as I on I.TABLE_NAME = T.name WHERE F.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ?;`, - name, tableName, schema, tableCatalog, + name, tableName, tableSchema, tableCatalog, ).Row().Scan(&count) }) return count > 0 } func (m Migrator) CurrentDatabase() (name string) { - m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) + _ = m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name) return } func (m Migrator) DefaultSchema() (name string) { - m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name) + _ = m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name) return } diff --git a/migrator_test.go b/migrator_test.go index 056a58a..e943fda 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -190,7 +190,7 @@ func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, co } type TestTableFieldComment struct { - ID string `gorm:"column:id;primaryKey"` + ID string `gorm:"column:id;primaryKey;comment:"` // field comment is an empty string Name string `gorm:"column:name;comment:姓名"` Age uint `gorm:"column:age;comment:年龄"` } @@ -198,10 +198,11 @@ type TestTableFieldComment struct { func (*TestTableFieldComment) TableName() string { return "test_table_field_comment" } type TestTableFieldCommentUpdate struct { - ID string `gorm:"column:id;primaryKey"` + ID string `gorm:"column:id;primaryKey;comment:ID"` Name string `gorm:"column:name;comment:姓名"` Age uint `gorm:"column:age;comment:周岁"` Birthday *time.Time `gorm:"column:birthday;comment:生日"` + Quote string `gorm:"column:quote;comment:注释中包含'单引号'和特殊符号❤️"` } func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_field_comment" } @@ -209,37 +210,37 @@ func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_fiel func TestMigrator_MigrateColumnComment(t *testing.T) { db, err := gorm.Open(sqlserver.Open(sqlserverDSN)) if err != nil { - t.Error(err) + t.Fatal(err) } - migrator := db.Debug().Migrator() + dm := db.Debug().Migrator() tableModel := new(TestTableFieldComment) defer func() { - if err = migrator.DropTable(tableModel); err != nil { + if err = dm.DropTable(tableModel); err != nil { t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err) } }() - if err = migrator.AutoMigrate(tableModel); err != nil { + if err = dm.AutoMigrate(tableModel); err != nil { t.Fatal(err) } tableModelUpdate := new(TestTableFieldCommentUpdate) - if err = migrator.AutoMigrate(tableModelUpdate); err != nil { + if err = dm.AutoMigrate(tableModelUpdate); err != nil { t.Error(err) } - if m, ok := migrator.(sqlserver.Migrator); ok { + if m, ok := dm.(sqlserver.Migrator); ok { stmt := db.Model(tableModelUpdate).Find(nil).Statement if stmt == nil || stmt.Schema == nil { t.Fatal("expected Statement.Schema, got nil") } - wantComments := []string{"", "姓名", "周岁", "生日"} + wantComments := []string{"ID", "姓名", "周岁", "生日", "注释中包含'单引号'和特殊符号❤️"} gotComments := make([]string, len(stmt.Schema.DBNames)) for i, fieldDBName := range stmt.Schema.DBNames { comment := m.GetColumnComment(stmt, fieldDBName) - gotComments[i] = comment + gotComments[i] = comment.String } if !reflect.DeepEqual(wantComments, gotComments) {