Skip to content

Commit 851a217

Browse files
authored
Merge branch 'go-gorm:master' into master
2 parents b1197c8 + 74475fc commit 851a217

File tree

4 files changed

+137
-76
lines changed

4 files changed

+137
-76
lines changed

ddlmod.go

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func parseDDL(strs ...string) (*ddl, error) {
125125
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
126126
PrimaryKeyValue: sql.NullBool{Valid: true},
127127
UniqueValue: sql.NullBool{Valid: true},
128-
NullableValue: sql.NullBool{Valid: true},
128+
NullableValue: sql.NullBool{Bool: true, Valid: true},
129129
DefaultValueValue: sql.NullString{Valid: false},
130130
}
131131

@@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) {
175175
return &result, nil
176176
}
177177

178+
func (d *ddl) clone() *ddl {
179+
copied := new(ddl)
180+
*copied = *d
181+
182+
copied.fields = make([]string, len(d.fields))
183+
copy(copied.fields, d.fields)
184+
copied.columns = make([]migrator.ColumnType, len(d.columns))
185+
copy(copied.columns, d.columns)
186+
187+
return copied
188+
}
189+
178190
func (d *ddl) compile() string {
179191
if len(d.fields) == 0 {
180192
return d.head
@@ -183,6 +195,21 @@ func (d *ddl) compile() string {
183195
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
184196
}
185197

198+
func (d *ddl) renameTable(dst, src string) error {
199+
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*")
200+
if err != nil {
201+
return err
202+
}
203+
204+
replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst))
205+
if replaced == d.head {
206+
return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head)
207+
}
208+
209+
d.head = replaced
210+
return nil
211+
}
212+
186213
func (d *ddl) addConstraint(name string, sql string) {
187214
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
188215

@@ -240,3 +267,30 @@ func (d *ddl) getColumns() []string {
240267
}
241268
return res
242269
}
270+
271+
func (d *ddl) alterColumn(name, sql string) bool {
272+
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
273+
274+
for i := 0; i < len(d.fields); i++ {
275+
if reg.MatchString(d.fields[i]) {
276+
d.fields[i] = sql
277+
return false
278+
}
279+
}
280+
281+
d.fields = append(d.fields, sql)
282+
return true
283+
}
284+
285+
func (d *ddl) removeColumn(name string) bool {
286+
reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$")
287+
288+
for i := 0; i < len(d.fields); i++ {
289+
if reg.MatchString(d.fields[i]) {
290+
d.fields = append(d.fields[:i], d.fields[i+1:]...)
291+
return true
292+
}
293+
}
294+
295+
return false
296+
}

ddlmod_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ func TestParseDDL(t *testing.T) {
2020
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
2121
}, 6, []migrator.ColumnType{
2222
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
23-
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
24-
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
25-
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
23+
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
24+
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
25+
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
2626
},
2727
},
2828
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
2929
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
3030
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
31-
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
32-
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
31+
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
32+
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
3333
}},
3434
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
3535
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
@@ -39,7 +39,7 @@ func TestParseDDL(t *testing.T) {
3939
{"with_special_characters", []string{
4040
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
4141
}, 1, []migrator.ColumnType{
42-
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
42+
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
4343
},
4444
},
4545
{
@@ -122,7 +122,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
122122
NameValue: sql.NullString{String: "id", Valid: true},
123123
DataTypeValue: sql.NullString{String: "integer", Valid: true},
124124
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
125-
NullableValue: sql.NullBool{Bool: false, Valid: true},
125+
NullableValue: sql.NullBool{Bool: true, Valid: true},
126126
DefaultValueValue: sql.NullString{Valid: false},
127127
UniqueValue: sql.NullBool{Bool: true, Valid: true},
128128
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
@@ -131,7 +131,7 @@ func TestParseDDL_Whitespaces(t *testing.T) {
131131
NameValue: sql.NullString{String: "dark_mode", Valid: true},
132132
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
133133
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
134-
NullableValue: sql.NullBool{Valid: true},
134+
NullableValue: sql.NullBool{Bool: true, Valid: true},
135135
DefaultValueValue: sql.NullString{String: "true", Valid: true},
136136
UniqueValue: sql.NullBool{Bool: false, Valid: true},
137137
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},

migrator.go

Lines changed: 29 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package sqlite
33
import (
44
"database/sql"
55
"fmt"
6-
"regexp"
76
"strings"
87

98
"gorm.io/gorm"
@@ -78,23 +77,16 @@ func (m Migrator) HasColumn(value interface{}, name string) bool {
7877

7978
func (m Migrator) AlterColumn(value interface{}, name string) error {
8079
return m.RunWithoutForeignKey(func() error {
81-
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
80+
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
8281
if field := stmt.Schema.LookUpField(name); field != nil {
83-
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
84-
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
85-
if err != nil {
86-
return "", nil, err
82+
if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) {
83+
return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile())
8784
}
8885

89-
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
90-
91-
if createSQL == rawDDL {
92-
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
93-
}
94-
95-
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
86+
return ddl, []interface{}{m.FullDataTypeOf(field)}, nil
9687
}
97-
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
88+
89+
return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name)
9890
})
9991
})
10092
}
@@ -149,19 +141,13 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
149141
}
150142

151143
func (m Migrator) DropColumn(value interface{}, name string) error {
152-
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
144+
return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
153145
if field := stmt.Schema.LookUpField(name); field != nil {
154146
name = field.DBName
155147
}
156148

157-
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
158-
if err != nil {
159-
return "", nil, err
160-
}
161-
162-
createSQL := reg.ReplaceAllString(rawDDL, "")
163-
164-
return createSQL, nil, nil
149+
ddl.removeColumn(name)
150+
return ddl, nil, nil
165151
})
166152
}
167153

