Skip to content

Commit 51f40ba

Browse files
authored
Merge pull request #212 from huandu/feature/returning-clause-issue-210
feat: implement Returning method for UpdateBuilder and DeleteBuilder
2 parents 38c1feb + 5349146 commit 51f40ba

File tree

4 files changed

+239
-0
lines changed

4 files changed

+239
-0
lines changed

delete.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ const (
1010
deleteMarkerAfterWhere
1111
deleteMarkerAfterOrderBy
1212
deleteMarkerAfterLimit
13+
deleteMarkerAfterReturning
1314
)
1415

1516
// NewDeleteBuilder creates a new DELETE builder.
@@ -47,6 +48,7 @@ type DeleteBuilder struct {
4748
orderByCols []string
4849
order string
4950
limitVar string
51+
returning []string
5052

5153
args *Args
5254

@@ -157,6 +159,14 @@ func (db *DeleteBuilder) Limit(limit int) *DeleteBuilder {
157159
return db
158160
}
159161

162+
// Returning sets returning columns.
163+
// For DBMS that doesn't support RETURNING, e.g. MySQL, it will be ignored.
164+
func (db *DeleteBuilder) Returning(col ...string) *DeleteBuilder {
165+
db.returning = col
166+
db.marker = deleteMarkerAfterReturning
167+
return db
168+
}
169+
160170
// String returns the compiled DELETE string.
161171
func (db *DeleteBuilder) String() string {
162172
s, _ := db.Build()
@@ -218,6 +228,15 @@ func (db *DeleteBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
218228
db.injection.WriteTo(buf, deleteMarkerAfterLimit)
219229
}
220230

231+
if flavor == PostgreSQL || flavor == SQLite {
232+
if len(db.returning) > 0 {
233+
buf.WriteLeadingString("RETURNING ")
234+
buf.WriteStrings(db.returning, ", ")
235+
}
236+
237+
db.injection.WriteTo(buf, deleteMarkerAfterReturning)
238+
}
239+
221240
return db.args.CompileWithFlavor(buf.String(), flavor, initialArg...)
222241
}
223242

delete_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,100 @@ func TestDeleteBuilderGetFlavor(t *testing.T) {
9696
flavor = dbClick.Flavor()
9797
a.Equal(ClickHouse, flavor)
9898
}
99+
100+
func ExampleDeleteBuilder_Returning() {
101+
db := NewDeleteBuilder()
102+
db.DeleteFrom("user")
103+
db.Where(db.Equal("id", 123))
104+
db.Returning("id", "deleted_at")
105+
106+
sql, args := db.BuildWithFlavor(PostgreSQL)
107+
fmt.Println(sql)
108+
fmt.Println(args)
109+
110+
// Output:
111+
// DELETE FROM user WHERE id = $1 RETURNING id, deleted_at
112+
// [123]
113+
}
114+
115+
func TestDeleteBuilderReturning(t *testing.T) {
116+
a := assert.New(t)
117+
db := NewDeleteBuilder()
118+
db.DeleteFrom("user")
119+
db.Where(db.Equal("id", 123))
120+
db.Returning("id", "deleted_at")
121+
122+
sql, _ := db.BuildWithFlavor(MySQL)
123+
a.Equal("DELETE FROM user WHERE id = ?", sql)
124+
125+
sql, _ = db.BuildWithFlavor(PostgreSQL)
126+
a.Equal("DELETE FROM user WHERE id = $1 RETURNING id, deleted_at", sql)
127+
128+
sql, _ = db.BuildWithFlavor(SQLite)
129+
a.Equal("DELETE FROM user WHERE id = ? RETURNING id, deleted_at", sql)
130+
131+
sql, _ = db.BuildWithFlavor(SQLServer)
132+
a.Equal("DELETE FROM user WHERE id = @p1", sql)
133+
134+
sql, _ = db.BuildWithFlavor(CQL)
135+
a.Equal("DELETE FROM user WHERE id = ?", sql)
136+
137+
sql, _ = db.BuildWithFlavor(ClickHouse)
138+
a.Equal("DELETE FROM user WHERE id = ?", sql)
139+
140+
sql, _ = db.BuildWithFlavor(Presto)
141+
a.Equal("DELETE FROM user WHERE id = ?", sql)
142+
143+
// Test with no returning columns
144+
db2 := NewDeleteBuilder()
145+
db2.DeleteFrom("user")
146+
db2.Where(db2.Equal("id", 1))
147+
db2.Returning() // Empty returning
148+
149+
sql, _ = db2.BuildWithFlavor(PostgreSQL)
150+
a.Equal("DELETE FROM user WHERE id = $1", sql)
151+
152+
// Test with single column
153+
db3 := NewDeleteBuilder()
154+
db3.DeleteFrom("user")
155+
db3.Where(db3.Equal("id", 1))
156+
db3.Returning("id")
157+
158+
sql, _ = db3.BuildWithFlavor(PostgreSQL)
159+
a.Equal("DELETE FROM user WHERE id = $1 RETURNING id", sql)
160+
161+
// Test with ORDER BY and LIMIT
162+
db4 := NewDeleteBuilder()
163+
db4.DeleteFrom("user")
164+
db4.Where(db4.Equal("status", 1))
165+
db4.OrderBy("id").Asc()
166+
db4.Limit(5)
167+
db4.Returning("id", "name")
168+
169+
sql, _ = db4.BuildWithFlavor(PostgreSQL)
170+
a.Equal("DELETE FROM user WHERE status = $1 ORDER BY id ASC LIMIT $2 RETURNING id, name", sql)
171+
172+
// Test chaining
173+
db5 := NewDeleteBuilder().DeleteFrom("user").Where("status = 0").Returning("id").Returning("name", "deleted_at")
174+
sql, _ = db5.BuildWithFlavor(PostgreSQL)
175+
a.Equal("DELETE FROM user WHERE status = 0 RETURNING name, deleted_at", sql) // Last Returning call overwrites
176+
177+
// Test SQL injection after RETURNING
178+
db6 := NewDeleteBuilder()
179+
db6.DeleteFrom("user")
180+
db6.Where(db6.Equal("id", 1))
181+
db6.Returning("id", "name")
182+
db6.SQL("/* comment after returning */")
183+
184+
sql, _ = db6.BuildWithFlavor(PostgreSQL)
185+
a.Equal("DELETE FROM user WHERE id = $1 RETURNING id, name /* comment after returning */", sql)
186+
187+
// Test with CTE (WITH clause)
188+
cte := With(CTETable("temp_user").As(Select("id").From("inactive_users")))
189+
db7 := cte.DeleteFrom("user")
190+
db7.Where("user.id IN (SELECT id FROM temp_user)")
191+
db7.Returning("id", "deleted_at")
192+
193+
sql, _ = db7.BuildWithFlavor(PostgreSQL)
194+
a.Equal("WITH temp_user AS (SELECT id FROM inactive_users) DELETE FROM user, temp_user WHERE user.id IN (SELECT id FROM temp_user) RETURNING id, deleted_at", sql)
195+
}

update.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const (
1515
updateMarkerAfterWhere
1616
updateMarkerAfterOrderBy
1717
updateMarkerAfterLimit
18+
updateMarkerAfterReturning
1819
)
1920

2021
// NewUpdateBuilder creates a new UPDATE builder.
@@ -53,6 +54,7 @@ type UpdateBuilder struct {
5354
orderByCols []string
5455
order string
5556
limitVar string
57+
returning []string
5658

5759
args *Args
5860

@@ -216,6 +218,14 @@ func (ub *UpdateBuilder) Limit(limit int) *UpdateBuilder {
216218
return ub
217219
}
218220

221+
// Returning sets returning columns.
222+
// For DBMS that doesn't support RETURNING, e.g. MySQL, it will be ignored.
223+
func (ub *UpdateBuilder) Returning(col ...string) *UpdateBuilder {
224+
ub.returning = col
225+
ub.marker = updateMarkerAfterReturning
226+
return ub
227+
}
228+
219229
// NumAssignment returns the number of assignments to update.
220230
func (ub *UpdateBuilder) NumAssignment() int {
221231
return len(ub.assignments)
@@ -310,6 +320,15 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
310320
ub.injection.WriteTo(buf, updateMarkerAfterLimit)
311321
}
312322

323+
if flavor == PostgreSQL || flavor == SQLite {
324+
if len(ub.returning) > 0 {
325+
buf.WriteLeadingString("RETURNING ")
326+
buf.WriteStrings(ub.returning, ", ")
327+
}
328+
329+
ub.injection.WriteTo(buf, updateMarkerAfterReturning)
330+
}
331+
313332
return ub.args.CompileWithFlavor(buf.String(), flavor, initialArg...)
314333
}
315334

update_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,107 @@ func TestUpdateBuilderGetFlavor(t *testing.T) {
162162
flavor = ubClick.Flavor()
163163
a.Equal(ClickHouse, flavor)
164164
}
165+
166+
func ExampleUpdateBuilder_Returning() {
167+
ub := NewUpdateBuilder()
168+
ub.Update("user")
169+
ub.Set(ub.Assign("name", "Huan Du"))
170+
ub.Where(ub.Equal("id", 123))
171+
ub.Returning("id", "updated_at")
172+
173+
sql, args := ub.BuildWithFlavor(PostgreSQL)
174+
fmt.Println(sql)
175+
fmt.Println(args)
176+
177+
// Output:
178+
// UPDATE user SET name = $1 WHERE id = $2 RETURNING id, updated_at
179+
// [Huan Du 123]
180+
}
181+
182+
func TestUpdateBuilderReturning(t *testing.T) {
183+
a := assert.New(t)
184+
ub := NewUpdateBuilder()
185+
ub.Update("user")
186+
ub.Set(ub.Assign("name", "Huan Du"))
187+
ub.Where(ub.Equal("id", 123))
188+
ub.Returning("id", "updated_at")
189+
190+
sql, _ := ub.BuildWithFlavor(MySQL)
191+
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)
192+
193+
sql, _ = ub.BuildWithFlavor(PostgreSQL)
194+
a.Equal("UPDATE user SET name = $1 WHERE id = $2 RETURNING id, updated_at", sql)
195+
196+
sql, _ = ub.BuildWithFlavor(SQLite)
197+
a.Equal("UPDATE user SET name = ? WHERE id = ? RETURNING id, updated_at", sql)
198+
199+
sql, _ = ub.BuildWithFlavor(SQLServer)
200+
a.Equal("UPDATE user SET name = @p1 WHERE id = @p2", sql)
201+
202+
sql, _ = ub.BuildWithFlavor(CQL)
203+
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)
204+
205+
sql, _ = ub.BuildWithFlavor(ClickHouse)
206+
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)
207+
208+
sql, _ = ub.BuildWithFlavor(Presto)
209+
a.Equal("UPDATE user SET name = ? WHERE id = ?", sql)
210+
211+
// Test with no returning columns
212+
ub2 := NewUpdateBuilder()
213+
ub2.Update("user")
214+
ub2.Set(ub2.Assign("name", "Test"))
215+
ub2.Where(ub2.Equal("id", 1))
216+
ub2.Returning() // Empty returning
217+
218+
sql, _ = ub2.BuildWithFlavor(PostgreSQL)
219+
a.Equal("UPDATE user SET name = $1 WHERE id = $2", sql)
220+
221+
// Test with single column
222+
ub3 := NewUpdateBuilder()
223+
ub3.Update("user")
224+
ub3.Set(ub3.Assign("name", "Test"))
225+
ub3.Where(ub3.Equal("id", 1))
226+
ub3.Returning("id")
227+
228+
sql, _ = ub3.BuildWithFlavor(PostgreSQL)
229+
a.Equal("UPDATE user SET name = $1 WHERE id = $2 RETURNING id", sql)
230+
231+
// Test with ORDER BY and LIMIT
232+
ub4 := NewUpdateBuilder()
233+
ub4.Update("user")
234+
ub4.Set(ub4.Assign("name", "Test"))
235+
ub4.Where(ub4.Equal("status", 1))
236+
ub4.OrderBy("id").Asc()
237+
ub4.Limit(5)
238+
ub4.Returning("id", "name")
239+
240+
sql, _ = ub4.BuildWithFlavor(PostgreSQL)
241+
a.Equal("UPDATE user SET name = $1 WHERE status = $2 ORDER BY id ASC LIMIT $3 RETURNING id, name", sql)
242+
243+
// Test chaining
244+
ub5 := NewUpdateBuilder().Update("user").Set("status = 1").Returning("id").Returning("name", "updated_at")
245+
sql, _ = ub5.BuildWithFlavor(PostgreSQL)
246+
a.Equal("UPDATE user SET status = 1 RETURNING name, updated_at", sql) // Last Returning call overwrites
247+
248+
// Test SQL injection after RETURNING
249+
ub6 := NewUpdateBuilder()
250+
ub6.Update("user")
251+
ub6.Set(ub6.Assign("name", "Test"))
252+
ub6.Where(ub6.Equal("id", 1))
253+
ub6.Returning("id", "name")
254+
ub6.SQL("/* comment after returning */")
255+
256+
sql, _ = ub6.BuildWithFlavor(PostgreSQL)
257+
a.Equal("UPDATE user SET name = $1 WHERE id = $2 RETURNING id, name /* comment after returning */", sql)
258+
259+
// Test with CTE (WITH clause)
260+
cte := With(CTETable("temp_user").As(Select("id").From("active_users")))
261+
ub7 := cte.Update("user")
262+
ub7.Set(ub7.Assign("status", "active"))
263+
ub7.Where("user.id IN (SELECT id FROM temp_user)")
264+
ub7.Returning("id", "status")
265+
266+
sql, _ = ub7.BuildWithFlavor(PostgreSQL)
267+
a.Equal("WITH temp_user AS (SELECT id FROM active_users) UPDATE user SET status = $1 FROM temp_user WHERE user.id IN (SELECT id FROM temp_user) RETURNING id, status", sql)
268+
}

0 commit comments

Comments
 (0)