Skip to content

Commit b222e65

Browse files
committed
Support Migrator ColumnType interface
1 parent 6cb5563 commit b222e65

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

migrator.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package sqlserver
22

33
import (
4+
"database/sql"
45
"fmt"
6+
"regexp"
7+
"strings"
58

69
"gorm.io/gorm"
710
"gorm.io/gorm/clause"
@@ -139,6 +142,104 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
139142
})
140143
}
141144

145+
var defaultValueTrimRegexp = regexp.MustCompile("^\\('?(.*)'?\\)$")
146+
147+
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
148+
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
149+
columnTypes := make([]gorm.ColumnType, 0)
150+
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
151+
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
152+
if err != nil {
153+
return err
154+
}
155+
156+
defer func() {
157+
err = rows.Close()
158+
}()
159+
160+
var (
161+
rawColumnTypes, _ = rows.ColumnTypes()
162+
columnTypeSQL = "SELECT column_name, data_type, column_default, is_nullable, character_maximum_length, numeric_precision, numeric_precision_radix, numeric_scale, datetime_precision FROM INFORMATION_SCHEMA.COLUMNS WHERE table_catalog = ? AND table_name = ?"
163+
columns, rowErr = m.DB.Raw(columnTypeSQL, m.CurrentDatabase(), stmt.Table).Rows()
164+
)
165+
166+
if rowErr != nil {
167+
return rowErr
168+
}
169+
170+
defer columns.Close()
171+
172+
for columns.Next() {
173+
var (
174+
column migrator.ColumnType
175+
datetimePrecision sql.NullInt64
176+
radixValue sql.NullInt64
177+
nullableValue sql.NullString
178+
values = []interface{}{
179+
&column.NameValue, &column.ColumnTypeValue, &column.DefaultValueValue, &nullableValue, &column.LengthValue, &column.DecimalSizeValue, &radixValue, &column.ScaleValue, &datetimePrecision,
180+
}
181+
)
182+
183+
if scanErr := columns.Scan(values...); scanErr != nil {
184+
return scanErr
185+
}
186+
187+
if nullableValue.Valid {
188+
column.NullableValue = sql.NullBool{Bool: strings.EqualFold(nullableValue.String, "YES"), Valid: true}
189+
}
190+
191+
if datetimePrecision.Valid {
192+
column.DecimalSizeValue = datetimePrecision
193+
}
194+
195+
if column.DefaultValueValue.Valid {
196+
matches := defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String)
197+
for len(matches) > 1 {
198+
column.DefaultValueValue.String = matches[1]
199+
matches = defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String)
200+
}
201+
}
202+
203+
for _, c := range rawColumnTypes {
204+
if c.Name() == column.NameValue.String {
205+
column.SQLColumnType = c
206+
break
207+
}
208+
}
209+
210+
columnTypes = append(columnTypes, column)
211+
}
212+
213+
columnTypeRows, err := m.DB.Raw("SELECT c.column_name, t.constraint_type FROM information_schema.table_constraints t JOIN information_schema.constraint_column_usage c ON c.constraint_name=t.constraint_name WHERE t.constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_name = ?", m.CurrentDatabase(), stmt.Table).Rows()
214+
if err != nil {
215+
return err
216+
}
217+
defer columnTypeRows.Close()
218+
219+
for columnTypeRows.Next() {
220+
var name, columnType string
221+
columnTypeRows.Scan(&name, &columnType)
222+
for idx, c := range columnTypes {
223+
mc := c.(migrator.ColumnType)
224+
if mc.NameValue.String == name {
225+
switch columnType {
226+
case "PRIMARY KEY":
227+
mc.PrimayKeyValue = sql.NullBool{Bool: true, Valid: true}
228+
case "UNIQUE":
229+
mc.UniqueValue = sql.NullBool{Bool: true, Valid: true}
230+
}
231+
columnTypes[idx] = mc
232+
break
233+
}
234+
}
235+
}
236+
237+
return
238+
})
239+
240+
return columnTypes, execErr
241+
}
242+
142243
func (m Migrator) HasIndex(value interface{}, name string) bool {
143244
var count int
144245
m.RunWithValue(value, func(stmt *gorm.Statement) error {

sqlserver.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
169169
}
170170
return sqlType
171171
case schema.Float:
172+
if field.Precision > 0 {
173+
if field.Scale > 0 {
174+
return fmt.Sprintf("decimal(%d, %d)", field.Precision, field.Scale)
175+
}
176+
return fmt.Sprintf("decimal(%d)", field.Precision)
177+
}
172178
return "float"
173179
case schema.String:
174180
size := field.Size
@@ -185,6 +191,9 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
185191
}
186192
return "nvarchar(MAX)"
187193
case schema.Time:
194+
if field.Precision > 0 {
195+
return fmt.Sprintf("datetimeoffset(%d)", field.Precision)
196+
}
188197
return "datetimeoffset"
189198
case schema.Bytes:
190199
return "varbinary(MAX)"

0 commit comments

Comments
 (0)