@@ -9,12 +9,12 @@ const (
99)
1010
1111// With creates a new CTE builder with default flavor.
12- func With (tables ... * CTETableBuilder ) * CTEBuilder {
12+ func With (tables ... * CTEQueryBuilder ) * CTEBuilder {
1313 return DefaultFlavor .NewCTEBuilder ().With (tables ... )
1414}
1515
1616// WithRecursive creates a new recursive CTE builder with default flavor.
17- func WithRecursive (tables ... * CTETableBuilder ) * CTEBuilder {
17+ func WithRecursive (tables ... * CTEQueryBuilder ) * CTEBuilder {
1818 return DefaultFlavor .NewCTEBuilder ().WithRecursive (tables ... )
1919}
2020
@@ -28,8 +28,8 @@ func newCTEBuilder() *CTEBuilder {
2828// CTEBuilder is a CTE (Common Table Expression) builder.
2929type CTEBuilder struct {
3030 recursive bool
31- tableNames [] string
32- tableBuilderVars []string
31+ queries [] * CTEQueryBuilder
32+ queryBuilderVars []string
3333
3434 args * Args
3535
@@ -40,24 +40,22 @@ type CTEBuilder struct {
4040var _ Builder = new (CTEBuilder )
4141
4242// With sets the CTE name and columns.
43- func (cteb * CTEBuilder ) With (tables ... * CTETableBuilder ) * CTEBuilder {
44- tableNames := make ([]string , 0 , len (tables ))
45- tableBuilderVars := make ([]string , 0 , len (tables ))
43+ func (cteb * CTEBuilder ) With (queries ... * CTEQueryBuilder ) * CTEBuilder {
44+ queryBuilderVars := make ([]string , 0 , len (queries ))
4645
47- for _ , table := range tables {
48- tableNames = append (tableNames , table .TableName ())
49- tableBuilderVars = append (tableBuilderVars , cteb .args .Add (table ))
46+ for _ , query := range queries {
47+ queryBuilderVars = append (queryBuilderVars , cteb .args .Add (query ))
5048 }
5149
52- cteb .tableNames = tableNames
53- cteb .tableBuilderVars = tableBuilderVars
50+ cteb .queries = queries
51+ cteb .queryBuilderVars = queryBuilderVars
5452 cteb .marker = cteMarkerAfterWith
5553 return cteb
5654}
5755
5856// WithRecursive sets the CTE name and columns and turns on the RECURSIVE keyword.
59- func (cteb * CTEBuilder ) WithRecursive (tables ... * CTETableBuilder ) * CTEBuilder {
60- cteb .With (tables ... ).recursive = true
57+ func (cteb * CTEBuilder ) WithRecursive (queries ... * CTEQueryBuilder ) * CTEBuilder {
58+ cteb .With (queries ... ).recursive = true
6159 return cteb
6260}
6361
@@ -67,6 +65,18 @@ func (cteb *CTEBuilder) Select(col ...string) *SelectBuilder {
6765 return sb .With (cteb ).Select (col ... )
6866}
6967
68+ // DeleteFrom creates a new DeleteBuilder to build a DELETE statement using this CTE.
69+ func (cteb * CTEBuilder ) DeleteFrom (table string ) * DeleteBuilder {
70+ db := cteb .args .Flavor .NewDeleteBuilder ()
71+ return db .With (cteb ).DeleteFrom (table )
72+ }
73+
74+ // Update creates a new UpdateBuilder to build an UPDATE statement using this CTE.
75+ func (cteb * CTEBuilder ) Update (table string ) * UpdateBuilder {
76+ ub := cteb .args .Flavor .NewUpdateBuilder ()
77+ return ub .With (cteb ).Update (table )
78+ }
79+
7080// String returns the compiled CTE string.
7181func (cteb * CTEBuilder ) String () string {
7282 sql , _ := cteb .Build ()
@@ -83,12 +93,12 @@ func (cteb *CTEBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}
8393 buf := newStringBuilder ()
8494 cteb .injection .WriteTo (buf , cteMarkerInit )
8595
86- if len (cteb .tableBuilderVars ) > 0 {
96+ if len (cteb .queryBuilderVars ) > 0 {
8797 buf .WriteLeadingString ("WITH " )
8898 if cteb .recursive {
8999 buf .WriteString ("RECURSIVE " )
90100 }
91- buf .WriteStrings (cteb .tableBuilderVars , ", " )
101+ buf .WriteStrings (cteb .queryBuilderVars , ", " )
92102 }
93103
94104 cteb .injection .WriteTo (buf , cteMarkerAfterWith )
@@ -110,5 +120,43 @@ func (cteb *CTEBuilder) SQL(sql string) *CTEBuilder {
110120
111121// TableNames returns all table names in a CTE.
112122func (cteb * CTEBuilder ) TableNames () []string {
113- return cteb .tableNames
123+ if len (cteb .queryBuilderVars ) == 0 {
124+ return nil
125+ }
126+
127+ tableNames := make ([]string , 0 , len (cteb .queries ))
128+
129+ for _ , query := range cteb .queries {
130+ tableNames = append (tableNames , query .TableName ())
131+ }
132+
133+ return tableNames
134+ }
135+
136+ // tableNamesForSelect returns a list of table names which should be automatically added to FROM clause.
137+ // It's not public, as this feature is designed only for SelectBuilder right now.
138+ func (cteb * CTEBuilder ) tableNamesForSelect () []string {
139+ cnt := 0
140+
141+ // It's rare that the ShouldAddToTableList() returns true.
142+ // Count it before allocating any memory for better performance.
143+ for _ , query := range cteb .queries {
144+ if query .ShouldAddToTableList () {
145+ cnt ++
146+ }
147+ }
148+
149+ if cnt == 0 {
150+ return nil
151+ }
152+
153+ tableNames := make ([]string , 0 , cnt )
154+
155+ for _ , query := range cteb .queries {
156+ if query .ShouldAddToTableList () {
157+ tableNames = append (tableNames , query .TableName ())
158+ }
159+ }
160+
161+ return tableNames
114162}
0 commit comments