Skip to content

Commit 5588d0a

Browse files
committed
implement CTEBuilder
1 parent 9fc30d9 commit 5588d0a

File tree

7 files changed

+222
-10
lines changed

7 files changed

+222
-10
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ This package includes following pre-defined builders so far. API document and ex
8181
- [UpdateBuilder](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#UpdateBuilder): Builder for UPDATE.
8282
- [DeleteBuilder](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#DeleteBuilder): Builder for DELETE.
8383
- [UnionBuilder](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#UnionBuilder): Builder for UNION and UNION ALL.
84+
- [CTEBuilder](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#CTEBuilder): Builder for Common Table Expression (CTE), e.g. `WITH name (col1, col2) AS (SELECT ...)`.
8485
- [Buildf](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#Buildf): Freestyle builder using `fmt.Sprintf`-like syntax.
8586
- [Build](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#Build): Advanced freestyle builder using special syntax defined in [Args#Compile](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#Args.Compile).
8687
- [BuildNamed](https://pkg.go.dev/github.com/huandu/go-sqlbuilder#BuildNamed): Advanced freestyle builder using `${key}` to refer the value of a map by key.

cte.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// Copyright 2024 Huan Du. All rights reserved.
2+
// Licensed under the MIT license that can be found in the LICENSE file.
3+
4+
package sqlbuilder
5+
6+
const (
7+
cteMarkerInit injectionMarker = iota
8+
cteMarkerAfterWith
9+
cteMarkerAfterAs
10+
)
11+
12+
// With creates a new CTE builder with default flavor.
13+
func With(name string, cols ...string) *CTEBuilder {
14+
return DefaultFlavor.NewCTEBuilder().With(name, cols...)
15+
}
16+
17+
func newCTEBuilder() *CTEBuilder {
18+
return &CTEBuilder{
19+
args: &Args{},
20+
injection: newInjection(),
21+
}
22+
}
23+
24+
// CTEBuilder is a CTE (Common Table Expression) builder.
25+
type CTEBuilder struct {
26+
name string
27+
cols []string
28+
builderVar string
29+
30+
args *Args
31+
32+
injection *injection
33+
marker injectionMarker
34+
}
35+
36+
var _ Builder = new(CTEBuilder)
37+
38+
// With sets the CTE name and columns.
39+
func (cteb *CTEBuilder) With(name string, cols ...string) *CTEBuilder {
40+
cteb.name = name
41+
cteb.cols = cols
42+
cteb.marker = cteMarkerAfterWith
43+
return cteb
44+
}
45+
46+
// As sets the builder to select data.
47+
func (cteb *CTEBuilder) As(builder Builder) *CTEBuilder {
48+
cteb.builderVar = cteb.args.Add(builder)
49+
cteb.marker = cteMarkerAfterAs
50+
return cteb
51+
}
52+
53+
// Select creates a new SelectBuilder to build a SELECT statement using this CTE.
54+
func (cteb *CTEBuilder) Select(col ...string) *SelectBuilder {
55+
sb := cteb.args.Flavor.NewSelectBuilder()
56+
return sb.With(cteb).Select(col...)
57+
}
58+
59+
// String returns the compiled CTE string.
60+
func (cteb *CTEBuilder) String() string {
61+
sql, _ := cteb.Build()
62+
return sql
63+
}
64+
65+
// Build returns compiled CTE string and args.
66+
func (cteb *CTEBuilder) Build() (sql string, args []interface{}) {
67+
return cteb.BuildWithFlavor(cteb.args.Flavor)
68+
}
69+
70+
// BuildWithFlavor builds a CTE with the specified flavor and initial arguments.
71+
func (cteb *CTEBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) {
72+
buf := newStringBuilder()
73+
cteb.injection.WriteTo(buf, cteMarkerInit)
74+
75+
if cteb.name != "" {
76+
buf.WriteLeadingString("WITH ")
77+
buf.WriteString(cteb.name)
78+
79+
if len(cteb.cols) > 0 {
80+
buf.WriteLeadingString("(")
81+
buf.WriteStrings(cteb.cols, ", ")
82+
buf.WriteString(")")
83+
}
84+
85+
cteb.injection.WriteTo(buf, cteMarkerAfterWith)
86+
}
87+
88+
if cteb.builderVar != "" {
89+
buf.WriteLeadingString("AS (")
90+
buf.WriteString(cteb.builderVar)
91+
buf.WriteRune(')')
92+
93+
cteb.injection.WriteTo(buf, cteMarkerAfterAs)
94+
}
95+
96+
return cteb.args.CompileWithFlavor(buf.String(), flavor, initialArg...)
97+
}
98+
99+
// SetFlavor sets the flavor of compiled sql.
100+
func (cteb *CTEBuilder) SetFlavor(flavor Flavor) (old Flavor) {
101+
old = cteb.args.Flavor
102+
cteb.args.Flavor = flavor
103+
return
104+
}
105+
106+
// Var returns a placeholder for value.
107+
func (cteb *CTEBuilder) Var(arg interface{}) string {
108+
return cteb.args.Add(arg)
109+
}
110+
111+
// SQL adds an arbitrary sql to current position.
112+
func (cteb *CTEBuilder) SQL(sql string) *CTEBuilder {
113+
cteb.injection.SQL(cteb.marker, sql)
114+
return cteb
115+
}
116+
117+
// TableName returns the CTE table name.
118+
func (cteb *CTEBuilder) TableName() string {
119+
return cteb.name
120+
}

cte_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright 2024 Huan Du. All rights reserved.
2+
// Licensed under the MIT license that can be found in the LICENSE file.
3+
4+
package sqlbuilder
5+
6+
import (
7+
"fmt"
8+
"testing"
9+
10+
"github.com/huandu/go-assert"
11+
)
12+
13+
func ExampleWith() {
14+
sb := With("users", "id", "name").As(
15+
Select("id", "name").From("users").Where("name IS NOT NULL"),
16+
).Select("users.id", "orders.id").Join("orders", "users.id = orders.user_id")
17+
18+
fmt.Println(sb)
19+
20+
// Output:
21+
// WITH users (id, name) AS (SELECT id, name FROM users WHERE name IS NOT NULL) SELECT users.id, orders.id FROM users JOIN orders ON users.id = orders.user_id
22+
}
23+
24+
func ExampleCTEBuilder() {
25+
usersBuilder := Select("id", "name", "level").From("users")
26+
usersBuilder.Where(
27+
usersBuilder.GreaterEqualThan("level", 10),
28+
)
29+
cteb := With("valid_users").As(usersBuilder)
30+
fmt.Println(cteb)
31+
32+
sb := Select("valid_users.id", "valid_users.name", "orders.id").With(cteb)
33+
sb.Join("orders", "valid_users.id = orders.user_id")
34+
sb.Where(
35+
sb.LessEqualThan("orders.price", 200),
36+
"valid_users.level < orders.min_level",
37+
).OrderBy("orders.price").Desc()
38+
39+
sql, args := sb.Build()
40+
fmt.Println(sql)
41+
fmt.Println(args)
42+
43+
// Output:
44+
// WITH valid_users AS (SELECT id, name, level FROM users WHERE level >= ?)
45+
// WITH valid_users AS (SELECT id, name, level FROM users WHERE level >= ?) SELECT valid_users.id, valid_users.name, orders.id FROM valid_users JOIN orders ON valid_users.id = orders.user_id WHERE orders.price <= ? AND valid_users.level < orders.min_level ORDER BY orders.price DESC
46+
// [10 200]
47+
}
48+
49+
func TestCTEBuilder(t *testing.T) {
50+
a := assert.New(t)
51+
cteb := newCTEBuilder()
52+
cteb.SQL("/* init */")
53+
cteb.With("t", "a", "b")
54+
cteb.SQL("/* after with */")
55+
56+
// Make sure that calling Var() will not affect the As().
57+
cteb.Var(123)
58+
59+
cteb.As(Select("a", "b").From("t"))
60+
cteb.SQL("/* after as */")
61+
62+
sql, args := cteb.Build()
63+
a.Equal(sql, "/* init */ WITH t (a, b) /* after with */ AS (SELECT a, b FROM t) /* after as */")
64+
a.Assert(args == nil)
65+
}

flavor.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ func (f Flavor) NewUnionBuilder() *UnionBuilder {
141141
return b
142142
}
143143

144+
// NewCTEBuilder creates a new CTE builder with flavor.
145+
func (f Flavor) NewCTEBuilder() *CTEBuilder {
146+
b := newCTEBuilder()
147+
b.SetFlavor(f)
148+
return b
149+
}
150+
144151
// Quote adds quote for name to make sure the name can be used safely
145152
// as table name or field name.
146153
//

select.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111

1212
const (
1313
selectMarkerInit injectionMarker = iota
14+
selectMarkerAfterWith
1415
selectMarkerAfterSelect
1516
selectMarkerAfterFrom
1617
selectMarkerAfterJoin
@@ -65,6 +66,7 @@ type SelectBuilder struct {
6566
whereClauseProxy *whereClauseProxy
6667
whereClauseExpr string
6768

69+
cteBuilder string
6870
distinct bool
6971
tables []string
7072
selectCols []string
@@ -92,6 +94,14 @@ func Select(col ...string) *SelectBuilder {
9294
return DefaultFlavor.NewSelectBuilder().Select(col...)
9395
}
9496

97+
// With sets WITH clause (the Common Table Expression) before SELECT.
98+
func (sb *SelectBuilder) With(builder *CTEBuilder) *SelectBuilder {
99+
sb.marker = selectMarkerAfterWith
100+
sb.cteBuilder = sb.Var(builder)
101+
sb.tables = []string{builder.TableName()}
102+
return sb
103+
}
104+
95105
// Select sets columns in SELECT.
96106
func (sb *SelectBuilder) Select(col ...string) *SelectBuilder {
97107
sb.selectCols = col
@@ -269,6 +279,11 @@ func (sb *SelectBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{
269279

270280
oraclePage := flavor == Oracle && (sb.limit >= 0 || sb.offset >= 0)
271281

282+
if sb.cteBuilder != "" {
283+
buf.WriteLeadingString(sb.cteBuilder)
284+
sb.injection.WriteTo(buf, selectMarkerAfterWith)
285+
}
286+
272287
if len(sb.selectCols) > 0 {
273288
buf.WriteLeadingString("SELECT ")
274289

struct_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,9 +638,7 @@ func ExampleStruct_buildDELETE() {
638638

639639
// Prepare DELETE query.
640640
user := &User{
641-
ID: 1234,
642-
Name: "Huan Du",
643-
Status: 1,
641+
ID: 1234,
644642
}
645643
b := userStruct.DeleteFrom("user")
646644
b.Where(b.Equal("id", user.ID))

union.go

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func newUnionBuilder() *UnionBuilder {
3737
// UnionBuilder is a builder to build UNION.
3838
type UnionBuilder struct {
3939
opt string
40-
builders []Builder
40+
builderVars []string
4141
orderByCols []string
4242
order string
4343
limit int
@@ -72,8 +72,14 @@ func (ub *UnionBuilder) UnionAll(builders ...Builder) *UnionBuilder {
7272
}
7373

7474
func (ub *UnionBuilder) union(opt string, builders ...Builder) *UnionBuilder {
75+
builderVars := make([]string, 0, len(builders))
76+
77+
for _, b := range builders {
78+
builderVars = append(builderVars, ub.Var(b))
79+
}
80+
7581
ub.opt = opt
76-
ub.builders = builders
82+
ub.builderVars = builderVars
7783
ub.marker = unionMarkerAfterUnion
7884
return ub
7985
}
@@ -131,25 +137,25 @@ func (ub *UnionBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}
131137
buf := newStringBuilder()
132138
ub.injection.WriteTo(buf, unionMarkerInit)
133139

134-
if len(ub.builders) > 0 {
140+
if len(ub.builderVars) > 0 {
135141
needParen := flavor != SQLite
136142

137143
if needParen {
138144
buf.WriteLeadingString("(")
139-
buf.WriteString(ub.Var(ub.builders[0]))
145+
buf.WriteString(ub.builderVars[0])
140146
buf.WriteRune(')')
141147
} else {
142-
buf.WriteLeadingString(ub.Var(ub.builders[0]))
148+
buf.WriteLeadingString(ub.builderVars[0])
143149
}
144150

145-
for _, b := range ub.builders[1:] {
151+
for _, b := range ub.builderVars[1:] {
146152
buf.WriteString(ub.opt)
147153

148154
if needParen {
149155
buf.WriteRune('(')
150156
}
151157

152-
buf.WriteString(ub.Var(b))
158+
buf.WriteString(b)
153159

154160
if needParen {
155161
buf.WriteRune(')')

0 commit comments

Comments
 (0)