Skip to content

Commit dd0e76a

Browse files
authored
feat: support setting field comments (#140)
1 parent 94b32a6 commit dd0e76a

File tree

2 files changed

+165
-0
lines changed

2 files changed

+165
-0
lines changed

migrator.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,58 @@ func (m Migrator) GetTables() (tableList []string, err error) {
3636
return tableList, m.DB.Raw("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_CATALOG = ?", m.CurrentDatabase()).Scan(&tableList).Error
3737
}
3838

39+
func (m Migrator) CreateTable(values ...interface{}) (err error) {
40+
if err = m.Migrator.CreateTable(values...); err != nil {
41+
return
42+
}
43+
for _, value := range m.ReorderModels(values, false) {
44+
if err = m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
45+
if stmt.Schema == nil {
46+
return
47+
}
48+
for _, fieldName := range stmt.Schema.DBNames {
49+
field := stmt.Schema.FieldsByDBName[fieldName]
50+
if field.Comment == "" {
51+
continue
52+
}
53+
if err = m.setColumnComment(stmt, field, true); err != nil {
54+
return
55+
}
56+
}
57+
return
58+
}); err != nil {
59+
return
60+
}
61+
}
62+
return
63+
}
64+
65+
func (m Migrator) setColumnComment(stmt *gorm.Statement, field *schema.Field, add bool) error {
66+
schemaName := m.getTableSchemaName(stmt.Schema)
67+
// add field comment
68+
if add {
69+
return m.DB.Exec(
70+
"EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
71+
field.Comment, schemaName, stmt.Table, field.DBName,
72+
).Error
73+
}
74+
// update field comment
75+
return m.DB.Exec(
76+
"EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?",
77+
field.Comment, schemaName, stmt.Table, field.DBName,
78+
).Error
79+
}
80+
81+
func (m Migrator) getTableSchemaName(schema *schema.Schema) string {
82+
// return the schema name if it is explicitly provided in the table name
83+
// otherwise return default schema name
84+
schemaName := getTableSchemaName(schema)
85+
if schemaName == "" {
86+
schemaName = m.DefaultSchema()
87+
}
88+
return schemaName
89+
}
90+
3991
func getTableSchemaName(schema *schema.Schema) string {
4092
// return the schema name if it is explicitly provided in the table name
4193
// otherwise return a sql wildcard -> use any table_schema
@@ -141,6 +193,26 @@ func (m Migrator) RenameTable(oldName, newName interface{}) error {
141193
).Error
142194
}
143195

196+
func (m Migrator) AddColumn(value interface{}, name string) error {
197+
if err := m.Migrator.AddColumn(value, name); err != nil {
198+
return err
199+
}
200+
201+
return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
202+
if stmt.Schema != nil {
203+
if field := stmt.Schema.LookUpField(name); field != nil {
204+
if field.Comment == "" {
205+
return
206+
}
207+
if err = m.setColumnComment(stmt, field, true); err != nil {
208+
return
209+
}
210+
}
211+
}
212+
return
213+
})
214+
}
215+
144216
func (m Migrator) HasColumn(value interface{}, field string) bool {
145217
var count int64
146218
m.RunWithValue(value, func(stmt *gorm.Statement) error {
@@ -200,6 +272,39 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
200272
})
201273
}
202274

275+
func (m Migrator) GetColumnComment(stmt *gorm.Statement, fieldDBName string) (description string) {
276+
queryTx := m.DB
277+
if m.DB.DryRun {
278+
queryTx = m.DB.Session(&gorm.Session{})
279+
queryTx.DryRun = false
280+
}
281+
var comment sql.NullString
282+
queryTx.Raw("SELECT value FROM ?.sys.fn_listextendedproperty('MS_Description', 'SCHEMA', ?, 'TABLE', ?, 'COLUMN', ?)",
283+
gorm.Expr(m.CurrentDatabase()), m.getTableSchemaName(stmt.Schema), stmt.Table, fieldDBName).Scan(&comment)
284+
if comment.Valid {
285+
description = comment.String
286+
}
287+
return
288+
}
289+
290+
func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error {
291+
if err := m.Migrator.MigrateColumn(value, field, columnType); err != nil {
292+
return err
293+
}
294+
295+
return m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
296+
description := m.GetColumnComment(stmt, field.DBName)
297+
if field.Comment != description {
298+
if description == "" {
299+
err = m.setColumnComment(stmt, field, true)
300+
} else {
301+
err = m.setColumnComment(stmt, field, false)
302+
}
303+
}
304+
return
305+
})
306+
}
307+
203308
var defaultValueTrimRegexp = regexp.MustCompile("^\\('?([^']*)'?\\)$")
204309

205310
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error

migrator_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,63 @@ func testGetMigrateColumns(db *gorm.DB, dst interface{}) (columnsWithDefault, co
188188
}
189189
return
190190
}
191+
192+
type TestTableFieldComment struct {
193+
ID string `gorm:"column:id;primaryKey"`
194+
Name string `gorm:"column:name;comment:姓名"`
195+
Age uint `gorm:"column:age;comment:年龄"`
196+
}
197+
198+
func (*TestTableFieldComment) TableName() string { return "test_table_field_comment" }
199+
200+
type TestTableFieldCommentUpdate struct {
201+
ID string `gorm:"column:id;primaryKey"`
202+
Name string `gorm:"column:name;comment:姓名"`
203+
Age uint `gorm:"column:age;comment:周岁"`
204+
Birthday *time.Time `gorm:"column:birthday;comment:生日"`
205+
}
206+
207+
func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_field_comment" }
208+
209+
func TestMigrator_MigrateColumnComment(t *testing.T) {
210+
db, err := gorm.Open(sqlserver.Open(sqlserverDSN))
211+
if err != nil {
212+
t.Error(err)
213+
}
214+
migrator := db.Debug().Migrator()
215+
216+
tableModel := new(TestTableFieldComment)
217+
defer func() {
218+
if err = migrator.DropTable(tableModel); err != nil {
219+
t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err)
220+
}
221+
}()
222+
223+
if err = migrator.AutoMigrate(tableModel); err != nil {
224+
t.Fatal(err)
225+
}
226+
tableModelUpdate := new(TestTableFieldCommentUpdate)
227+
if err = migrator.AutoMigrate(tableModelUpdate); err != nil {
228+
t.Error(err)
229+
}
230+
231+
if m, ok := migrator.(sqlserver.Migrator); ok {
232+
stmt := db.Model(tableModelUpdate).Find(nil).Statement
233+
if stmt == nil || stmt.Schema == nil {
234+
t.Fatal("expected Statement.Schema, got nil")
235+
}
236+
237+
wantComments := []string{"", "姓名", "周岁", "生日"}
238+
gotComments := make([]string, len(stmt.Schema.DBNames))
239+
240+
for i, fieldDBName := range stmt.Schema.DBNames {
241+
comment := m.GetColumnComment(stmt, fieldDBName)
242+
gotComments[i] = comment
243+
}
244+
245+
if !reflect.DeepEqual(wantComments, gotComments) {
246+
t.Fatalf("expected comments %#v, got %#v", wantComments, gotComments)
247+
}
248+
t.Logf("got comments: %#v", gotComments)
249+
}
250+
}

0 commit comments

Comments
 (0)