@@ -3,33 +3,32 @@ package walk
33import (
44 "fmt"
55 "log"
6- "sort"
76 "strings"
87
98 "github.com/auxten/postgresql-parser/pkg/sql/parser"
109 "github.com/auxten/postgresql-parser/pkg/sql/sem/tree"
10+ "github.com/auxten/postgresql-parser/pkg/util/set"
1111)
1212
1313type AstWalker struct {
14- unknownNodes []interface {}
14+ UnknownNodes []interface {}
1515 Fn func (ctx interface {}, node interface {}) (stop bool )
1616}
1717type ReferredCols map [string ]int
1818
1919func (rc ReferredCols ) ToList () []string {
2020 cols := make ([]string , len (rc ))
2121 i := 0
22- for k , _ := range rc {
22+ for k := range rc {
2323 cols [i ] = k
2424 i ++
2525 }
26- sort .Strings (cols )
27- return cols
26+ return set .SortDeDup (cols )
2827}
2928
3029func (w * AstWalker ) Walk (stmts parser.Statements , ctx interface {}) (ok bool , err error ) {
3130
32- w .unknownNodes = make ([]interface {}, 0 )
31+ w .UnknownNodes = make ([]interface {}, 0 )
3332 asts := make ([]tree.NodeFormatter , len (stmts ))
3433 for si , stmt := range stmts {
3534 asts [si ] = stmt .AST
@@ -69,8 +68,8 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
6968 walk (node .Left , node .Right )
7069 case * tree.CaseExpr :
7170 walk (node .Expr , node .Else )
72- for _ , w := range node .Whens {
73- walk (w .Cond , w .Val )
71+ for _ , when := range node .Whens {
72+ walk (when .Cond , when .Val )
7473 }
7574 case * tree.RangeCond :
7675 walk (node .Left , node .From , node .To )
@@ -98,6 +97,11 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
9897 walk (expr )
9998 }
10099 case * tree.FamilyTableDef :
100+ case * tree.From :
101+ walk (node .AsOf )
102+ for _ , table := range node .Tables {
103+ walk (table )
104+ }
101105 case * tree.FuncExpr :
102106 if node .WindowDef != nil {
103107 walk (node .WindowDef )
@@ -111,6 +115,12 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
111115 case * tree.NumVal :
112116 case * tree.OnJoinCond :
113117 walk (node .Expr )
118+ case * tree.Order :
119+ walk (node .Expr , node .Table )
120+ case tree.OrderBy :
121+ for _ , order := range node {
122+ walk (order )
123+ }
114124 case * tree.OrExpr :
115125 walk (node .Left , node .Right )
116126 case * tree.ParenExpr :
@@ -126,16 +136,12 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
126136 walk (node .With )
127137 }
128138 if node .OrderBy != nil {
129- for _ , order := range node .OrderBy {
130- walk (order )
131- }
139+ walk (node .OrderBy )
132140 }
133141 if node .Limit != nil {
134142 walk (node .Limit )
135143 }
136144 walk (node .Select )
137- case * tree.Order :
138- walk (node .Expr , node .Table )
139145 case * tree.Limit :
140146 walk (node .Count )
141147 case * tree.SelectClause :
@@ -156,10 +162,7 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
156162 walk (group )
157163 }
158164 }
159- walk (node .From .AsOf )
160- for _ , table := range node .From .Tables {
161- walk (table )
162- }
165+ walk (& node .From )
163166 case tree.SelectExpr :
164167 walk (node .Expr )
165168 case tree.SelectExprs :
@@ -173,6 +176,10 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
173176 case * tree.StrVal :
174177 case * tree.Subquery :
175178 walk (node .Select )
179+ case tree.TableExprs :
180+ for _ , expr := range node {
181+ walk (expr )
182+ }
176183 case * tree.TableName , tree.TableName :
177184 case * tree.Tuple :
178185 for _ , expr := range node .Exprs {
@@ -214,8 +221,8 @@ func (w *AstWalker) Walk(stmts parser.Statements, ctx interface{}) (ok bool, err
214221 walk (expr )
215222 }
216223 default :
217- if w .unknownNodes != nil {
218- w .unknownNodes = append (w .unknownNodes , node )
224+ if w .UnknownNodes != nil {
225+ w .UnknownNodes = append (w .UnknownNodes , node )
219226 }
220227 }
221228 }
@@ -270,8 +277,27 @@ func ColNamesInSelect(sql string) (referredCols ReferredCols, err error) {
270277 if err != nil {
271278 return
272279 }
273- for _ , col := range w .unknownNodes {
280+ for _ , col := range w .UnknownNodes {
274281 log .Printf ("unhandled column type %T" , col )
275282 }
276283 return
277284}
285+
286+ func AllColsContained (set ReferredCols , cols []string ) bool {
287+ if cols == nil {
288+ if set == nil {
289+ return true
290+ } else {
291+ return false
292+ }
293+ }
294+ if len (set ) != len (cols ) {
295+ return false
296+ }
297+ for _ , col := range cols {
298+ if _ , exist := set [col ]; ! exist {
299+ return false
300+ }
301+ }
302+ return true
303+ }
0 commit comments