Skip to content

Commit 606ad2b

Browse files
committed
fix: failed to modify field comments when empty due to incorrect conditions
Ref: #140
1 parent 4937266 commit 606ad2b

File tree

2 files changed

+36
-40
lines changed

2 files changed

+36
-40
lines changed

migrator.go

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (m Migrator) CreateTable(values ...interface{}) (err error) {
4848
}
4949
for _, fieldName := range stmt.Schema.DBNames {
5050
field := stmt.Schema.FieldsByDBName[fieldName]
51-
if field.Comment == "" {
51+
if _, ok := field.TagSettings["COMMENT"]; !ok {
5252
continue
5353
}
5454
if err = m.setColumnComment(stmt, field, true); err != nil {
@@ -68,14 +68,14 @@ func (m Migrator) setColumnComment(stmt *gorm.Statement, field *schema.Field, ad
6868
// add field comment
6969
if add {
7070
return m.DB.Exec(
71-
"EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
72-
field.Comment, schemaName, stmt.Table, field.DBName,
71+
"EXEC sp_addextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
72+
gorm.Expr(field.Comment), schemaName, stmt.Table, field.DBName,
7373
).Error
7474
}
7575
// update field comment
7676
return m.DB.Exec(
77-
"EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
78-
field.Comment, schemaName, stmt.Table, field.DBName,
77+
"EXEC sp_updateextendedproperty 'MS_Description', N'?', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
78+
gorm.Expr(field.Comment), schemaName, stmt.Table, field.DBName,
7979
).Error
8080
}
8181

@@ -121,7 +121,7 @@ func getFullQualifiedTableName(stmt *gorm.Statement) string {
121121

122122
func (m Migrator) HasTable(value interface{}) bool {
123123
var count int
124-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
124+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
125125
schemaName := getTableSchemaName(stmt.Schema)
126126
if schemaName == "" {
127127
schemaName = "%"
@@ -202,7 +202,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
202202
return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
203203
if stmt.Schema != nil {
204204
if field := stmt.Schema.LookUpField(name); field != nil {
205-
if field.Comment == "" {
205+
if _, ok := field.TagSettings["COMMENT"]; !ok {
206206
return
207207
}
208208
if err = m.setColumnComment(stmt, field, true); err != nil {
@@ -216,7 +216,7 @@ func (m Migrator) AddColumn(value interface{}, name string) error {
216216

217217
func (m Migrator) HasColumn(value interface{}, field string) bool {
218218
var count int64
219-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
219+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
220220
currentDatabase := m.DB.Migrator().CurrentDatabase()
221221
name := field
222222
if stmt.Schema != nil {
@@ -273,17 +273,13 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
273273
})
274274
}
275275

276-
func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (description string) {
276+
func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (comment sql.NullString) {
277277
queryTx := m.DB.Session(&gorm.Session{Logger: m.DB.Logger.LogMode(logger.Warn)})
278278
if m.DB.DryRun {
279279
queryTx.DryRun = false
280280
}
281-
var comment sql.NullString
282281
queryTx.Raw("SELECT value FROM ?.sys.fn_listextendedproperty('MS_Description', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?)",
283282
gorm.Expr(m.CurrentDatabase()), m.getTableSchemaName(stmt.Schema), stmt.Table, fieldDBName).Scan(&comment)
284-
if comment.Valid {
285-
description = comment.String
286-
}
287283
return
288284
}
289285

@@ -293,12 +289,12 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy
293289
}
294290

295291
return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
296-
description := m.GetColumnComment(stmt, field.DBName)
297-
if field.Comment != description {
298-
if description == "" {
299-
err = m.setColumnComment(stmt, field, true)
300-
} else {
292+
comment := m.GetColumnComment(stmt, field.DBName)
293+
if field.Comment != comment.String {
294+
if comment.Valid {
301295
err = m.setColumnComment(stmt, field, false)
296+
} else {
297+
err = m.setColumnComment(stmt, field, true)
302298
}
303299
}
304300
return
@@ -317,7 +313,7 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
317313
}
318314

319315
rawColumnTypes, _ := rows.ColumnTypes()
320-
rows.Close()
316+
_ = rows.Close()
321317

322318
{
323319
_, schemaName, tableName := splitFullQualifiedName(stmt.Table)
@@ -394,7 +390,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
394390
columnTypes = append(columnTypes, column)
395391
}
396392

397-
columns.Close()
393+
_ = columns.Close()
398394
}
399395

400396
{
@@ -415,7 +411,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
415411

416412
for columnTypeRows.Next() {
417413
var name, columnType string
418-
columnTypeRows.Scan(&name, &columnType)
414+
_ = columnTypeRows.Scan(&name, &columnType)
419415
for idx, c := range columnTypes {
420416
mc := c.(migrator.ColumnType)
421417
if mc.NameValue.String == name {
@@ -431,7 +427,7 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`)
431427
}
432428
}
433429

434-
columnTypeRows.Close()
430+
_ = columnTypeRows.Close()
435431
}
436432

437433
return
@@ -473,7 +469,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error {
473469

474470
func (m Migrator) HasIndex(value interface{}, name string) bool {
475471
var count int
476-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
472+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
477473
if stmt.Schema != nil {
478474
if idx := stmt.Schema.LookIndex(name); idx != nil {
479475
name = idx.Name
@@ -538,34 +534,34 @@ func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
538534

539535
func (m Migrator) HasConstraint(value interface{}, name string) bool {
540536
var count int64
541-
m.RunWithValue(value, func(stmt *gorm.Statement) error {
537+
_ = m.RunWithValue(value, func(stmt *gorm.Statement) error {
542538
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
543539
if constraint != nil {
544540
name = constraint.GetName()
545541
}
546542

547-
tableCatalog, schema, tableName := splitFullQualifiedName(table)
543+
tableCatalog, tableSchema, tableName := splitFullQualifiedName(table)
548544
if tableCatalog == "" {
549545
tableCatalog = m.CurrentDatabase()
550546
}
551-
if schema == "" {
552-
schema = "%"
547+
if tableSchema == "" {
548+
tableSchema = "%"
553549
}
554550

555551
return m.DB.Raw(
556552
`SELECT count(*) FROM sys.foreign_keys as F inner join sys.tables as T on F.parent_object_id=T.object_id inner join INFORMATION_SCHEMA.TABLES as I on I.TABLE_NAME = T.name WHERE F.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ?;`,
557-
name, tableName, schema, tableCatalog,
553+
name, tableName, tableSchema, tableCatalog,
558554
).Row().Scan(&count)
559555
})
560556
return count > 0
561557
}
562558

563559
func (m Migrator) CurrentDatabase() (name string) {
564-
m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
560+
_ = m.DB.Raw("SELECT DB_NAME() AS [Current Database]").Row().Scan(&name)
565561
return
566562
}
567563

568564
func (m Migrator) DefaultSchema() (name string) {
569-
m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
565+
_ = m.DB.Raw("SELECT SCHEMA_NAME() AS [Default Schema]").Row().Scan(&name)
570566
return
571567
}

migrator_test.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,15 @@ func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, co
190190
}
191191

192192
type TestTableFieldComment struct {
193-
ID string `gorm:"column:id;primaryKey"`
193+
ID string `gorm:"column:id;primaryKey;comment:"` // field comment is an empty string
194194
Name string `gorm:"column:name;comment:姓名"`
195195
Age uint `gorm:"column:age;comment:年龄"`
196196
}
197197

198198
func (*TestTableFieldComment) TableName() string { return "test_table_field_comment" }
199199

200200
type TestTableFieldCommentUpdate struct {
201-
ID string `gorm:"column:id;primaryKey"`
201+
ID string `gorm:"column:id;primaryKey;comment:ID"`
202202
Name string `gorm:"column:name;comment:姓名"`
203203
Age uint `gorm:"column:age;comment:周岁"`
204204
Birthday *time.Time `gorm:"column:birthday;comment:生日"`
@@ -209,37 +209,37 @@ func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_fiel
209209
func TestMigrator_MigrateColumnComment(t *testing.T) {
210210
db, err := gorm.Open(sqlserver.Open(sqlserverDSN))
211211
if err != nil {
212-
t.Error(err)
212+
t.Fatal(err)
213213
}
214-
migrator := db.Debug().Migrator()
214+
dm := db.Debug().Migrator()
215215

216216
tableModel := new(TestTableFieldComment)
217217
defer func() {
218-
if err = migrator.DropTable(tableModel); err != nil {
218+
if err = dm.DropTable(tableModel); err != nil {
219219
t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err)
220220
}
221221
}()
222222

223-
if err = migrator.AutoMigrate(tableModel); err != nil {
223+
if err = dm.AutoMigrate(tableModel); err != nil {
224224
t.Fatal(err)
225225
}
226226
tableModelUpdate := new(TestTableFieldCommentUpdate)
227-
if err = migrator.AutoMigrate(tableModelUpdate); err != nil {
227+
if err = dm.AutoMigrate(tableModelUpdate); err != nil {
228228
t.Error(err)
229229
}
230230

231-
if m, ok := migrator.(sqlserver.Migrator); ok {
231+
if m, ok := dm.(sqlserver.Migrator); ok {
232232
stmt := db.Model(tableModelUpdate).Find(nil).Statement
233233
if stmt == nil || stmt.Schema == nil {
234234
t.Fatal("expected Statement.Schema, got nil")
235235
}
236236

237-
wantComments := []string{"", "姓名", "周岁", "生日"}
237+
wantComments := []string{"ID", "姓名", "周岁", "生日"}
238238
gotComments := make([]string, len(stmt.Schema.DBNames))
239239

240240
for i, fieldDBName := range stmt.Schema.DBNames {
241241
comment := m.GetColumnComment(stmt, fieldDBName)
242-
gotComments[i] = comment
242+
gotComments[i] = comment.String
243243
}
244244

245245
if !reflect.DeepEqual(wantComments, gotComments) {

0 commit comments

Comments
 (0)