@@ -3,7 +3,6 @@ package sqlite
3
3
import (
4
4
"database/sql"
5
5
"fmt"
6
- "regexp"
7
6
"strings"
8
7
9
8
"gorm.io/gorm"
@@ -78,23 +77,16 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
78
77
79
78
func (m Migrator ) AlterColumn (value interface {}, name string ) error {
80
79
return m .RunWithoutForeignKey (func () error {
81
- return m .recreateTable (value , nil , func (rawDDL string , stmt * gorm.Statement ) (sql string , sqlArgs []interface {}, err error ) {
80
+ return m .recreateTable (value , nil , func (ddl * ddl , stmt * gorm.Statement ) (* ddl , []interface {}, error ) {
82
81
if field := stmt .Schema .LookUpField (name ); field != nil {
83
- // lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
84
- reg , err := regexp .Compile ("(`|'|\" | )" + field .DBName + "(`|'|\" | ) .*?(,|\\ )\\ s*$)" )
85
- if err != nil {
86
- return "" , nil , err
82
+ if ddl .alterColumn (field .DBName , fmt .Sprintf ("`%s` ?" , field .DBName )) {
83
+ return nil , nil , fmt .Errorf ("field `%s` not found in origin ddl, ddl= '%s'" , name , ddl .compile ())
87
84
}
88
85
89
- createSQL := reg .ReplaceAllString (rawDDL , fmt .Sprintf ("`%v` ?$3" , field .DBName ))
90
-
91
- if createSQL == rawDDL {
92
- return "" , nil , fmt .Errorf ("failed to look up field %v from DDL %v" , field .DBName , rawDDL )
93
- }
94
-
95
- return createSQL , []interface {}{m .FullDataTypeOf (field )}, nil
86
+ return ddl , []interface {}{m .FullDataTypeOf (field )}, nil
96
87
}
97
- return "" , nil , fmt .Errorf ("failed to alter field with name %v" , name )
88
+
89
+ return nil , nil , fmt .Errorf ("failed to alter field with name `%s`" , name )
98
90
})
99
91
})
100
92
}
@@ -149,19 +141,13 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
149
141
}
150
142
151
143
func (m Migrator ) DropColumn (value interface {}, name string ) error {
152
- return m .recreateTable (value , nil , func (rawDDL string , stmt * gorm.Statement ) (sql string , sqlArgs []interface {}, err error ) {
144
+ return m .recreateTable (value , nil , func (ddl * ddl , stmt * gorm.Statement ) (* ddl , []interface {}, error ) {
153
145
if field := stmt .Schema .LookUpField (name ); field != nil {
154
146
name = field .DBName
155
147
}
156
148
157
- reg , err := regexp .Compile ("(`|'|\" | |\\ [)" + name + "(`|'|\" | |\\ ]) .*?," )
158
- if err != nil {
159
- return "" , nil , err
160
- }
161
-
162
- createSQL := reg .ReplaceAllString (rawDDL , "" )
163
-
164
- return createSQL , nil , nil
149
+ ddl .removeColumn (name )
150
+ return ddl , nil , nil
165
151
})
166
152
}
167
153
@@ -170,7 +156,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
170
156
constraint , chk , table := m .GuessConstraintAndTable (stmt , name )
171
157
172
158
return m .recreateTable (value , & table ,
173
- func (rawDDL string , stmt * gorm.Statement ) (sql string , sqlArgs []interface {}, err error ) {
159
+ func (ddl * ddl , stmt * gorm.Statement ) (* ddl , []interface {}, error ) {
174
160
var (
175
161
constraintName string
176
162
constraintSql string
@@ -185,17 +171,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
185
171
constraintSql = "CONSTRAINT ? CHECK (?)"
186
172
constraintValues = []interface {}{clause.Column {Name : chk .Name }, clause.Expr {SQL : chk .Constraint }}
187
173
} else {
188
- return "" , nil , nil
174
+ return nil , nil , nil
189
175
}
190
176
191
- createDDL , err := parseDDL (rawDDL )
192
- if err != nil {
193
- return "" , nil , err
194
- }
195
- createDDL .addConstraint (constraintName , constraintSql )
196
- createSQL := createDDL .compile ()
197
-
198
- return createSQL , constraintValues , nil
177
+ ddl .addConstraint (constraintName , constraintSql )
178
+ return ddl , constraintValues , nil
199
179
})
200
180
})
201
181
}
@@ -210,15 +190,9 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
210
190
}
211
191
212
192
return m .recreateTable (value , & table ,
213
- func (rawDDL string , stmt * gorm.Statement ) (sql string , sqlArgs []interface {}, err error ) {
214
- createDDL , err := parseDDL (rawDDL )
215
- if err != nil {
216
- return "" , nil , err
217
- }
218
- createDDL .removeConstraint (name )
219
- createSQL := createDDL .compile ()
220
-
221
- return createSQL , nil , nil
193
+ func (ddl * ddl , stmt * gorm.Statement ) (* ddl , []interface {}, error ) {
194
+ ddl .removeConstraint (name )
195
+ return ddl , nil , nil
222
196
})
223
197
})
224
198
}
@@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
375
349
return createSQL , nil
376
350
}
377
351
378
- func (m Migrator ) recreateTable (value interface {}, tablePtr * string ,
379
- getCreateSQL func (rawDDL string , stmt * gorm.Statement ) (sql string , sqlArgs []interface {}, err error )) error {
352
+ func (m Migrator ) recreateTable (
353
+ value interface {}, tablePtr * string ,
354
+ getCreateSQL func (ddl * ddl , stmt * gorm.Statement ) (sql * ddl , sqlArgs []interface {}, err error ),
355
+ ) error {
380
356
return m .RunWithValue (value , func (stmt * gorm.Statement ) error {
381
357
table := stmt .Table
382
358
if tablePtr != nil {
@@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
388
364
return err
389
365
}
390
366
391
- newTableName := table + "__temp"
392
-
393
- createSQL , sqlArgs , err := getCreateSQL (rawDDL , stmt )
367
+ originDDL , err := parseDDL (rawDDL )
394
368
if err != nil {
395
369
return err
396
370
}
397
- if createSQL == "" {
398
- return nil
399
- }
400
371
401
- tableReg , err := regexp . Compile ( " \\ s*('|`| \" )? \\ b" + table + " \\ b('|`| \" )? \\ s*" )
372
+ createDDL , sqlArgs , err := getCreateSQL ( originDDL . clone (), stmt )
402
373
if err != nil {
403
374
return err
404
375
}
405
- createSQL = tableReg .ReplaceAllString (createSQL , fmt .Sprintf (" `%v` " , newTableName ))
376
+ if createDDL == nil {
377
+ return nil
378
+ }
406
379
407
- createDDL , err := parseDDL ( createSQL )
408
- if err != nil {
380
+ newTableName := table + "__temp"
381
+ if err := createDDL . renameTable ( newTableName , table ); err != nil {
409
382
return err
410
383
}
384
+
411
385
columns := createDDL .getColumns ()
386
+ createSQL := createDDL .compile ()
412
387
413
388
return m .DB .Transaction (func (tx * gorm.DB ) error {
414
389
if err := tx .Exec (createSQL , sqlArgs ... ).Error ; err != nil {
0 commit comments