Skip to content

Commit 1114558

Browse files
chenco47ido50
authored andcommitted
Support Update from Values
This commit adds the ability to generate an update statement from VALUES and not from a specific SELECT. It can be handy once someone wants to update many columns of many rows in the same time.
1 parent abb9e99 commit 1114558

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

update.go

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ type UpdateStmt struct {
1919
execer Ext
2020
SelectStmt *SelectStmt
2121
SelectStmtAlias string
22+
MultipleValues MultipleValues
23+
}
24+
25+
type MultipleValues struct {
26+
Values [][]interface{}
27+
As string
28+
Columns []string
29+
Where []WhereCondition
2230
}
2331

2432
// Update creates a new UpdateStmt object for
@@ -131,15 +139,44 @@ func (stmt *UpdateStmt) ToSQL(rebind bool) (asSQL string, bindings []interface{}
131139
}
132140
}
133141

142+
if len(stmt.Updates) == 0 && len(stmt.MultipleValues.Columns) > 0 {
143+
// add the set columns
144+
for _, column := range stmt.MultipleValues.Columns {
145+
updates = append(updates,
146+
fmt.Sprintf("%s.%s = %s.%s", stmt.Table, column, stmt.MultipleValues.As, column))
147+
}
148+
}
149+
134150
clauses = append(clauses, "SET "+strings.Join(updates, ", "))
135151

136152
if stmt.SelectStmt != nil && stmt.SelectStmtAlias != "" {
137153
selectSQL, selectBindings := stmt.SelectStmt.ToSQL(false)
138154
selectSQL = "(" + selectSQL + ") AS " + stmt.SelectStmtAlias + " "
139-
140155
clauses = append(clauses, "FROM ")
141156
clauses = append(clauses, selectSQL)
142157
bindings = append(bindings, selectBindings...)
158+
} else if len(stmt.MultipleValues.Values) > 0 {
159+
// add the FROM
160+
clauses = append(clauses, "FROM")
161+
var multipleValues []string
162+
for _, multipleVals := range stmt.MultipleValues.Values {
163+
placeholders, bindingsToAdd := parseInsertValues(multipleVals)
164+
bindings = append(bindings, bindingsToAdd...)
165+
multipleValues = append(multipleValues, "("+strings.Join(placeholders, ", ")+")")
166+
}
167+
168+
clauses = append(clauses, fmt.Sprintf("(VALUES %s) AS %s(%s)",
169+
strings.Join(multipleValues, ", "),
170+
stmt.MultipleValues.As,
171+
strings.Join(stmt.MultipleValues.Columns, ", "),
172+
))
173+
174+
if len(stmt.MultipleValues.Where) > 0 {
175+
whereClause, whereBindings := parseConditions(stmt.MultipleValues.Where)
176+
bindings = append(bindings, whereBindings...)
177+
clauses = append(clauses, fmt.Sprintf("WHERE %s", whereClause))
178+
}
179+
143180
}
144181

145182
if len(stmt.Conditions) > 0 {
@@ -239,6 +276,16 @@ func (stmt *UpdateStmt) GetAllContext(ctx context.Context, into interface{}) err
239276
return err
240277
}
241278

279+
// FromValues receives an array of interfaces in order to insert multiple records using the same insert statement
280+
func (stmt *UpdateStmt) FromValues(mv MultipleValues) *UpdateStmt {
281+
stmt.MultipleValues.Values = append(stmt.MultipleValues.Values, mv.Values...)
282+
stmt.MultipleValues.As = mv.As
283+
stmt.MultipleValues.Columns = mv.Columns
284+
stmt.MultipleValues.Where = mv.Where
285+
286+
return stmt
287+
}
288+
242289
// UpdateFunction represents a function call in the context of
243290
// updating a column's value. For example, PostgreSQL provides
244291
// functions to append, prepend or remove items from array

update_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,17 @@ func TestUpdate(t *testing.T) {
6262
"UPDATE table SET something = replace(something, ?, '')",
6363
[]interface{}{"prefix/"},
6464
},
65+
{
66+
"update from values",
67+
dbz.Update("table").FromValues(MultipleValues{
68+
Values: [][]interface{}{{"Tom", 20}, {"John", 3}},
69+
As: "values",
70+
Columns: []string{"name", "age"},
71+
Where: []WhereCondition{Eq("values.name", Indirect("table.name"))},
72+
}),
73+
"UPDATE table SET table.name = values.name, table.age = values.age FROM (VALUES (?, ?), (?, ?)) AS values(name, age) WHERE values.name = table.name",
74+
[]interface{}{"Tom", 20, "John", 3},
75+
},
6576
}
6677
})
6778
}

0 commit comments

Comments
 (0)