Skip to content
This repository was archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.

Commit 17592d9

Browse files
authored
Add insert select where support (#1401)
1 parent b78ac8c commit 17592d9

File tree

2 files changed

+152
-21
lines changed

2 files changed

+152
-21
lines changed

session_insert.go

Lines changed: 93 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"strconv"
1313
"strings"
1414

15+
"xorm.io/builder"
1516
"xorm.io/core"
1617
)
1718

@@ -345,7 +346,7 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
345346
for _, v := range exprColumns {
346347
// remove the expr columns
347348
for i, colName := range colNames {
348-
if colName == v.colName {
349+
if colName == strings.Trim(v.colName, "`") {
349350
colNames = append(colNames[:i], colNames[i+1:]...)
350351
args = append(args[:i], args[i+1:]...)
351352
}
@@ -371,12 +372,30 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
371372
if session.engine.dialect.DBType() == core.MSSQL && len(table.AutoIncrement) > 0 {
372373
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
373374
}
375+
374376
if len(colPlaces) > 0 {
375-
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
376-
session.engine.Quote(tableName),
377-
quoteColumns(colNames, session.engine.Quote, ","),
378-
output,
379-
colPlaces)
377+
if session.statement.cond.IsValid() {
378+
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
379+
if err != nil {
380+
return 0, err
381+
}
382+
383+
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s SELECT %v FROM %v WHERE %v",
384+
session.engine.Quote(tableName),
385+
quoteColumns(colNames, session.engine.Quote, ","),
386+
output,
387+
colPlaces,
388+
session.engine.Quote(tableName),
389+
condSQL,
390+
)
391+
args = append(args, condArgs...)
392+
} else {
393+
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
394+
session.engine.Quote(tableName),
395+
quoteColumns(colNames, session.engine.Quote, ","),
396+
output,
397+
colPlaces)
398+
}
380399
} else {
381400
if session.engine.dialect.DBType() == core.MYSQL {
382401
sqlStr = fmt.Sprintf("INSERT INTO %s VALUES ()", session.engine.Quote(tableName))
@@ -663,26 +682,52 @@ func (session *Session) insertMapInterface(m map[string]interface{}) (int64, err
663682
return 0, ErrParamsType
664683
}
665684

685+
tableName := session.statement.TableName()
686+
if len(tableName) <= 0 {
687+
return 0, ErrTableNotFound
688+
}
689+
666690
var columns = make([]string, 0, len(m))
667691
for k := range m {
668692
columns = append(columns, k)
669693
}
670694
sort.Strings(columns)
671695

672696
qm := strings.Repeat("?,", len(columns))
673-
qm = "(" + qm[:len(qm)-1] + ")"
674-
675-
tableName := session.statement.TableName()
676-
if len(tableName) <= 0 {
677-
return 0, ErrTableNotFound
678-
}
679697

680-
var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
681698
var args = make([]interface{}, 0, len(m))
682699
for _, colName := range columns {
683700
args = append(args, m[colName])
684701
}
685702

703+
// insert expr columns, override if exists
704+
exprColumns := session.statement.getExpr()
705+
for _, col := range exprColumns {
706+
columns = append(columns, strings.Trim(col.colName, "`"))
707+
qm = qm + col.expr + ","
708+
}
709+
710+
qm = qm[:len(qm)-1]
711+
712+
var sql string
713+
714+
if session.statement.cond.IsValid() {
715+
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
716+
if err != nil {
717+
return 0, err
718+
}
719+
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
720+
session.engine.Quote(tableName),
721+
strings.Join(columns, "`,`"),
722+
qm,
723+
session.engine.Quote(tableName),
724+
condSQL,
725+
)
726+
args = append(args, condArgs...)
727+
} else {
728+
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
729+
}
730+
686731
if err := session.cacheInsert(tableName); err != nil {
687732
return 0, err
688733
}
@@ -703,24 +748,51 @@ func (session *Session) insertMapString(m map[string]string) (int64, error) {
703748
return 0, ErrParamsType
704749
}
705750

751+
tableName := session.statement.TableName()
752+
if len(tableName) <= 0 {
753+
return 0, ErrTableNotFound
754+
}
755+
706756
var columns = make([]string, 0, len(m))
707757
for k := range m {
708758
columns = append(columns, k)
709759
}
710760
sort.Strings(columns)
711761

762+
var args = make([]interface{}, 0, len(m))
763+
for _, colName := range columns {
764+
args = append(args, m[colName])
765+
}
766+
712767
qm := strings.Repeat("?,", len(columns))
713-
qm = "(" + qm[:len(qm)-1] + ")"
714768

715-
tableName := session.statement.TableName()
716-
if len(tableName) <= 0 {
717-
return 0, ErrTableNotFound
769+
// insert expr columns, override if exists
770+
exprColumns := session.statement.getExpr()
771+
for _, col := range exprColumns {
772+
columns = append(columns, strings.Trim(col.colName, "`"))
773+
qm = qm + col.expr + ","
718774
}
719775

720-
var sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES %s", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
721-
var args = make([]interface{}, 0, len(m))
722-
for _, colName := range columns {
723-
args = append(args, m[colName])
776+
qm = qm[:len(qm)-1]
777+
778+
var sql string
779+
780+
if session.statement.cond.IsValid() {
781+
qm = "(" + qm[:len(qm)-1] + ")"
782+
condSQL, condArgs, err := builder.ToSQL(session.statement.cond)
783+
if err != nil {
784+
return 0, err
785+
}
786+
sql = fmt.Sprintf("INSERT INTO %s (`%s`) SELECT %s FROM %s WHERE %s",
787+
session.engine.Quote(tableName),
788+
strings.Join(columns, "`,`"),
789+
qm,
790+
session.engine.Quote(tableName),
791+
condSQL,
792+
)
793+
args = append(args, condArgs...)
794+
} else {
795+
sql = fmt.Sprintf("INSERT INTO %s (`%s`) VALUES (%s)", session.engine.Quote(tableName), strings.Join(columns, "`,`"), qm)
724796
}
725797

726798
if err := session.cacheInsert(tableName); err != nil {

session_insert_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,3 +834,62 @@ func TestInsertMap(t *testing.T) {
834834
assert.EqualValues(t, 10, ims[3].Height)
835835
assert.EqualValues(t, "lunny", ims[3].Name)
836836
}
837+
838+
/*INSERT INTO `issue` (`repo_id`, `poster_id`, ... ,`name`, `content`, ... ,`index`)
839+
SELECT $1, $2, ..., $14, $15, ..., MAX(`index`) + 1 FROM `issue` WHERE `repo_id` = $1;
840+
*/
841+
func TestInsertWhere(t *testing.T) {
842+
type InsertWhere struct {
843+
Id int64
844+
Index int `xorm:"unique(s) notnull"`
845+
RepoId int64 `xorm:"unique(s)"`
846+
Width uint32
847+
Height uint32
848+
Name string
849+
}
850+
851+
assert.NoError(t, prepareEngine())
852+
assertSync(t, new(InsertWhere))
853+
854+
var i = InsertWhere{
855+
RepoId: 1,
856+
Width: 10,
857+
Height: 20,
858+
Name: "trest",
859+
}
860+
861+
inserted, err := testEngine.SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
862+
Where("repo_id=?", 1).
863+
Insert(&i)
864+
assert.NoError(t, err)
865+
assert.EqualValues(t, 1, inserted)
866+
assert.EqualValues(t, 1, i.Id)
867+
868+
var j InsertWhere
869+
has, err := testEngine.ID(i.Id).Get(&j)
870+
assert.NoError(t, err)
871+
assert.True(t, has)
872+
i.Index = 1
873+
assert.EqualValues(t, i, j)
874+
875+
inserted, err = testEngine.Table(new(InsertWhere)).Where("repo_id=?", 1).
876+
SetExpr("`index`", "coalesce(MAX(`index`),0)+1").
877+
Insert(map[string]interface{}{
878+
"repo_id": 1,
879+
"width": 20,
880+
"height": 40,
881+
"name": "trest2",
882+
})
883+
assert.NoError(t, err)
884+
assert.EqualValues(t, 1, inserted)
885+
886+
var j2 InsertWhere
887+
has, err = testEngine.ID(2).Get(&j2)
888+
assert.NoError(t, err)
889+
assert.True(t, has)
890+
assert.EqualValues(t, 1, j2.RepoId)
891+
assert.EqualValues(t, 20, j2.Width)
892+
assert.EqualValues(t, 40, j2.Height)
893+
assert.EqualValues(t, "trest2", j2.Name)
894+
assert.EqualValues(t, 2, j2.Index)
895+
}

0 commit comments

Comments
 (0)