@@ -19,6 +19,14 @@ type UpdateStmt struct {
1919 execer Ext
2020 SelectStmt * SelectStmt
2121 SelectStmtAlias string
22+ MultipleValues MultipleValues
23+ }
24+
25+ type MultipleValues struct {
26+ Values [][]interface {}
27+ As string
28+ Columns []string
29+ Where []WhereCondition
2230}
2331
2432// Update creates a new UpdateStmt object for
@@ -131,15 +139,44 @@ func (stmt *UpdateStmt) ToSQL(rebind bool) (asSQL string, bindings []interface{}
131139 }
132140 }
133141
142+ if len (stmt .Updates ) == 0 && len (stmt .MultipleValues .Columns ) > 0 {
143+ // add the set columns
144+ for _ , column := range stmt .MultipleValues .Columns {
145+ updates = append (updates ,
146+ fmt .Sprintf ("%s.%s = %s.%s" , stmt .Table , column , stmt .MultipleValues .As , column ))
147+ }
148+ }
149+
134150 clauses = append (clauses , "SET " + strings .Join (updates , ", " ))
135151
136152 if stmt .SelectStmt != nil && stmt .SelectStmtAlias != "" {
137153 selectSQL , selectBindings := stmt .SelectStmt .ToSQL (false )
138154 selectSQL = "(" + selectSQL + ") AS " + stmt .SelectStmtAlias + " "
139-
140155 clauses = append (clauses , "FROM " )
141156 clauses = append (clauses , selectSQL )
142157 bindings = append (bindings , selectBindings ... )
158+ } else if len (stmt .MultipleValues .Values ) > 0 {
159+ // add the FROM
160+ clauses = append (clauses , "FROM" )
161+ var multipleValues []string
162+ for _ , multipleVals := range stmt .MultipleValues .Values {
163+ placeholders , bindingsToAdd := parseInsertValues (multipleVals )
164+ bindings = append (bindings , bindingsToAdd ... )
165+ multipleValues = append (multipleValues , "(" + strings .Join (placeholders , ", " )+ ")" )
166+ }
167+
168+ clauses = append (clauses , fmt .Sprintf ("(VALUES %s) AS %s(%s)" ,
169+ strings .Join (multipleValues , ", " ),
170+ stmt .MultipleValues .As ,
171+ strings .Join (stmt .MultipleValues .Columns , ", " ),
172+ ))
173+
174+ if len (stmt .MultipleValues .Where ) > 0 {
175+ whereClause , whereBindings := parseConditions (stmt .MultipleValues .Where )
176+ bindings = append (bindings , whereBindings ... )
177+ clauses = append (clauses , fmt .Sprintf ("WHERE %s" , whereClause ))
178+ }
179+
143180 }
144181
145182 if len (stmt .Conditions ) > 0 {
@@ -239,6 +276,16 @@ func (stmt *UpdateStmt) GetAllContext(ctx context.Context, into interface{}) err
239276 return err
240277}
241278
279+ // FromValues receives an array of interfaces in order to insert multiple records using the same insert statement
280+ func (stmt * UpdateStmt ) FromValues (mv MultipleValues ) * UpdateStmt {
281+ stmt .MultipleValues .Values = append (stmt .MultipleValues .Values , mv .Values ... )
282+ stmt .MultipleValues .As = mv .As
283+ stmt .MultipleValues .Columns = mv .Columns
284+ stmt .MultipleValues .Where = mv .Where
285+
286+ return stmt
287+ }
288+
242289// UpdateFunction represents a function call in the context of
243290// updating a column's value. For example, PostgreSQL provides
244291// functions to append, prepend or remove items from array
0 commit comments