Skip to content

Commit 7fd87d1

Browse files
committed
tree: also visit Statement nodes in ExtendedVisitor
Similar to 300e81b, we also need to visit Statement nodes during hint injection, so add them to ExtendedVisitor. This commmit adds VisitStatementPre and VisitStatementPost to ExtendedVisitor and extendedSimpleVisitor. It also converts debugVisitor from a Visitor to an ExtendedVisitor which will provide more information. (ExprDebugString is currently only called in one place, but I have some ideas for how to use these in the future.) Informs: #153633 Release note: None
1 parent a5660a6 commit 7fd87d1

File tree

1 file changed

+110
-27
lines changed

1 file changed

+110
-27
lines changed

pkg/sql/sem/tree/walk.go

Lines changed: 110 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ type Visitor interface {
4141
//
4242
// VisitPost visits Exprs but not TableExprs. For TableExprs, VisitTablePost
4343
// must be used.
44-
VisitPost(expr Expr) (newNode Expr)
44+
VisitPost(expr Expr) (newExpr Expr)
4545
}
4646

4747
// ExtendedVisitor extends Visitor with methods that are called for TableExpr
48-
// nodes during an expression or statement walk.
48+
// nodes and Statement nodes during an expression or statement walk.
4949
//
5050
// Unlike Visitor, which does not visit some parts of the AST for historical
51-
// reasons, ExtendedVisitor is intended to visit every part of the tree.
51+
// reasons, ExtendedVisitor is intended to visit every node in the tree. (If a
52+
// node is missing, please add it.)
5253
type ExtendedVisitor interface {
5354
Visitor
5455

@@ -65,7 +66,22 @@ type ExtendedVisitor interface {
6566
// used for rewriting expressions.
6667
//
6768
// VisitTablePost is identical to VisitPost but handles TableExpr nodes.
68-
VisitTablePost(expr TableExpr) (newNode TableExpr)
69+
VisitTablePost(expr TableExpr) (newExpr TableExpr)
70+
71+
// VisitStatementPre is called for each Statement node before recursing into
72+
// that subtree. Upon return, if recurse if false, the visit will not recurse
73+
// into the subtree (and VisitStatementPost will node be called for this
74+
// Statement node).
75+
//
76+
// VisitStatementPre is identical to VisitPre but handles Statement nodes.
77+
VisitStatementPre(expr Statement) (recurse bool, newExpr Statement)
78+
79+
// VisitStatementPost is called for each Statement node after recursing into
80+
// the subtree. The returned Statement replaces the visited expression and can
81+
// be used for rewriting expressions.
82+
//
83+
// VisitStatementPost is identical to VisitPost but handles Statement nodes.
84+
VisitStatementPost(expr Statement) (newExpr Statement)
6985
}
7086

7187
// Walk implements the Expr interface.
@@ -2147,12 +2163,21 @@ var _ walkableStmt = &ValuesClause{}
21472163
// statement by itself. For example, it will not walk into Subquery nodes within
21482164
// a FROM clause or into a JoinCond (unless using an ExtendedVisitor). Walk's
21492165
// logic is pretty interdependent with the logic for constructing a query plan.
2150-
func WalkStmt(v Visitor, stmt Statement) (newStmt Statement, changed bool) {
2166+
func WalkStmt(v Visitor, stmt Statement) (Statement, bool) {
2167+
if ev, ok := v.(ExtendedVisitor); ok {
2168+
recurse, newStmt := ev.VisitStatementPre(stmt)
2169+
if walkable, ok := newStmt.(walkableStmt); recurse && ok {
2170+
newStmt = walkable.walkStmt(v)
2171+
newStmt = ev.VisitStatementPost(newStmt)
2172+
}
2173+
return newStmt, (stmt != newStmt)
2174+
}
2175+
21512176
walkable, ok := stmt.(walkableStmt)
21522177
if !ok {
21532178
return stmt, false
21542179
}
2155-
newStmt = walkable.walkStmt(v)
2180+
newStmt := walkable.walkStmt(v)
21562181
return newStmt, (stmt != newStmt)
21572182
}
21582183

@@ -2211,7 +2236,8 @@ func SimpleStmtVisit(stmt Statement, preFn SimpleVisitFn) (Statement, error) {
22112236

22122237
type extendedSimpleVisitor struct {
22132238
simpleVisitor
2214-
efn ExtendedSimpleVisitFn
2239+
preTableFn ExtendedSimpleVisitTableFn
2240+
preStmtFn ExtendedSimpleVisitStmtFn
22152241
}
22162242

22172243
var _ ExtendedVisitor = &extendedSimpleVisitor{}
@@ -2220,31 +2246,54 @@ func (ev *extendedSimpleVisitor) VisitTablePre(expr TableExpr) (recurse bool, ne
22202246
if ev.err != nil {
22212247
return false, expr
22222248
}
2223-
recurse, newExpr, ev.err = ev.efn(expr)
2249+
recurse, newExpr, ev.err = ev.preTableFn(expr)
22242250
if ev.err != nil {
22252251
return false, expr
22262252
}
22272253
return recurse, newExpr
22282254
}
22292255

2230-
func (ev *extendedSimpleVisitor) VisitTablePost(expr TableExpr) (newNode TableExpr) { return expr }
2256+
func (ev *extendedSimpleVisitor) VisitTablePost(expr TableExpr) (newExpr TableExpr) { return expr }
22312257

2232-
// ExtendedSimpleVisitFn is a function that is run for every TableExpr node in
2233-
// the VisitTablePre stage; see ExtendedSimpleVisit.
2234-
type ExtendedSimpleVisitFn func(expr TableExpr) (recurse bool, newExpr TableExpr, err error)
2258+
func (ev *extendedSimpleVisitor) VisitStatementPre(
2259+
expr Statement,
2260+
) (recurse bool, newExpr Statement) {
2261+
if ev.err != nil {
2262+
return false, expr
2263+
}
2264+
recurse, newExpr, ev.err = ev.preStmtFn(expr)
2265+
if ev.err != nil {
2266+
return false, expr
2267+
}
2268+
return recurse, newExpr
2269+
}
22352270

2236-
// ExtendedSimpleVisit is a convenience wrapper for visitors that only have
2237-
// VisitPre and VisitTablePre code, and don't return any results except an
2238-
// error. The given functions are called in VisitPre for every Expr node and
2239-
// VisitTablePre for every TableExpr node, respectively. The visitor stops as
2240-
// soon as an error is returned.
2271+
func (ev *extendedSimpleVisitor) VisitStatementPost(expr Statement) (newExpr Statement) {
2272+
return expr
2273+
}
2274+
2275+
// ExtendedSimpleVisitFn and ExtendedSimpleVisitStmtFn are functions that are
2276+
// run for every TableExpr and Statement node, respectively; see
2277+
// ExtendedSimpleVisit.
2278+
type ExtendedSimpleVisitTableFn func(expr TableExpr) (recurse bool, newExpr TableExpr, err error)
2279+
type ExtendedSimpleVisitStmtFn func(expr Statement) (recurse bool, newExpr Statement, err error)
2280+
2281+
// ExtendedSimpleVisit is a convenience wrapper for extended visitors that only
2282+
// have VisitPre, VisitTablePre, and VisitStatementPre code, and don't return
2283+
// any results except an error. The given functions are called in VisitPre for
2284+
// every Expr node, VisitTablePre for every TableExpr node, and
2285+
// VisitStatementPre for every Statement node. The visitor stops as soon as an
2286+
// error is returned.
22412287
//
22422288
// ExtendedSimpleVisit is identical to SimpleVisit but also handles TableExpr
2243-
// nodes.
2289+
// and Statement nodes.
22442290
func ExtendedSimpleVisit(
2245-
expr Expr, preFn SimpleVisitFn, preTableFn ExtendedSimpleVisitFn,
2291+
expr Expr,
2292+
preFn SimpleVisitFn,
2293+
preTableFn ExtendedSimpleVisitTableFn,
2294+
preStmtFn ExtendedSimpleVisitStmtFn,
22462295
) (Expr, error) {
2247-
ev := extendedSimpleVisitor{simpleVisitor{fn: preFn}, preTableFn}
2296+
ev := extendedSimpleVisitor{simpleVisitor{fn: preFn}, preTableFn, preStmtFn}
22482297
newExpr, _ := WalkExpr(&ev, expr)
22492298
if ev.err != nil {
22502299
return nil, ev.err
@@ -2253,14 +2302,18 @@ func ExtendedSimpleVisit(
22532302
}
22542303

22552304
// ExtendedSimpleStmtVisit is a convenience wrapper for visitors that want to
2256-
// visit all part of a statement, only have VisitPre and VisitTablePre code, and
2257-
// don't return any results except an error. The given functions are called in
2258-
// VisitPre for every Expr node and VisitTablePre for every TableExpr node,
2259-
// respectively. The visitor stops as soon as an error is returned.
2305+
// visit all part of a statement, only have VisitPre, VisitTablePre, and
2306+
// VisitStatementPre code, and don't return any results except an error. The
2307+
// given functions are called in VisitPre for every Expr node, VisitTablePre for
2308+
// every TableExpr node, and VisitStatementPre for every Statement node. The
2309+
// visitor stops as soon as an error is returned.
22602310
func ExtendedSimpleStmtVisit(
2261-
stmt Statement, preFn SimpleVisitFn, preTableFn ExtendedSimpleVisitFn,
2311+
stmt Statement,
2312+
preFn SimpleVisitFn,
2313+
preTableFn ExtendedSimpleVisitTableFn,
2314+
preStmtFn ExtendedSimpleVisitStmtFn,
22622315
) (Statement, error) {
2263-
ev := extendedSimpleVisitor{simpleVisitor{fn: preFn}, preTableFn}
2316+
ev := extendedSimpleVisitor{simpleVisitor{fn: preFn}, preTableFn, preStmtFn}
22642317
newStmt, changed := WalkStmt(&ev, stmt)
22652318
if ev.err != nil {
22662319
return nil, ev.err
@@ -2276,7 +2329,7 @@ type debugVisitor struct {
22762329
level int
22772330
}
22782331

2279-
var _ Visitor = &debugVisitor{}
2332+
var _ ExtendedVisitor = &debugVisitor{}
22802333

22812334
func (v *debugVisitor) VisitPre(expr Expr) (recurse bool, newExpr Expr) {
22822335
v.level++
@@ -2293,6 +2346,36 @@ func (v *debugVisitor) VisitPost(expr Expr) Expr {
22932346
return expr
22942347
}
22952348

2349+
func (v *debugVisitor) VisitTablePre(expr TableExpr) (recurse bool, newExpr TableExpr) {
2350+
v.level++
2351+
fmt.Fprintf(&v.buf, "%*s", 2*v.level, " ")
2352+
str := fmt.Sprintf("%#v\n", expr)
2353+
// Remove "parser." to make the string more compact.
2354+
str = strings.Replace(str, "parser.", "", -1)
2355+
v.buf.WriteString(str)
2356+
return true, expr
2357+
}
2358+
2359+
func (v *debugVisitor) VisitTablePost(expr TableExpr) TableExpr {
2360+
v.level--
2361+
return expr
2362+
}
2363+
2364+
func (v *debugVisitor) VisitStatementPre(expr Statement) (recurse bool, newExpr Statement) {
2365+
v.level++
2366+
fmt.Fprintf(&v.buf, "%*s", 2*v.level, " ")
2367+
str := fmt.Sprintf("%#v\n", expr)
2368+
// Remove "parser." to make the string more compact.
2369+
str = strings.Replace(str, "parser.", "", -1)
2370+
v.buf.WriteString(str)
2371+
return true, expr
2372+
}
2373+
2374+
func (v *debugVisitor) VisitStatementPost(expr Statement) Statement {
2375+
v.level--
2376+
return expr
2377+
}
2378+
22962379
// ExprDebugString generates a multi-line debug string with one node per line in
22972380
// Go format.
22982381
func ExprDebugString(expr Expr) string {

0 commit comments

Comments
 (0)