Skip to content

Commit a211fba

Browse files
committed
opt: return RelExpr from (*memo).RootExpr
`(*memo).RootExpr` now returns a `RelExpr` instead of an `Expr`, helping eliminate some type assertions. Release note: None
1 parent 43c294e commit a211fba

File tree

10 files changed

+29
-38
lines changed

10 files changed

+29
-38
lines changed

pkg/sql/opt/exec/execbuilder/scalar.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ func (b *Builder) buildExistsSubquery(
675675
stmtProps := []*physical.Required{{Presentation: physical.Presentation{aliasedCol}}}
676676

677677
// Create an wrapRootExprFn that wraps input in a Limit and a Project.
678-
wrapRootExpr := func(f *norm.Factory, e memo.RelExpr) opt.Expr {
678+
wrapRootExpr := func(f *norm.Factory, e memo.RelExpr) memo.RelExpr {
679679
return f.ConstructProject(
680680
f.ConstructLimit(
681681
e,
@@ -1115,7 +1115,7 @@ func (b *Builder) initRoutineExceptionHandler(
11151115
blockState.ExceptionHandler = exceptionHandler
11161116
}
11171117

1118-
type wrapRootExprFn func(f *norm.Factory, e memo.RelExpr) opt.Expr
1118+
type wrapRootExprFn func(f *norm.Factory, e memo.RelExpr) memo.RelExpr
11191119

11201120
// buildRoutinePlanGenerator returns a tree.RoutinePlanFn that can plan the
11211121
// statements in a routine that has one or more arguments.
@@ -1249,7 +1249,7 @@ func (b *Builder) buildRoutinePlanGenerator(
12491249
f.CopyAndReplace(originalMemo, stmt, props, replaceFn)
12501250

12511251
if wrapRootExpr != nil {
1252-
wrapped := wrapRootExpr(f, f.Memo().RootExpr().(memo.RelExpr)).(memo.RelExpr)
1252+
wrapped := wrapRootExpr(f, f.Memo().RootExpr())
12531253
f.Memo().SetRoot(wrapped, props)
12541254
}
12551255

pkg/sql/opt/memo/memo.go

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ type Memo struct {
132132
// rootExpr is the root expression of the memo expression forest. It is set
133133
// via a call to SetRoot. After optimization, it is set to be the root of the
134134
// lowest cost tree in the forest.
135-
rootExpr opt.Expr
135+
rootExpr RelExpr
136136

137137
// rootProps are the physical properties required of the root memo expression.
138138
// It is set via a call to SetRoot.
@@ -366,7 +366,7 @@ func (m *Memo) Metadata() *opt.Metadata {
366366

367367
// RootExpr returns the root memo expression previously set via a call to
368368
// SetRoot.
369-
func (m *Memo) RootExpr() opt.Expr {
369+
func (m *Memo) RootExpr() RelExpr {
370370
return m.rootExpr
371371
}
372372

@@ -395,12 +395,7 @@ func (m *Memo) SetRoot(e RelExpr, phys *physical.Required) {
395395
// HasPlaceholders returns true if the memo contains at least one placeholder
396396
// operator.
397397
func (m *Memo) HasPlaceholders() bool {
398-
rel, ok := m.rootExpr.(RelExpr)
399-
if !ok {
400-
panic(errors.AssertionFailedf("placeholders only supported when memo root is relational"))
401-
}
402-
403-
return rel.Relational().HasPlaceholder
398+
return m.rootExpr.Relational().HasPlaceholder
404399
}
405400

406401
// IsStale returns true if the memo has been invalidated by changes to any of
@@ -555,8 +550,7 @@ func (m *Memo) ResetCost(e RelExpr, cost Cost) {
555550
func (m *Memo) IsOptimized() bool {
556551
// The memo is optimized once the root expression has its physical properties
557552
// assigned.
558-
rel, ok := m.rootExpr.(RelExpr)
559-
return ok && rel.RequiredPhysical() != nil
553+
return m.rootExpr != nil && m.rootExpr.RequiredPhysical() != nil
560554
}
561555

562556
// OptimizationCost returns a rough estimate of the cost of optimization of the

pkg/sql/opt/memo/memo_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -691,15 +691,15 @@ func TestStatsAvailable(t *testing.T) {
691691

692692
// Stats should not be available for any expression.
693693
opttestutils.BuildQuery(t, &o, catalog, &evalCtx, "SELECT * FROM t WHERE a=1")
694-
testNotAvailable(o.Memo().RootExpr().(memo.RelExpr))
694+
testNotAvailable(o.Memo().RootExpr())
695695

696696
opttestutils.BuildQuery(t, &o, catalog, &evalCtx, "SELECT sum(a), b FROM t GROUP BY b")
697-
testNotAvailable(o.Memo().RootExpr().(memo.RelExpr))
697+
testNotAvailable(o.Memo().RootExpr())
698698

699699
opttestutils.BuildQuery(t, &o, catalog, &evalCtx,
700700
"SELECT * FROM t AS t1, t AS t2 WHERE t1.a = t2.a AND t1.b = 5",
701701
)
702-
testNotAvailable(o.Memo().RootExpr().(memo.RelExpr))
702+
testNotAvailable(o.Memo().RootExpr())
703703

704704
if _, err := catalog.ExecuteDDL(
705705
`ALTER TABLE t INJECT STATISTICS '[
@@ -729,15 +729,15 @@ func TestStatsAvailable(t *testing.T) {
729729

730730
// Stats should be available for all expressions.
731731
opttestutils.BuildQuery(t, &o, catalog, &evalCtx, "SELECT * FROM t WHERE a=1")
732-
testAvailable(o.Memo().RootExpr().(memo.RelExpr))
732+
testAvailable(o.Memo().RootExpr())
733733

734734
opttestutils.BuildQuery(t, &o, catalog, &evalCtx, "SELECT sum(a), b FROM t GROUP BY b")
735-
testAvailable(o.Memo().RootExpr().(memo.RelExpr))
735+
testAvailable(o.Memo().RootExpr())
736736

737737
opttestutils.BuildQuery(t, &o, catalog, &evalCtx,
738738
"SELECT * FROM t AS t1, t AS t2 WHERE t1.a = t2.a AND t1.b = 5",
739739
)
740-
testAvailable(o.Memo().RootExpr().(memo.RelExpr))
740+
testAvailable(o.Memo().RootExpr())
741741
}
742742

743743
// traverseExpr is a helper function to recursively traverse a relational

pkg/sql/opt/norm/factory.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ func (f *Factory) AssignPlaceholders(from *memo.Memo) (err error) {
420420
}
421421
return f.CopyAndReplaceDefault(e, replaceFn)
422422
}
423-
f.CopyAndReplace(from, from.RootExpr().(memo.RelExpr), from.RootProps(), replaceFn)
423+
f.CopyAndReplace(from, from.RootExpr(), from.RootProps(), replaceFn)
424424

425425
return nil
426426
}

pkg/sql/opt/norm/factory_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func TestCopyAndReplace(t *testing.T) {
101101
}
102102
return o.Factory().CopyAndReplaceDefault(e, replaceFn)
103103
}
104-
o.Factory().CopyAndReplace(m, m.RootExpr().(memo.RelExpr), m.RootProps(), replaceFn)
104+
o.Factory().CopyAndReplace(m, m.RootExpr(), m.RootProps(), replaceFn)
105105

106106
if e, err := o.Optimize(); err != nil {
107107
t.Fatal(err)
@@ -146,7 +146,7 @@ func TestCopyAndReplaceWithScan(t *testing.T) {
146146
replaceFn = func(e opt.Expr) opt.Expr {
147147
return o.Factory().CopyAndReplaceDefault(e, replaceFn)
148148
}
149-
o.Factory().CopyAndReplace(m, m.RootExpr().(memo.RelExpr), m.RootProps(), replaceFn)
149+
o.Factory().CopyAndReplace(m, m.RootExpr(), m.RootProps(), replaceFn)
150150

151151
if _, err := o.Optimize(); err != nil {
152152
t.Fatal(err)

pkg/sql/opt/xform/optimizer.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ func (o *Optimizer) Optimize() (_ opt.Expr, err error) {
261261
o.optimizeRootWithProps()
262262

263263
// Now optimize the entire expression tree.
264-
root := o.mem.RootExpr().(memo.RelExpr)
264+
root := o.mem.RootExpr()
265265
rootProps := o.mem.RootProps()
266266
o.optimizeGroup(root, rootProps)
267267

@@ -922,10 +922,7 @@ func (o *Optimizer) ensureOptState(grp memo.RelExpr, required *physical.Required
922922
// properties required of it. This may trigger the creation of a new root and
923923
// new properties.
924924
func (o *Optimizer) optimizeRootWithProps() {
925-
root, ok := o.mem.RootExpr().(memo.RelExpr)
926-
if !ok {
927-
panic(errors.AssertionFailedf("Optimize can only be called on relational root expressions"))
928-
}
925+
root := o.mem.RootExpr()
929926
rootProps := o.mem.RootProps()
930927

931928
// [SimplifyRootOrdering]

pkg/sql/opt/xform/optimizer_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func TestDetachMemoRace(t *testing.T) {
113113
// Rewrite the filter to use a different column, which will trigger creation
114114
// of new table statistics. If the statistics object is aliased, this will
115115
// be racy.
116-
f.CopyAndReplace(mem, mem.RootExpr().(memo.RelExpr), mem.RootProps(), replaceFn)
116+
f.CopyAndReplace(mem, mem.RootExpr(), mem.RootProps(), replaceFn)
117117
wg.Done()
118118
}()
119119
}

pkg/sql/opt/xform/placeholder_fast_path.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func (o *Optimizer) TryPlaceholderFastPath() (ok bool, err error) {
4545
}
4646
}()
4747

48-
root := o.mem.RootExpr().(memo.RelExpr)
48+
root := o.mem.RootExpr()
4949

5050
rootRelProps := root.Relational()
5151
// We are dealing with a memo that still contains placeholders. The statistics

pkg/sql/plan_opt.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ func (opc *optPlanningCtx) reuseMemo(cachedMemo *memo.Memo) (*memo.Memo, error)
590590
opc.flags.Set(planFlagOptimized)
591591
mem := f.Memo()
592592
if prep := opc.p.stmt.Prepared; opc.allowMemoReuse && prep != nil {
593-
costWithOptimizationCost := mem.RootExpr().(memo.RelExpr).Cost()
593+
costWithOptimizationCost := mem.RootExpr().Cost()
594594
costWithOptimizationCost.Add(mem.OptimizationCost())
595595
prep.Costs.AddCustom(costWithOptimizationCost)
596596
}
@@ -772,7 +772,7 @@ func (opc *optPlanningCtx) fetchPreparedMemo(ctx context.Context) (_ *memo.Memo,
772772
prep.IdealGenericPlan = true
773773
case memoTypeGeneric:
774774
prep.GenericMemo = newMemo
775-
prep.Costs.SetGeneric(newMemo.RootExpr().(memo.RelExpr).Cost())
775+
prep.Costs.SetGeneric(newMemo.RootExpr().Cost())
776776
// Now that the cost of the generic plan is known, we need to
777777
// re-evaluate the decision to use a generic or custom plan.
778778
if !opc.chooseGenericPlan() {
@@ -961,11 +961,11 @@ func (opc *optPlanningCtx) runExecBuilder(
961961
if opc.gf.Initialized() {
962962
planTop.instrumentation.planGist = opc.gf.PlanGist()
963963
}
964-
planTop.instrumentation.costEstimate = mem.RootExpr().(memo.RelExpr).Cost().C
965-
available := mem.RootExpr().(memo.RelExpr).Relational().Statistics().Available
964+
planTop.instrumentation.costEstimate = mem.RootExpr().Cost().C
965+
available := mem.RootExpr().Relational().Statistics().Available
966966
planTop.instrumentation.statsAvailable = available
967967
if available {
968-
planTop.instrumentation.outputRows = mem.RootExpr().(memo.RelExpr).Relational().Statistics().RowCount
968+
planTop.instrumentation.outputRows = mem.RootExpr().Relational().Statistics().RowCount
969969
}
970970

971971
if stmt.ExpectedTypes != nil {
@@ -1045,7 +1045,7 @@ func (opc *optPlanningCtx) makeQueryIndexRecommendation(
10451045
f.FoldingControl().AllowStableFolds()
10461046
f.CopyAndReplace(
10471047
savedMemo,
1048-
savedMemo.RootExpr().(memo.RelExpr),
1048+
savedMemo.RootExpr(),
10491049
savedMemo.RootProps(),
10501050
f.CopyWithoutAssigningPlaceholders,
10511051
)
@@ -1066,7 +1066,7 @@ func (opc *optPlanningCtx) makeQueryIndexRecommendation(
10661066
opc.optimizer.Init(ctx, f.EvalContext(), opc.catalog)
10671067
f.CopyAndReplace(
10681068
savedMemo,
1069-
savedMemo.RootExpr().(memo.RelExpr),
1069+
savedMemo.RootExpr(),
10701070
savedMemo.RootProps(),
10711071
f.CopyWithoutAssigningPlaceholders,
10721072
)
@@ -1091,7 +1091,7 @@ func (opc *optPlanningCtx) makeQueryIndexRecommendation(
10911091
savedMemo.Metadata().UpdateTableMeta(origCtx, f.EvalContext(), optTables)
10921092
f.CopyAndReplace(
10931093
savedMemo,
1094-
savedMemo.RootExpr().(memo.RelExpr),
1094+
savedMemo.RootExpr(),
10951095
savedMemo.RootProps(),
10961096
f.CopyWithoutAssigningPlaceholders,
10971097
)

pkg/sql/reference_provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ func (f *referenceProviderFactory) NewReferenceProvider(
132132
case *memo.CreateTriggerExpr:
133133
planDeps, typeDeps, funcDeps, err = toPlanDependencies(t.Deps, t.TypeDeps, t.FuncDeps)
134134
default:
135-
return nil, errors.AssertionFailedf("unexpected root expression: %s", t.(memo.RelExpr).Op())
135+
return nil, errors.AssertionFailedf("unexpected root expression: %s", t.Op())
136136
}
137137
if err != nil {
138138
return nil, err

0 commit comments

Comments
 (0)