@@ -170,7 +156,7 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
170156
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
171157

172158
return m.recreateTable(value, &table,
173-
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
159+
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
174160
var (
175161
constraintName string
176162
constraintSql string
@@ -185,17 +171,11 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error {
185171
constraintSql = "CONSTRAINT ? CHECK (?)"
186172
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
187173
} else {
188-
return "", nil, nil
174+
return nil, nil, nil
189175
}
190176

191-
createDDL, err := parseDDL(rawDDL)
192-
if err != nil {
193-
return "", nil, err
194-
}
195-
createDDL.addConstraint(constraintName, constraintSql)
196-
createSQL := createDDL.compile()
197-
198-
return createSQL, constraintValues, nil
177+
ddl.addConstraint(constraintName, constraintSql)
178+
return ddl, constraintValues, nil
199179
})
200180
})
201181
}
@@ -210,15 +190,9 @@ func (m Migrator) DropConstraint(value interface{}, name string) error {
210190
}
211191

212192
return m.recreateTable(value, &table,
213-
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
214-
createDDL, err := parseDDL(rawDDL)
215-
if err != nil {
216-
return "", nil, err
217-
}
218-
createDDL.removeConstraint(name)
219-
createSQL := createDDL.compile()
220-
221-
return createSQL, nil, nil
193+
func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
194+
ddl.removeConstraint(name)
195+
return ddl, nil, nil
222196
})
223197
})
224198
}
@@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) {
375349
return createSQL, nil
376350
}
377351

378-
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
379-
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
352+
func (m Migrator) recreateTable(
353+
value interface{}, tablePtr *string,
354+
getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
355+
) error {
380356
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
381357
table := stmt.Table
382358
if tablePtr != nil {
@@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string,
388364
return err
389365
}
390366

391-
newTableName := table + "__temp"
392-
393-
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
367+
originDDL, err := parseDDL(rawDDL)
394368
if err != nil {
395369
return err
396370
}
397-
if createSQL == "" {
398-
return nil
399-
}
400371

401-
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
372+
createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
402373
if err != nil {
403374
return err
404375
}
405-
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
376+
if createDDL == nil {
377+
return nil
378+
}
406379

407-
createDDL, err := parseDDL(createSQL)
408-
if err != nil {
380+
newTableName := table + "__temp"
381+
if err := createDDL.renameTable(newTableName, table); err != nil {
409382
return err
410383
}
384+
411385
columns := createDDL.getColumns()
386+
createSQL := createDDL.compile()
412387

413388
return m.DB.Transaction(func(tx *gorm.DB) error {
414389
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {

sqlite.go

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"database/sql"
66
"strconv"
7-
"strings"
87

98
"gorm.io/gorm/callbacks"
109

@@ -58,7 +57,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
5857
if compareVersion(version, "3.35.0") >= 0 {
5958
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
6059
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
61-
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
60+
UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"},
6261
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
6362
LastInsertIDReversed: true,
6463
})
@@ -145,19 +144,51 @@ func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement,
145144
}
146145

147146
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
148-
writer.WriteByte('`')
149-
if strings.Contains(str, ".") {
150-
for idx, str := range strings.Split(str, ".") {
151-
if idx > 0 {
152-
writer.WriteString(".`")
147+
var (
148+
underQuoted, selfQuoted bool
149+
continuousBacktick int8
150+
shiftDelimiter int8
151+
)
152+
153+
for _, v := range []byte(str) {
154+
switch v {
155+
case '`':
156+
continuousBacktick++
157+
if continuousBacktick == 2 {
158+
writer.WriteString("``")
159+
continuousBacktick = 0
160+
}
161+
case '.':
162+
if continuousBacktick > 0 || !selfQuoted {
163+
shiftDelimiter = 0
164+
underQuoted = false
165+
continuousBacktick = 0
166+
writer.WriteString("`")
167+
}
168+
writer.WriteByte(v)
169+
continue
170+
default:
171+
if shiftDelimiter-continuousBacktick <= 0 && !underQuoted {
172+
writer.WriteString("`")
173+
underQuoted = true
174+
if selfQuoted = continuousBacktick > 0; selfQuoted {
175+
continuousBacktick -= 1
176+
}
177+
}
178+
179+
for ; continuousBacktick > 0; continuousBacktick -= 1 {
180+
writer.WriteString("``")
153181
}
154-
writer.WriteString(str)
155-
writer.WriteByte('`')
182+
183+
writer.WriteByte(v)
156184
}
157-
} else {
158-
writer.WriteString(str)
159-
writer.WriteByte('`')
185+
shiftDelimiter++
186+
}
187+
188+
if continuousBacktick > 0 && !selfQuoted {
189+
writer.WriteString("``")
160190
}
191+
writer.WriteString("`")
161192
}
162193

163194
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
@@ -169,7 +200,8 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string {
169200
case schema.Bool:
170201
return "numeric"
171202
case schema.Int, schema.Uint:
172-
if field.AutoIncrement && !field.PrimaryKey {
203+
if field.AutoIncrement {
204+
// doesn't check `PrimaryKey`, to keep backward compatibility
173205
// https://www.sqlite.org/autoinc.html
174206
return "integer PRIMARY KEY AUTOINCREMENT"
175207
} else {

0 commit comments

Comments
 (0)