Skip to content

Commit 139bd30

Browse files
authored
fix: AUTOINCREMENT flag cannot apply with PRIMARY KEY (go-gorm#167)
* fix: AUTOINCREMENT flag cannot apply with PRIMARY KEY * fix: migrator use ddl parser instead of regexp
1 parent af1b822 commit 139bd30

File tree

3 files changed

+85
-55
lines changed

3 files changed

+85
-55
lines changed

ddlmod.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) {
175175
return &result, nil
176176
}
177177

178+
func (d *ddl) clone() *ddl {
179+
copied := new(ddl)
180+
*copied = *d
181+
182+
copied.fields = make([]string, len(d.fields))
183+
copy(copied.fields, d.fields)
184+
copied.columns = make([]migrator.ColumnType, len(d.columns))
185+
copy(copied.columns, d.columns)
186+
187+
return copied
188+
}
189+
178190
func (d *ddl) compile() string {
179191
if len(d.fields) == 0 {
180192
return d.head
@@ -183,6 +195,21 @@ func (d *ddl) compile() string {
183195
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
184196
}
185197

198+
func (d *ddl) renameTable(dst, src string) error {
199+
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
200+
if err != nil {
201+
return err
202+
}
203+
204+
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
205+
if replaced == d.head {
206+
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
207+
}
208+
209+
d.head = replaced
210+
return nil
211+
}
212+
186213
func (d *ddl) addConstraint(name string, sql string) {
187214
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
188215

@@ -240,3 +267,30 @@ func (d *ddl) getColumns() []string {
240267
}
241268
return res
242269
}
270+
271+
func (d *ddl) alterColumn(name, sql string) bool {
272+
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
273+
274+
for i := 0; i < len(d.fields); i++ {
275+
if reg.MatchString(d.fields[i]) {
276+
d.fields[i] = sql
277+
return false
278+
}
279+
}
280+
281+
d.fields = append(d.fields, sql)
282+
return true
283+
}
284+
285+
func (d *ddl) removeColumn(name string) bool {
286+
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
287+
288+
for i := 0; i < len(d.fields); i++ {
289+
if reg.MatchString(d.fields[i]) {
290+
d.fields = append(d.fields[:i], d.fields[i+1:]...)
291+
return true
292+
}
293+
}
294+
295+
return false
296+
}

migrator.go

Lines changed: 29 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package sqlite
33
import (
44
"database/sql"
55
"fmt"
6-
"regexp"
76
"strings"
87

98
"gorm.io/gorm"
@@ -78,23 +77,16 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
7877

7978
func (m Migrator) AlterColumn(value interface{}, name string) error {
8079
return m.RunWithoutForeignKey(func() error {
81-
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
80+
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
8281
if field := stmt.Schema.LookUpField(name); field != nil {
83-
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
84-
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
85-
if err != nil {
86-
return "", nil, err
82+
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
83+
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
8784
}
8885

89-
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
90-
91-
if createSQL == rawDDL {
92-
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
93-
}
94-
95-
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
86+
return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
9687
}
97-
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
88+
89+
return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
9890
})
9991
})
10092
}
@@ -149,19 +141,13 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
149141
}
150142

151143
func (m Migrator) DropColumn(value interface{}, name string) error {
152-
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
144+
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
153145
if field := stmt.Schema.LookUpField(name); field != nil {
154146
name = field.DBName
155147
}
156148

157-
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
158-
if err != nil {
159-
return "", nil, err
160-
}
161-
162-
createSQL := reg.ReplaceAllString(rawDDL, "")
163-
164-
return createSQL, nil, nil
149+
ddl.removeColumn(name)
150+
return ddl, nil, nil
165151
})
166152
}
167153

@@ -170,7 +156,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
170156
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
171157

172158
return m.recreateTable(value, &table,
173-
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
159+
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
174160
var (
175161
constraintName string
176162
constraintSql string
@@ -185,17 +171,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
185171
constraintSql = "CONSTRAINT ? CHECK (?)"
186172
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
187173
} else {
188-
return "", nil, nil
174+
return nil, nil, nil
189175
}
190176

191-
createDDL, err := parseDDL(rawDDL)
192-
if err != nil {
193-
return "", nil, err
194-
}
195-
createDDL.addConstraint(constraintName, constraintSql)
196-
createSQL := createDDL.compile()
197-
198-
return createSQL, constraintValues, nil
177+
ddl.addConstraint(constraintName, constraintSql)
178+
return ddl, constraintValues, nil
199179
})
200180
})
201181
}
@@ -210,15 +190,9 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
210190
}
211191

212192
return m.recreateTable(value, &table,
213-
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
214-
createDDL, err := parseDDL(rawDDL)
215-
if err != nil {
216-
return "", nil, err
217-
}
218-
createDDL.removeConstraint(name)
219-
createSQL := createDDL.compile()
220-
221-
return createSQL, nil, nil
193+
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
194+
ddl.removeConstraint(name)
195+
return ddl, nil, nil
222196
})
223197
})
224198
}
@@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
375349
return createSQL, nil
376350
}
377351

378-
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
379-
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
352+
func (m Migrator) recreateTable(
353+
value interface{}, tablePtr *string,
354+
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
355+
) error {
380356
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
381357
table := stmt.Table
382358
if tablePtr != nil {
@@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
388364
return err
389365
}
390366

391-
newTableName := table + "__temp"
392-
393-
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
367+
originDDL, err := parseDDL(rawDDL)
394368
if err != nil {
395369
return err
396370
}
397-
if createSQL == "" {
398-
return nil
399-
}
400371

401-
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
372+
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
402373
if err != nil {
403374
return err
404375
}
405-
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
376+
if createDDL == nil {
377+
return nil
378+
}
406379

407-
createDDL, err := parseDDL(createSQL)
408-
if err != nil {
380+
newTableName := table + "__temp"
381+
if err := createDDL.renameTable(newTableName, table); err != nil {
409382
return err
410383
}
384+
411385
columns := createDDL.getColumns()
386+
createSQL := createDDL.compile()
412387

413388
return m.DB.Transaction(func(tx *gorm.DB) error {
414389
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {

sqlite.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
198198
case schema.Bool:
199199
return "numeric"
200200
case schema.Int, schema.Uint:
201-
if field.AutoIncrement && !field.PrimaryKey {
201+
if field.AutoIncrement {
202+
// doesn't check `PrimaryKey`, to keep backward compatibility
202203
// https://www.sqlite.org/autoinc.html
203204
return "integer PRIMARY KEY AUTOINCREMENT"
204205
} else {

0 commit comments

Comments
 (0)