|
1 | 1 | package sqlserver
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "database/sql" |
4 | 5 | "fmt"
|
| 6 | + "regexp" |
| 7 | + "strings" |
5 | 8 |
|
6 | 9 | "gorm.io/gorm"
|
7 | 10 | "gorm.io/gorm/clause"
|
@@ -139,6 +142,104 @@ func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error
|
139 | 142 | })
|
140 | 143 | }
|
141 | 144 |
|
| 145 | +var defaultValueTrimRegexp = regexp.MustCompile("^\\('?(.*)'?\\)$") |
| 146 | + |
| 147 | +// ColumnTypes return columnTypes []gorm.ColumnType and execErr error |
| 148 | +func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { |
| 149 | + columnTypes := make([]gorm.ColumnType, 0) |
| 150 | + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { |
| 151 | + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() |
| 152 | + if err != nil { |
| 153 | + return err |
| 154 | + } |
| 155 | + |
| 156 | + defer func() { |
| 157 | + err = rows.Close() |
| 158 | + }() |
| 159 | + |
| 160 | + var ( |
| 161 | + rawColumnTypes, _ = rows.ColumnTypes() |
| 162 | + columnTypeSQL = "SELECT column_name, data_type, column_default, is_nullable, character_maximum_length, numeric_precision, numeric_precision_radix, numeric_scale, datetime_precision FROM INFORMATION_SCHEMA.COLUMNS WHERE table_catalog = ? AND table_name = ?" |
| 163 | + columns, rowErr = m.DB.Raw(columnTypeSQL, m.CurrentDatabase(), stmt.Table).Rows() |
| 164 | + ) |
| 165 | + |
| 166 | + if rowErr != nil { |
| 167 | + return rowErr |
| 168 | + } |
| 169 | + |
| 170 | + defer columns.Close() |
| 171 | + |
| 172 | + for columns.Next() { |
| 173 | + var ( |
| 174 | + column migrator.ColumnType |
| 175 | + datetimePrecision sql.NullInt64 |
| 176 | + radixValue sql.NullInt64 |
| 177 | + nullableValue sql.NullString |
| 178 | + values = []interface{}{ |
| 179 | + &column.NameValue, &column.ColumnTypeValue, &column.DefaultValueValue, &nullableValue, &column.LengthValue, &column.DecimalSizeValue, &radixValue, &column.ScaleValue, &datetimePrecision, |
| 180 | + } |
| 181 | + ) |
| 182 | + |
| 183 | + if scanErr := columns.Scan(values...); scanErr != nil { |
| 184 | + return scanErr |
| 185 | + } |
| 186 | + |
| 187 | + if nullableValue.Valid { |
| 188 | + column.NullableValue = sql.NullBool{Bool: strings.EqualFold(nullableValue.String, "YES"), Valid: true} |
| 189 | + } |
| 190 | + |
| 191 | + if datetimePrecision.Valid { |
| 192 | + column.DecimalSizeValue = datetimePrecision |
| 193 | + } |
| 194 | + |
| 195 | + if column.DefaultValueValue.Valid { |
| 196 | + matches := defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String) |
| 197 | + for len(matches) > 1 { |
| 198 | + column.DefaultValueValue.String = matches[1] |
| 199 | + matches = defaultValueTrimRegexp.FindStringSubmatch(column.DefaultValueValue.String) |
| 200 | + } |
| 201 | + } |
| 202 | + |
| 203 | + for _, c := range rawColumnTypes { |
| 204 | + if c.Name() == column.NameValue.String { |
| 205 | + column.SQLColumnType = c |
| 206 | + break |
| 207 | + } |
| 208 | + } |
| 209 | + |
| 210 | + columnTypes = append(columnTypes, column) |
| 211 | + } |
| 212 | + |
| 213 | + columnTypeRows, err := m.DB.Raw("SELECT c.column_name, t.constraint_type FROM information_schema.table_constraints t JOIN information_schema.constraint_column_usage c ON c.constraint_name=t.constraint_name WHERE t.constraint_type IN ('PRIMARY KEY', 'UNIQUE') AND c.table_catalog = ? AND c.table_name = ?", m.CurrentDatabase(), stmt.Table).Rows() |
| 214 | + if err != nil { |
| 215 | + return err |
| 216 | + } |
| 217 | + defer columnTypeRows.Close() |
| 218 | + |
| 219 | + for columnTypeRows.Next() { |
| 220 | + var name, columnType string |
| 221 | + columnTypeRows.Scan(&name, &columnType) |
| 222 | + for idx, c := range columnTypes { |
| 223 | + mc := c.(migrator.ColumnType) |
| 224 | + if mc.NameValue.String == name { |
| 225 | + switch columnType { |
| 226 | + case "PRIMARY KEY": |
| 227 | + mc.PrimayKeyValue = sql.NullBool{Bool: true, Valid: true} |
| 228 | + case "UNIQUE": |
| 229 | + mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} |
| 230 | + } |
| 231 | + columnTypes[idx] = mc |
| 232 | + break |
| 233 | + } |
| 234 | + } |
| 235 | + } |
| 236 | + |
| 237 | + return |
| 238 | + }) |
| 239 | + |
| 240 | + return columnTypes, execErr |
| 241 | +} |
| 242 | + |
142 | 243 | func (m Migrator) HasIndex(value interface{}, name string) bool {
|
143 | 244 | var count int
|
144 | 245 | m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
|
0 commit comments