diff --git a/optgen/cmd/support/memo_gen.go b/optgen/cmd/support/memo_gen.go index 1015ccf46c..356f79e3a5 100644 --- a/optgen/cmd/support/memo_gen.go +++ b/optgen/cmd/support/memo_gen.go @@ -42,36 +42,36 @@ func DecodeMemoExprs(path string) (MemoExprs, error) { var _ GenDefs = (*MemoExprs)(nil) type MemoGen struct { - defines []ExprDef - w io.Writer + defns []ExprDef + w io.Writer } -func (g *MemoGen) Generate(defines GenDefs, w io.Writer) { - g.defines = defines.(MemoExprs).Exprs +func (g *MemoGen) Generate(defns GenDefs, w io.Writer) { + g.defns = defns.(MemoExprs).Exprs g.w = w g.genImport() - for _, define := range g.defines { - g.genType(define) - g.genRelInterfaces(define) - - g.genStringer(define) - if define.SourceType != "" { - g.genSourceRelInterface(define) + for _, defn := range g.defns { + g.genType(defn) + g.genRelInterfaces(defn) + + g.genStringer(defn) + g.genFormatter(defn) + if defn.SourceType != "" { + g.genSourceRelInterface(defn) } - if define.Join { - g.genJoinRelInterface(define) - } else if define.Binary { - g.genBinaryGroupInterface(define) - } else if define.Unary { - g.genUnaryGroupInterface(define) + if defn.Join { + g.genJoinRelInterface(defn) + } else if defn.Binary { + g.genBinaryGroupInterface(defn) + } else if defn.Unary { + g.genUnaryGroupInterface(defn) } else { - g.genChildlessGroupInterface(define) + g.genChildlessGroupInterface(defn) } } - g.genFormatters(g.defines) - + g.genBuildRelExpr(g.defns) } func (g *MemoGen) genImport() { @@ -83,108 +83,116 @@ func (g *MemoGen) genImport() { fmt.Fprintf(g.w, ")\n\n") } -func (g *MemoGen) genType(define ExprDef) { - fmt.Fprintf(g.w, "type %s struct {\n", strings.Title(define.Name)) - if define.SourceType != "" { +func (g *MemoGen) genType(defn ExprDef) { + fmt.Fprintf(g.w, "type %s struct {\n", strings.Title(defn.Name)) + if defn.SourceType != "" { fmt.Fprintf(g.w, " *sourceBase\n") - fmt.Fprintf(g.w, " Table %s\n", define.SourceType) - } else if define.Join { + fmt.Fprintf(g.w, " Table %s\n", defn.SourceType) + } else if defn.Join { fmt.Fprintf(g.w, " *JoinBase\n") - } else if define.Unary { + } else if defn.Unary { fmt.Fprintf(g.w, " *relBase\n") fmt.Fprintf(g.w, " Child *ExprGroup\n") - } else if define.Binary { + } else if defn.Binary { fmt.Fprintf(g.w, " *relBase\n") fmt.Fprintf(g.w, " Left *ExprGroup\n") fmt.Fprintf(g.w, " Right *ExprGroup\n") } - for _, attr := range define.Attrs { + for _, attr := range defn.Attrs { fmt.Fprintf(g.w, " %s %s\n", strings.Title(attr[0]), attr[1]) } fmt.Fprintf(g.w, "}\n\n") } -func (g *MemoGen) genRelInterfaces(define ExprDef) { - fmt.Fprintf(g.w, "var _ RelExpr = (*%s)(nil)\n", define.Name) - if define.SourceType != "" { - fmt.Fprintf(g.w, "var _ SourceRel = (*%s)(nil)\n", define.Name) - } else if define.Join { - fmt.Fprintf(g.w, "var _ JoinRel = (*%s)(nil)\n", define.Name) - } else if define.Unary || define.Binary { +func (g *MemoGen) genRelInterfaces(defn ExprDef) { + fmt.Fprintf(g.w, "var _ RelExpr = (*%s)(nil)\n", defn.Name) + fmt.Fprintf(g.w, "var _ fmt.Formatter = (*%s)(nil)\n", defn.Name) + fmt.Fprintf(g.w, "var _ fmt.Stringer = (*%s)(nil)\n", defn.Name) + if defn.SourceType != "" { + fmt.Fprintf(g.w, "var _ SourceRel = (*%s)(nil)\n", defn.Name) + } else if defn.Join { + fmt.Fprintf(g.w, "var _ JoinRel = (*%s)(nil)\n", defn.Name) + } else if defn.Unary || defn.Binary { } else { panic("unreachable") } fmt.Fprintf(g.w, "\n") } -func (g *MemoGen) genScalarInterfaces(define ExprDef) { - fmt.Fprintf(g.w, "var _ ScalarExpr = (*%s)(nil)\n", define.Name) +func (g *MemoGen) genScalarInterfaces(defn ExprDef) { + fmt.Fprintf(g.w, "var _ ScalarExpr = (*%s)(nil)\n", defn.Name) fmt.Fprintf(g.w, "\n") - fmt.Fprintf(g.w, "func (r *%s) ExprId() ScalarExprId {\n", define.Name) - fmt.Fprintf(g.w, " return ScalarExpr%s\n", strings.Title(define.Name)) + fmt.Fprintf(g.w, "func (r *%s) ExprId() ScalarExprId {\n", defn.Name) + fmt.Fprintf(g.w, " return ScalarExpr%s\n", strings.Title(defn.Name)) + fmt.Fprintf(g.w, "}\n\n") +} + +func (g *MemoGen) genStringer(defn ExprDef) { + fmt.Fprintf(g.w, "func (r *%s) String() string {\n", defn.Name) + fmt.Fprintf(g.w, " return fmt.Sprintf(\"%%s\", r)\n") fmt.Fprintf(g.w, "}\n\n") } -func (g *MemoGen) genStringer(define ExprDef) { - fmt.Fprintf(g.w, "func (r *%s) String() string {\n", define.Name) - fmt.Fprintf(g.w, " return FormatExpr(r)\n") +func (g *MemoGen) genFormatter(defn ExprDef) { + fmt.Fprintf(g.w, "func (r *%s) Format(s fmt.State, verb rune) {\n", defn.Name) + fmt.Fprintf(g.w, " FormatExpr(r, s, verb)\n") fmt.Fprintf(g.w, "}\n\n") } -func (g *MemoGen) genSourceRelInterface(define ExprDef) { - fmt.Fprintf(g.w, "func (r *%s) Name() string {\n", define.Name) - if !define.SkipName { +func (g *MemoGen) genSourceRelInterface(defn ExprDef) { + fmt.Fprintf(g.w, "func (r *%s) Name() string {\n", defn.Name) + if !defn.SkipName { fmt.Fprintf(g.w, " return strings.ToLower(r.Table.Name())\n") } else { fmt.Fprintf(g.w, " return \"\"\n") } fmt.Fprintf(g.w, "}\n\n") - fmt.Fprintf(g.w, "func (r *%s) TableId() sql.TableId {\n", define.Name) + fmt.Fprintf(g.w, "func (r *%s) TableId() sql.TableId {\n", defn.Name) fmt.Fprintf(g.w, " return TableIdForSource(r.g.Id)\n") fmt.Fprintf(g.w, "}\n\n") - fmt.Fprintf(g.w, "func (r *%s) TableIdNode() plan.TableIdNode {\n", define.Name) - if define.SkipTableId { + fmt.Fprintf(g.w, "func (r *%s) TableIdNode() plan.TableIdNode {\n", defn.Name) + if defn.SkipTableId { fmt.Fprintf(g.w, " return nil\n") } else { fmt.Fprintf(g.w, " return r.Table\n") } fmt.Fprintf(g.w, "}\n\n") - fmt.Fprintf(g.w, "func (r *%s) OutputCols() sql.Schema {\n", define.Name) + fmt.Fprintf(g.w, "func (r *%s) OutputCols() sql.Schema {\n", defn.Name) fmt.Fprintf(g.w, " return r.Table.Schema()\n") fmt.Fprintf(g.w, "}\n\n") } -func (g *MemoGen) genJoinRelInterface(define ExprDef) { - fmt.Fprintf(g.w, "func (r *%s) JoinPrivate() *JoinBase {\n", define.Name) +func (g *MemoGen) genJoinRelInterface(defn ExprDef) { + fmt.Fprintf(g.w, "func (r *%s) JoinPrivate() *JoinBase {\n", defn.Name) fmt.Fprintf(g.w, " return r.JoinBase\n") fmt.Fprintf(g.w, "}\n\n") } -func (g *MemoGen) genBinaryGroupInterface(define ExprDef) { - fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", define.Name) +func (g *MemoGen) genBinaryGroupInterface(defn ExprDef) { + fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", defn.Name) fmt.Fprintf(g.w, " return []*ExprGroup{r.Left, r.Right}\n") fmt.Fprintf(g.w, "}\n\n") } -func (g *MemoGen) genChildlessGroupInterface(define ExprDef) { - fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", define.Name) +func (g *MemoGen) genChildlessGroupInterface(defn ExprDef) { + fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", defn.Name) fmt.Fprintf(g.w, " return nil\n") fmt.Fprintf(g.w, "}\n\n") } -func (g *MemoGen) genUnaryGroupInterface(define ExprDef) { - fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", define.Name) +func (g *MemoGen) genUnaryGroupInterface(defn ExprDef) { + fmt.Fprintf(g.w, "func (r *%s) Children() []*ExprGroup {\n", defn.Name) fmt.Fprintf(g.w, " return []*ExprGroup{r.Child}\n") fmt.Fprintf(g.w, "}\n\n") - fmt.Fprintf(g.w, "func (r *%s) outputCols() sql.ColSet {\n", define.Name) - switch define.Name { + fmt.Fprintf(g.w, "func (r *%s) outputCols() sql.ColSet {\n", defn.Name) + switch defn.Name { case "Project": fmt.Fprintf(g.w, " return getProjectColset(r)\n") @@ -193,42 +201,14 @@ func (g *MemoGen) genUnaryGroupInterface(define ExprDef) { } fmt.Fprintf(g.w, "}\n\n") - } -func (g *MemoGen) genFormatters(defines []ExprDef) { - // printer - fmt.Fprintf(g.w, "func FormatExpr(r exprType) string {\n") - fmt.Fprintf(g.w, " switch r := r.(type) {\n") - for _, d := range defines { - loweredName := strings.ToLower(d.Name) - fmt.Fprintf(g.w, " case *%s:\n", d.Name) - if loweredName == "indexscan" { - fmt.Fprintf(g.w, " if r.Alias != \"\" {\n") - fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s: %%s\", r.Alias)\n", loweredName) - fmt.Fprintf(g.w, " }\n") - } - if d.SourceType != "" { - fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s: %%s\", r.Name())\n", loweredName) - } else if d.Join || d.Binary { - fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s %%d %%d\", r.Left.Id, r.Right.Id)\n", loweredName) - } else if d.Unary { - fmt.Fprintf(g.w, " return fmt.Sprintf(\"%s: %%d\", r.Child.Id)\n", loweredName) - } else { - panic("unreachable") - } - } - fmt.Fprintf(g.w, " default:\n") - fmt.Fprintf(g.w, " panic(fmt.Sprintf(\"unknown RelExpr type: %%T\", r))\n") - fmt.Fprintf(g.w, " }\n") - fmt.Fprintf(g.w, "}\n\n") - - // to sqlNode +func (g *MemoGen) genBuildRelExpr(defns []ExprDef) { fmt.Fprintf(g.w, "func buildRelExpr(b *ExecBuilder, r RelExpr, children ...sql.Node) (sql.Node, error) {\n") fmt.Fprintf(g.w, " var result sql.Node\n") fmt.Fprintf(g.w, " var err error\n\n") fmt.Fprintf(g.w, " switch r := r.(type) {\n") - for _, d := range defines { + for _, d := range defns { if d.SkipExec { continue } diff --git a/optgen/cmd/support/memo_gen_test.go b/optgen/cmd/support/memo_gen_test.go index 4b96aa4baf..ef1f7f3e3d 100644 --- a/optgen/cmd/support/memo_gen_test.go +++ b/optgen/cmd/support/memo_gen_test.go @@ -12,7 +12,7 @@ func TestMemoGen(t *testing.T) { expected string }{ expected: ` - import ( + import ( "fmt" "strings" "github.com/dolthub/go-mysql-server/sql" @@ -26,10 +26,16 @@ func TestMemoGen(t *testing.T) { } var _ RelExpr = (*hashJoin)(nil) + var _ fmt.Formatter = (*hashJoin)(nil) + var _ fmt.Stringer = (*hashJoin)(nil) var _ JoinRel = (*hashJoin)(nil) func (r *hashJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) + } + + func (r *hashJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *hashJoin) JoinPrivate() *JoinBase { @@ -42,10 +48,16 @@ func TestMemoGen(t *testing.T) { } var _ RelExpr = (*tableScan)(nil) + var _ fmt.Formatter = (*tableScan)(nil) + var _ fmt.Stringer = (*tableScan)(nil) var _ SourceRel = (*tableScan)(nil) func (r *tableScan) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) + } + + func (r *tableScan) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *tableScan) Name() string { @@ -68,17 +80,6 @@ func TestMemoGen(t *testing.T) { return nil } - func FormatExpr(r exprType) string { - switch r := r.(type) { - case *hashJoin: - return fmt.Sprintf("hashjoin %d %d", r.Left.Id, r.Right.Id) - case *tableScan: - return fmt.Sprintf("tablescan: %s", r.Name()) - default: - panic(fmt.Sprintf("unknown RelExpr type: %T", r)) - } - } - func buildRelExpr(b *ExecBuilder, r RelExpr, children ...sql.Node) (sql.Node, error) { var result sql.Node var err error @@ -96,9 +97,9 @@ func TestMemoGen(t *testing.T) { return nil, err } - if withDescribeStats, ok := result.(sql.WithDescribeStats); ok { - withDescribeStats.SetDescribeStats(*DescribeStats(r)) - } + if withDescribeStats, ok := result.(sql.WithDescribeStats); ok { + withDescribeStats.SetDescribeStats(*DescribeStats(r)) + } result, err = r.Group().finalize(result) if err != nil { return nil, err diff --git a/sql/analyzer/analyzer.go b/sql/analyzer/analyzer.go index 56bc882cad..0e710c26e3 100644 --- a/sql/analyzer/analyzer.go +++ b/sql/analyzer/analyzer.go @@ -36,6 +36,7 @@ import ( const debugAnalyzerKey = "DEBUG_ANALYZER" const verboseAnalyzerKey = "VERBOSE_ANALYZER" +const traceAnalyzerKey = "TRACE_ANALYZER" const maxAnalysisIterations = 8 @@ -215,6 +216,7 @@ func (s simpleLogFormatter) Format(entry *logrus.Entry) ([]byte, error) { func (ab *Builder) Build() *Analyzer { _, debug := os.LookupEnv(debugAnalyzerKey) _, verbose := os.LookupEnv(verboseAnalyzerKey) + _, trace := os.LookupEnv(traceAnalyzerKey) var batches = []*Batch{ { Desc: "pre-analyzer", @@ -266,6 +268,7 @@ func (ab *Builder) Build() *Analyzer { return &Analyzer{ Debug: debug || ab.debug, Verbose: verbose, + Trace: trace, contextStack: make([]string, 0), Batches: batches, Catalog: NewCatalog(ab.provider), @@ -297,6 +300,8 @@ type Analyzer struct { Batches []*Batch // Whether to log various debugging messages Debug bool + // Whether to output detailed trace logging for join planning + Trace bool // Whether to output the query plan at each step of the analyzer Verbose bool } diff --git a/sql/analyzer/costed_index_scan.go b/sql/analyzer/costed_index_scan.go index e41d07cfaa..68be635157 100644 --- a/sql/analyzer/costed_index_scan.go +++ b/sql/analyzer/costed_index_scan.go @@ -131,7 +131,7 @@ func costedIndexLookup(ctx *sql.Context, n sql.Node, a *Analyzer, iat sql.IndexA if err != nil { return n, transform.SameTree, err } - // TODO(next): this is getting a GMSCast node and not getting an index assigned here + ita, stats, filters, err := getCostedIndexScan(ctx, a.Catalog, rt, indexes, SplitConjunction(oldFilter), qFlags) if err != nil || ita == nil { return n, transform.SameTree, err @@ -334,6 +334,9 @@ func getCostedIndexScan(ctx *sql.Context, statsProv sql.StatsProvider, rt sql.Ta } func addIndexScans(ctx *sql.Context, m *memo.Memo) error { + m.Tracer.PushDebugContext("addIndexScans") + defer m.Tracer.PopDebugContext() + return memo.DfsRel(m.Root(), func(e memo.RelExpr) error { filter, ok := e.(*memo.Filter) if !ok { @@ -928,7 +931,7 @@ func (b *indexScanRangeBuilder) rangeBuildOr(f *iScanOr, inScan bool) (sql.MySQL // imprecise filters cannot be removed b.markImprecise(f) - //todo union the or ranges + // todo union the or ranges var ret sql.MySQLRangeCollection for _, c := range f.children { var ranges sql.MySQLRangeCollection diff --git a/sql/analyzer/indexed_joins.go b/sql/analyzer/indexed_joins.go index a250274c06..d701726dc9 100644 --- a/sql/analyzer/indexed_joins.go +++ b/sql/analyzer/indexed_joins.go @@ -137,6 +137,7 @@ func recSchemaToGetFields(n sql.Node, sch sql.Schema) []sql.Expression { func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Scope, qFlags *sql.QueryFlags) (ret sql.Node, err error) { m := memo.NewMemo(ctx, a.Catalog, scope, len(scope.Schema()), a.Coster, qFlags) m.Debug = a.Debug + m.EnableTrace(a.Trace) defer func() { if r := recover(); r != nil { @@ -165,14 +166,17 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco if err != nil { return nil, err } + err = convertSemiToInnerJoin(m) if err != nil { return nil, err } + err = convertAntiToLeftJoin(m) if err != nil { return nil, err } + err = addRightSemiJoins(ctx, m) if err != nil { return nil, err @@ -188,18 +192,22 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco if err != nil { return nil, err } + } else { + m.Tracer.Log("Skipping merge joins (disabled by hints)") } - memo.CardMemoGroups(ctx, m.Root()) + m.CardMemoGroups(ctx, m.Root()) err = addCrossHashJoins(m) if err != nil { return nil, err } + err = addHashJoins(m) if err != nil { return nil, err } + err = addRangeHeapJoin(m) if err != nil { return nil, err @@ -208,14 +216,21 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco // Once we've enumerated all expression groups, we can apply hints. This must be done after expression // groups have been identified, so that the applied hints use the correct metadata. for _, h := range hints { + m.Tracer.Log("Applying hint: %s", h.Typ) m.ApplyHint(h) } + if m.Tracer.TraceEnabled { + m.Tracer.Log("Starting cost-based optimization for groups %s", m) + } + err = m.OptimizeRoot() if err != nil { return nil, err } + m.LogCostDebugString() + if a.Verbose && a.Debug { a.Log("%s", m.String()) } @@ -223,6 +238,8 @@ func replanJoin(ctx *sql.Context, n *plan.JoinNode, a *Analyzer, scope *plan.Sco scope.JoinTrees = append(scope.JoinTrees, m.String()) } + m.LogBestPlanDebugString() + return m.BestRootPlan(ctx) } @@ -244,7 +261,13 @@ func mergeJoinsDisabled(hints []memo.Hint) bool { // attributes in the join filter. Costing is responsible for choosing the most // appropriate execution plan among options added to an expression group. func addLookupJoins(ctx *sql.Context, m *memo.Memo) error { + m.Tracer.PushDebugContext("addLookupJoins") + defer m.Tracer.PopDebugContext() + return memo.DfsRel(m.Root(), func(e memo.RelExpr) error { + m.Tracer.PushDebugContextFmt("%+v", e) + defer m.Tracer.PopDebugContext() + var right *memo.ExprGroup var join *memo.JoinBase @@ -258,7 +281,7 @@ func addLookupJoins(ctx *sql.Context, m *memo.Memo) error { case *memo.LeftJoin: right = e.Right join = e.JoinBase - //TODO fullouterjoin + // TODO fullouterjoin case *memo.SemiJoin: right = e.Right join = e.JoinBase @@ -267,10 +290,12 @@ func addLookupJoins(ctx *sql.Context, m *memo.Memo) error { } if len(join.Filter) == 0 { + m.Tracer.Log("Skipping lookup join for %T - no filters", e) return nil } tableId, indexes, extraFilters := lookupCandidates(right.First, false) + m.Tracer.Log("Found %d index candidates for lookup join", len(indexes)) var rt sql.TableNode var aliasName string @@ -281,10 +306,12 @@ func addLookupJoins(ctx *sql.Context, m *memo.Memo) error { var ok bool rt, ok = n.Child.(sql.TableNode) if !ok { + m.Tracer.Log("Skipping lookup join - table alias child is not TableNode") return nil } aliasName = n.Name() default: + m.Tracer.Log("Skipping lookup join - unsupported table node type: %T", n) return nil } @@ -325,10 +352,14 @@ func addLookupJoins(ctx *sql.Context, m *memo.Memo) error { for _, idx := range indexes { keyExprs, matchedFilters, nullmask := keyExprsForIndex(tableId, idx.Cols(), append(join.Filter, extraFilters...)) if keyExprs == nil { + m.Tracer.Log("Index %s: no matching key expressions found", idx.SqlIdx().ID()) continue } + m.Tracer.Log("Index %s: found %d key expressions, %d matched filters", idx.SqlIdx().ID(), len(keyExprs), len(matchedFilters)) + ita, err := plan.NewIndexedAccessForTableNode(ctx, rt, plan.NewLookupBuilder(idx.SqlIdx(), keyExprs, nullmask)) if err != nil { + m.Tracer.Log("Index %s: failed to create indexed table access: %v", idx.SqlIdx().ID(), err) return err } lookup := &memo.IndexScan{ @@ -350,6 +381,7 @@ func addLookupJoins(ctx *sql.Context, m *memo.Memo) error { } } + m.Tracer.Log("Adding lookup join with index %s, %d remaining filters", idx.SqlIdx().ID(), len(filters)) m.MemoizeLookupJoin(e.Group(), join.Left, join.Right, join.Op, filters, lookup) } return nil @@ -433,6 +465,9 @@ func exprRefsTable(e sql.Expression, tableId sql.TableId) bool { // https://www.researchgate.net/publication/221311318_Cost-Based_Query_Transformation_in_Oracle // TODO: need more elegant way to extend the number of groups, interner func convertSemiToInnerJoin(m *memo.Memo) error { + m.Tracer.PushDebugContext("convertSemiToInnerJoin") + defer m.Tracer.PopDebugContext() + return memo.DfsRel(m.Root(), func(e memo.RelExpr) error { semi, ok := e.(*memo.SemiJoin) if !ok { @@ -523,6 +558,9 @@ func convertSemiToInnerJoin(m *memo.Memo) error { // convertAntiToLeftJoin adds left join alternatives for anti join // ANTI_JOIN(left, right) => PROJECT(left sch) -> FILTER(right attr IS NULL) -> LEFT_JOIN(left, right) func convertAntiToLeftJoin(m *memo.Memo) error { + m.Tracer.PushDebugContext("convertAntiToLeftJoin") + defer m.Tracer.PopDebugContext() + return memo.DfsRel(m.Root(), func(e memo.RelExpr) error { anti, ok := e.(*memo.AntiJoin) if !ok { @@ -627,7 +665,13 @@ func convertAntiToLeftJoin(m *memo.Memo) error { // addRightSemiJoins allows for a reversed semiJoin operator when // the join attributes of the left side are provably unique. func addRightSemiJoins(ctx *sql.Context, m *memo.Memo) error { + m.Tracer.PushDebugContext("addRightSemiJoins") + defer m.Tracer.PopDebugContext() + return memo.DfsRel(m.Root(), func(e memo.RelExpr) error { + m.Tracer.PushDebugContextFmt("%+v", e) + defer m.Tracer.PopDebugContext() + semi, ok := e.(*memo.SemiJoin) if !ok { return nil @@ -748,6 +792,9 @@ func dfsLookupCandidates(rel memo.RelExpr, limitOk bool) (sql.TableId, []*memo.I } func addCrossHashJoins(m *memo.Memo) error { + m.Tracer.PushDebugContext("addCrossHashJoins") + defer m.Tracer.PopDebugContext() + return memo.DfsRel(m.Root(), func(e memo.RelExpr) error { switch e.(type) { case *memo.CrossJoin: @@ -784,7 +831,13 @@ func addCrossHashJoins(m *memo.Memo) error { } func addHashJoins(m *memo.Memo) error { + m.Tracer.PushDebugContext("addHashJoins") + defer m.Tracer.PopDebugContext() + return memo.DfsRel(m.Root(), func(e memo.RelExpr) error { + m.Tracer.PushDebugContextFmt("%+v", e) + defer m.Tracer.PopDebugContext() + switch e.(type) { case *memo.InnerJoin, *memo.LeftJoin: default: @@ -793,9 +846,12 @@ func addHashJoins(m *memo.Memo) error { join := e.(memo.JoinRel).JoinPrivate() if len(join.Filter) == 0 { + m.Tracer.Log("Skipping hash join for %T - no filters", e) return nil } + m.Tracer.Log("Considering hash join with %d filters", len(join.Filter)) + var fromExpr, toExpr []sql.Expression for _, f := range join.Filter { switch f := f.(type) { @@ -804,22 +860,28 @@ func addHashJoins(m *memo.Memo) error { satisfiesScalarRefs(f.Right(), join.Right.RelProps.OutputTables()) { fromExpr = append(fromExpr, f.Right()) toExpr = append(toExpr, f.Left()) + m.Tracer.Log("Filter %s: found a left->right hash key mapping", f) } else if satisfiesScalarRefs(f.Right(), join.Left.RelProps.OutputTables()) && satisfiesScalarRefs(f.Left(), join.Right.RelProps.OutputTables()) { fromExpr = append(fromExpr, f.Left()) toExpr = append(toExpr, f.Right()) + m.Tracer.Log("Filter %s: found a right->left hash key mapping", f) } else { + m.Tracer.Log("Filter %s: does not satisfy scalar refs for hash join", f) return nil } default: + m.Tracer.Log("Filter %s: not an equality expression, skipping hash join", f) return nil } } switch join.Right.First.(type) { case *memo.RecursiveTable: + m.Tracer.Log("Skipping hash join - right side is recursive table") return nil } + m.Tracer.Log("Adding hash join with %d key expressions", len(toExpr)) m.MemoizeHashJoin(e.Group(), join, toExpr, fromExpr) return nil }) @@ -909,7 +971,13 @@ func getRangeFilters(filters []sql.Expression) (ranges []rangeFilter) { // - SELECT * FROM a JOIN b on a.value BETWEEN b.min AND b.max // - SELECT * FROM a JOIN b on b.min <= a.value AND a.value < b.max func addRangeHeapJoin(m *memo.Memo) error { + m.Tracer.PushDebugContext("addRangeHeapJoin") + defer m.Tracer.PopDebugContext() + return memo.DfsRel(m.Root(), func(e memo.RelExpr) error { + m.Tracer.PushDebugContextFmt("%+v", e) + defer m.Tracer.PopDebugContext() + switch e.(type) { case *memo.InnerJoin, *memo.LeftJoin: default: @@ -1014,29 +1082,41 @@ func satisfiesScalarRefs(e sql.Expression, tables sql.FastIntSet) bool { // filter. // TODO: sort-merge joins func addMergeJoins(ctx *sql.Context, m *memo.Memo) error { + m.Tracer.PushDebugContext("addMergeJoins") + defer m.Tracer.PopDebugContext() + return memo.DfsRel(m.Root(), func(e memo.RelExpr) error { + m.Tracer.PushDebugContextFmt("%+v", e) + defer m.Tracer.PopDebugContext() + var join *memo.JoinBase switch e := e.(type) { case *memo.InnerJoin: join = e.JoinBase case *memo.LeftJoin: join = e.JoinBase - //TODO semijoin, antijoin, fullouterjoin + // TODO semijoin, antijoin, fullouterjoin default: return nil } if len(join.Filter) == 0 { + m.Tracer.Log("Skipping merge join for %T - no filters", e) return nil } + m.Tracer.Log("Considering merge join for %T with %d filters", e, len(join.Filter)) + leftTabId, lIndexes, lFilters := lookupCandidates(join.Left.First, true) rightTabId, rIndexes, rFilters := lookupCandidates(join.Right.First, true) if leftTabId == 0 || rightTabId == 0 { + m.Tracer.Log("Skipping merge join - no valid table candidates found") return nil } + m.Tracer.Log("Found %d left indexes, %d right indexes for merge join", len(lIndexes), len(rIndexes)) + leftTab := join.Left.RelProps.TableIdNodes()[0] rightTab := join.Right.RelProps.TableIdNodes()[0] @@ -1095,16 +1175,21 @@ func addMergeJoins(ctx *sql.Context, m *memo.Memo) error { if lIndex.Order() == sql.IndexOrderNone { // lookups can be unordered, merge indexes need to // be globally ordered + m.Tracer.Log("Left index %s: skipping - unordered index", lIndex.SqlIdx().ID()) continue } matchedEqFilters := matchedFiltersForLeftIndex(lIndex, join.Left.RelProps.FuncDeps().Constants(), eqFilters) + m.Tracer.Log("Left index %s: matched %d equality filters", lIndex.SqlIdx().ID(), len(matchedEqFilters)) + for len(matchedEqFilters) > 0 { for _, rIndex := range rIndexes { if rIndex.Order() == sql.IndexOrderNone { + m.Tracer.Log("Right index %s: skipping - unordered index", rIndex.SqlIdx().ID()) continue } if rightIndexMatchesFilters(rIndex, join.Left.RelProps.FuncDeps().Constants(), matchedEqFilters) { + m.Tracer.Log("Found matching index pair: left[%s] <-> right[%s]", lIndex.SqlIdx().ID(), rIndex.SqlIdx().ID()) jb := join.Copy() if d, ok := jb.Left.First.(*memo.Distinct); ok && lIndex.SqlIdx().IsUnique() { jb.Left = d.Child @@ -1147,8 +1232,10 @@ func addMergeJoins(ctx *sql.Context, m *memo.Memo) error { return err } if !success { + m.Tracer.Log("Failed to create index scan for right index %s", rIndex.SqlIdx().ID()) continue } + m.Tracer.Log("Adding merge join with left index %s, right index %s", lIndex.SqlIdx().ID(), rIndex.SqlIdx().ID()) m.MemoizeMergeJoin(e.Group(), join.Left, join.Right, lIndexScan, rIndexScan, jb.Op.AsMerge(), newFilters, false) } } diff --git a/sql/analyzer/indexed_joins_test.go b/sql/analyzer/indexed_joins_test.go index 1f5061c6ed..08602f816a 100644 --- a/sql/analyzer/indexed_joins_test.go +++ b/sql/analyzer/indexed_joins_test.go @@ -40,16 +40,16 @@ func TestHashJoins(t *testing.T) { memo: `memo: ├── G1: (tablescan: ab) ├── G2: (tablescan: xy) -├── G3: (hashjoin 1 2) (hashjoin 2 1) (innerjoin 2 1) (innerjoin 1 2) +├── G3: (hashjoin 1[ab] 2[xy]) (hashjoin 2[xy] 1[ab]) (innerjoin 2[xy] 1[ab]) (innerjoin 1[ab] 2[xy]) ├── G4: (tablescan: pq) -├── G5: (hashjoin 3 4) (hashjoin 1 9) (hashjoin 9 1) (hashjoin 2 8) (hashjoin 8 2) (hashjoin 4 3) (innerjoin 4 3) (innerjoin 8 2) (innerjoin 2 8) (innerjoin 9 1) (innerjoin 1 9) (innerjoin 3 4) +├── G5: (hashjoin 3 4[pq]) (hashjoin 1[ab] 9) (hashjoin 9 1[ab]) (hashjoin 2[xy] 8) (hashjoin 8 2[xy]) (hashjoin 4[pq] 3) (innerjoin 4[pq] 3) (innerjoin 8 2[xy]) (innerjoin 2[xy] 8) (innerjoin 9 1[ab]) (innerjoin 1[ab] 9) (innerjoin 3 4[pq]) ├── G6: (tablescan: uv) -├── G7: (hashjoin 5 6) (hashjoin 1 12) (hashjoin 12 1) (hashjoin 2 11) (hashjoin 11 2) (hashjoin 3 10) (hashjoin 10 3) (hashjoin 6 5) (innerjoin 6 5) (innerjoin 10 3) (innerjoin 3 10) (innerjoin 11 2) (innerjoin 2 11) (innerjoin 12 1) (innerjoin 1 12) (innerjoin 5 6) -├── G8: (hashjoin 1 4) (hashjoin 4 1) (innerjoin 4 1) (innerjoin 1 4) -├── G9: (hashjoin 2 4) (hashjoin 4 2) (innerjoin 4 2) (innerjoin 2 4) -├── G10: (hashjoin 4 6) (hashjoin 6 4) (innerjoin 6 4) (innerjoin 4 6) -├── G11: (hashjoin 1 10) (hashjoin 10 1) (hashjoin 8 6) (hashjoin 6 8) (innerjoin 6 8) (innerjoin 8 6) (innerjoin 10 1) (innerjoin 1 10) -└── G12: (hashjoin 2 10) (hashjoin 10 2) (hashjoin 9 6) (hashjoin 6 9) (innerjoin 6 9) (innerjoin 9 6) (innerjoin 10 2) (innerjoin 2 10) +├── G7: (hashjoin 5 6[uv]) (hashjoin 1[ab] 12) (hashjoin 12 1[ab]) (hashjoin 2[xy] 11) (hashjoin 11 2[xy]) (hashjoin 3 10) (hashjoin 10 3) (hashjoin 6[uv] 5) (innerjoin 6[uv] 5) (innerjoin 10 3) (innerjoin 3 10) (innerjoin 11 2[xy]) (innerjoin 2[xy] 11) (innerjoin 12 1[ab]) (innerjoin 1[ab] 12) (innerjoin 5 6[uv]) +├── G8: (hashjoin 1[ab] 4[pq]) (hashjoin 4[pq] 1[ab]) (innerjoin 4[pq] 1[ab]) (innerjoin 1[ab] 4[pq]) +├── G9: (hashjoin 2[xy] 4[pq]) (hashjoin 4[pq] 2[xy]) (innerjoin 4[pq] 2[xy]) (innerjoin 2[xy] 4[pq]) +├── G10: (hashjoin 4[pq] 6[uv]) (hashjoin 6[uv] 4[pq]) (innerjoin 6[uv] 4[pq]) (innerjoin 4[pq] 6[uv]) +├── G11: (hashjoin 1[ab] 10) (hashjoin 10 1[ab]) (hashjoin 8 6[uv]) (hashjoin 6[uv] 8) (innerjoin 6[uv] 8) (innerjoin 8 6[uv]) (innerjoin 10 1[ab]) (innerjoin 1[ab] 10) +└── G12: (hashjoin 2[xy] 10) (hashjoin 10 2[xy]) (hashjoin 9 6[uv]) (hashjoin 6[uv] 9) (innerjoin 6[uv] 9) (innerjoin 9 6[uv]) (innerjoin 10 2[xy]) (innerjoin 2[xy] 10) `, }, } diff --git a/sql/memo/dfs.go b/sql/memo/dfs.go index 1990a5b02f..1e07392abc 100644 --- a/sql/memo/dfs.go +++ b/sql/memo/dfs.go @@ -26,8 +26,8 @@ func dfsRelHelper(grp *ExprGroup, seen map[GroupId]struct{}, cb func(rel RelExpr } else { seen[grp.Id] = struct{}{} } - n := grp.First - for n != nil { + + for n := range grp.Iter() { for _, c := range n.Children() { err := dfsRelHelper(c, seen, cb) if err != nil { @@ -38,7 +38,6 @@ func dfsRelHelper(grp *ExprGroup, seen map[GroupId]struct{}, cb func(rel RelExpr if err != nil { return err } - n = n.Next() } return nil } diff --git a/sql/memo/expr_group.go b/sql/memo/expr_group.go index 59bf1d5cd4..2c198f7e73 100644 --- a/sql/memo/expr_group.go +++ b/sql/memo/expr_group.go @@ -16,6 +16,11 @@ package memo import ( "fmt" + "io" + "iter" + "maps" + "slices" + "sort" "strings" "github.com/dolthub/go-mysql-server/sql" @@ -25,18 +30,40 @@ import ( // ExprGroup is a linked list of plans that return the same result set // defined by row count and schema. type ExprGroup struct { - m *Memo - RelProps *relProps - First RelExpr - Best RelExpr - _children []*ExprGroup - Cost float64 + m *Memo + RelProps *relProps + First RelExpr + Best RelExpr + Cost float64 Id GroupId Done bool HintOk bool } +// Format implements the fmt.Formatter interface. +func (e *ExprGroup) Format(f fmt.State, verb rune) { + expr := e.Best + if expr == nil { + expr = e.First + } + switch ex := expr.(type) { + case sql.Nameable: + io.WriteString(f, fmt.Sprintf("%d", expr.Group().Id)) + io.WriteString(f, "[") + io.WriteString(f, ex.Name()) + io.WriteString(f, "]") + default: + if verb == 'v' && f.Flag('+') { + io.WriteString(f, fmt.Sprintf("%d{%+v}", ex.Group().Id, ex)) + } else { + io.WriteString(f, fmt.Sprintf("%d", ex.Group().Id)) + } + } +} + +var _ fmt.Formatter = (*ExprGroup)(nil) + func newExprGroup(m *Memo, id GroupId, expr exprType) *ExprGroup { // bit of circularity: |grp| references |rel|, |rel| references |grp|, // and |relProps| references |rel| and |grp| info. @@ -61,30 +88,32 @@ func (e *ExprGroup) Prepend(rel RelExpr) { rel.SetNext(first) } +// Iter returns an iterator over the RelExprs in this ExprGroup. +func (e *ExprGroup) Iter() iter.Seq[RelExpr] { + return IterRelExprs(e.First) +} + // children returns a unioned list of child ExprGroup for // every logical plan in this group. -func (e *ExprGroup) children() []*ExprGroup { - relExpr, ok := e.First.(RelExpr) - if !ok { - return e.children() - } - n := relExpr - children := make([]*ExprGroup, 0) - for n != nil { - children = append(children, n.Children()...) - n = n.Next() +func (e *ExprGroup) children() iter.Seq[*ExprGroup] { + children := make(map[GroupId]*ExprGroup) + for n := range e.Iter() { + for _, n := range n.Children() { + children[n.Id] = n + } } - return children + return maps.Values(children) } -// updateBest updates a group's Best to the given expression or a hinted -// operator if the hinted plan is not found. Join operator is applied as -// a local rather than global property. -func (e *ExprGroup) updateBest(n RelExpr, grpCost float64) { +// updateBest updates a group's Best to the given expression if the cost is lower than the current best. +// Returns whether the best plan was updated. +func (e *ExprGroup) updateBest(n RelExpr, grpCost float64) bool { if e.Best == nil || grpCost < e.Cost { e.Best = n e.Cost = grpCost + return true } + return false } func (e *ExprGroup) finalize(node sql.Node) (sql.Node, error) { @@ -164,18 +193,17 @@ func (e *ExprGroup) fixTableScanPath() bool { func (e *ExprGroup) String() string { b := strings.Builder{} - n := e.First sep := "" - for n != nil { + for n := range e.Iter() { b.WriteString(sep) - b.WriteString(fmt.Sprintf("(%s", FormatExpr(n))) + b.WriteString(fmt.Sprintf("(%s", n)) if e.Best != nil { cost := n.Cost() if cost == 0 { // if source relation we want the cardinality cost = float64(n.Group().RelProps.GetStats().RowCount()) } - b.WriteString(fmt.Sprintf(" %.1f", n.Cost())) + b.WriteString(fmt.Sprintf(" %.1f", cost)) childCost := 0.0 for _, c := range n.Children() { @@ -190,7 +218,50 @@ func (e *ExprGroup) String() string { b.WriteString(")") } sep = " " - n = n.Next() } return b.String() } + +// CostTreeString returns a string representation of the expression group for use in cost debug printing +func (e *ExprGroup) CostTreeString(prefix string) string { + b := strings.Builder{} + costSortedGroups := slices.Collect(e.Iter()) + sort.Slice(costSortedGroups, func(i, j int) bool { + return costSortedGroups[i].Cost() < costSortedGroups[j].Cost() + }) + + for i, n := range costSortedGroups { + b.WriteString("\n") + + beg := prefix + "├── " + if i == len(costSortedGroups)-1 { + beg = prefix + "└── " + } + b.WriteString(fmt.Sprintf("%s(%s", beg, n)) + if e.Best != nil { + cost := n.Cost() + if cost == 0 { + // if source relation we want the cardinality + cost = float64(n.Group().RelProps.GetStats().RowCount()) + } + b.WriteString(fmt.Sprintf(" %.1f", cost)) + } + b.WriteString(")") + } + + return b.String() +} + +// BestPlanDebugString returns a string representation of the physical best plan for use in cost debug printing +func (e *ExprGroup) BestPlanDebugString() string { + tp := sql.NewTreePrinter() + tp.WriteNode("G%d [%s] Cost: %.1f", e.Id, e.Best, e.Best.Cost()) + children := e.Best.Children() + childrenStrings := make([]string, len(children)) + for i, c := range children { + childrenStrings[i] = c.BestPlanDebugString() + } + + tp.WriteChildren(childrenStrings...) + return tp.String() +} diff --git a/sql/memo/join_order_builder.go b/sql/memo/join_order_builder.go index 652fe2611d..ef4a03f3d3 100644 --- a/sql/memo/join_order_builder.go +++ b/sql/memo/join_order_builder.go @@ -26,6 +26,10 @@ import ( "github.com/dolthub/go-mysql-server/sql/plan" ) +// SplitConjunction is a pseudo-extension point of expression.SplitConjunction, used to alter the logic +// for different integrators. +var SplitConjunction func(expr sql.Expression) []sql.Expression = expression.SplitConjunction + // joinOrderBuilder enumerates valid plans for a join tree. We build the join // tree bottom up, first joining single nodes with join condition "edges", then // single nodes to hypernodes (1+n), and finally hyper nodes to @@ -166,21 +170,34 @@ func (j *joinOrderBuilder) useFastReorder() bool { } func (j *joinOrderBuilder) ReorderJoin(n sql.Node) { + j.m.Tracer.PushDebugContext("ReorderJoin") + defer j.m.Tracer.PopDebugContext() + j.populateSubgraph(n) + if j.useFastReorder() { - j.buildSingleLookupPlan() - return + j.m.Tracer.Log("Using fast reorder algorithm (large join with %d tables)", len(j.vertices)) + if j.buildSingleLookupPlan() { + j.m.Tracer.Log("Successfully built single lookup plan") + return + } + j.m.Tracer.Log("Failed to build single lookup plan, falling back to exhaustive enumeration") } else if j.hasCrossJoin { + j.m.Tracer.Log("Join contains cross joins, attempting single lookup plan first") // Rely on FastReorder to avoid plans that drop filters with cross joins if j.buildSingleLookupPlan() { + j.m.Tracer.Log("Successfully built single lookup plan for cross join") return } + j.m.Tracer.Log("Failed to build single lookup plan for cross join, using exhaustive enumeration") } + // TODO: consider if buildSingleLookupPlan can/should run after ensureClosure. This could allow us to use analysis // from ensureClosure in buildSingleLookupPlan, but the equivalence sets could create multiple possible join orders // for the single-lookup plan, which would complicate things. j.ensureClosure(j.m.root) j.dpEnumerateSubsets() + j.m.Tracer.Log("Completed join reordering") return } @@ -420,6 +437,9 @@ func (j *joinOrderBuilder) findVertexFromGroup(grp GroupId) vertexIndex { // makeTransitiveEdge constructs a new join tree edge and memo group // on an equality filter between two columns. func (j *joinOrderBuilder) makeTransitiveEdge(col1, col2 sql.ColumnId) { + j.m.Tracer.PushDebugContext("makeTransitiveEdge") + defer j.m.Tracer.PopDebugContext() + var vert vertexSet v1, _, v1found := j.findVertexFromCol(col1) v2, _, v2found := j.findVertexFromCol(col2) @@ -468,17 +488,22 @@ func (j *joinOrderBuilder) makeTransitiveEdge(col1, col2 sql.ColumnId) { return } - j.edges = append(j.edges, *j.makeEdge(op, expression.NewEquals(gf1, gf2))) + eq := expression.NewEquals(gf1, gf2) + j.m.Tracer.Log("adding edge %s", eq) + j.edges = append(j.edges, *j.makeEdge(op, eq)) j.innerEdges.Add(len(j.edges) - 1) - } func (j *joinOrderBuilder) buildJoinOp(n *plan.JoinNode) *ExprGroup { + j.m.Tracer.PushDebugContext("buildJoinOp") + defer j.m.Tracer.PopDebugContext() + leftV, leftE, _ := j.populateSubgraph(n.Left()) rightV, rightE, _ := j.populateSubgraph(n.Right()) typ := n.JoinType() if typ.IsPhysical() { typ = plan.JoinTypeInner + j.m.Tracer.Log("Converted physical join type to inner join") } isInner := typ.IsInner() op := &operator{ @@ -489,21 +514,25 @@ func (j *joinOrderBuilder) buildJoinOp(n *plan.JoinNode) *ExprGroup { rightEdges: rightE, } - filters := expression.SplitConjunction(n.JoinCond()) + filters := SplitConjunction(n.JoinCond()) + j.m.Tracer.Log("Join filters: %v", filters) union := leftV.union(rightV) group, ok := j.plans[union] if !ok { // TODO: memo and root should be initialized prior to join planning left := j.plans[leftV] right := j.plans[rightV] - group = j.memoize(op.joinType, left, right, filters, nil) + group = j.memoize(op.joinType, left, right, filters) j.plans[union] = group j.m.root = group + j.m.Tracer.Log("Created new memo group for join combination") } if !isInner { + j.m.Tracer.Log("Building non-inner edge for join type: %s", typ) j.buildNonInnerEdge(op, filters...) } else { + j.m.Tracer.Log("Building inner edge for join type: %s", typ) j.buildInnerEdge(op, filters...) } return group @@ -513,7 +542,7 @@ func (j *joinOrderBuilder) buildFilter(child sql.Node, e sql.Expression) (vertex // memoize child childV, childE, childGrp := j.populateSubgraph(child) - filterGrp := j.m.MemoizeFilter(nil, childGrp, expression.SplitConjunction(e)) + filterGrp := j.m.MemoizeFilter(nil, childGrp, SplitConjunction(e)) // filter will absorb child relation for join reordering j.plans[childV] = filterGrp @@ -633,6 +662,9 @@ func (j *joinOrderBuilder) checkSize() { // adding plans to the tree when we find two sets that can // be joined func (j *joinOrderBuilder) dpEnumerateSubsets() { + j.m.Tracer.PushDebugContext("dpEnumerateSubsets") + defer j.m.Tracer.PopDebugContext() + all := j.allVertices() for subset := vertexSet(1); subset <= all; subset++ { if subset.isSingleton() { @@ -668,16 +700,20 @@ func setPrinter(all, s1, s2 vertexSet) { // addPlans finds operators that let us join (s1 op s2) and (s2 op s1). func (j *joinOrderBuilder) addPlans(s1, s2 vertexSet) { + j.m.Tracer.PushDebugContextFmt("addPlans/%s<->%s", s1, s2) + defer j.m.Tracer.PopDebugContext() + // all inner filters could be applied if j.plans[s1] == nil || j.plans[s2] == nil { // Both inputs must have plans. // need this to prevent cross-joins higher in tree + j.m.Tracer.Log("Skipping join - one or both input plans are nil") return } - //TODO collect all inner join filters that can be used as select filters - //TODO collect functional dependencies to avoid redundant filters - //TODO relational nodes track functional dependencies + // TODO collect all inner join filters that can be used as select filters + // TODO collect functional dependencies to avoid redundant filters + // TODO relational nodes track functional dependencies var innerJoinFilters []sql.Expression var addInnerJoin bool @@ -691,6 +727,7 @@ func (j *joinOrderBuilder) addPlans(s1, s2 vertexSet) { } isRedundant = isRedundant || e.joinIsRedundant(s1, s2) addInnerJoin = true + j.m.Tracer.Log("Found applicable inner edge %d with filters: %v", i, e.filters) } } @@ -699,12 +736,14 @@ func (j *joinOrderBuilder) addPlans(s1, s2 vertexSet) { for i, ok := j.nonInnerEdges.Next(0); ok; i, ok = j.nonInnerEdges.Next(i + 1) { e := &j.edges[i] if e.applicable(s1, s2) { + j.m.Tracer.Log("Found applicable non-inner edge %d, adding join: %s", i, e.op.joinType) j.addJoin(e.op.joinType, s1, s2, e.filters, innerJoinFilters, e.joinIsRedundant(s1, s2)) return } if e.applicable(s2, s1) { // This is necessary because we only iterate s1 up to subset / 2 // in DPSube() + j.m.Tracer.Log("Found applicable non-inner edge %d (swapped), adding join: %s", i, e.op.joinType) j.addJoin(e.op.joinType, s2, s1, e.filters, innerJoinFilters, e.joinIsRedundant(s2, s1)) return } @@ -715,10 +754,14 @@ func (j *joinOrderBuilder) addPlans(s1, s2 vertexSet) { // already been constructed, because doing so can lead to a case where an // inner join replaces a non-inner join. if innerJoinFilters == nil { + j.m.Tracer.Log("Adding cross join") j.addJoin(plan.JoinTypeCross, s1, s2, nil, nil, isRedundant) } else { + j.m.Tracer.Log("Adding inner join with filters: %v", innerJoinFilters) j.addJoin(plan.JoinTypeInner, s1, s2, innerJoinFilters, nil, isRedundant) } + } else { + j.m.Tracer.Log("No applicable edges found for join") } } @@ -733,7 +776,7 @@ func (j *joinOrderBuilder) addJoin(op plan.JoinType, s1, s2 vertexSet, joinFilte group, ok := j.plans[union] if !isRedundant { if !ok { - group = j.memoize(op, left, right, joinFilter, selFilters) + group = j.memoize(op, left, right, joinFilter) j.plans[union] = group } else { j.addJoinToGroup(op, left, right, joinFilter, selFilters, group) @@ -781,7 +824,6 @@ func (j *joinOrderBuilder) memoize( left *ExprGroup, right *ExprGroup, joinFilter []sql.Expression, - selFilter []sql.Expression, ) *ExprGroup { rel := j.constructJoin(op, left, right, joinFilter, nil) return j.m.NewExprGroup(rel) @@ -908,9 +950,9 @@ func (e *edge) populateEdgeProps(tableIds []sql.TableId, edges []edge) { e.freeVars = cols // TODO implement, we currently limit transforms assuming no strong null safety - //e.nullRejectedRels = e.nullRejectingTables(nullAccepting, allNames, allV) + // e.nullRejectedRels = e.nullRejectingTables(nullAccepting, allNames, allV) - //SES is vertexSet of all tables referenced in cols + // SES is vertexSet of all tables referenced in cols e.calcSES(tables, tableIds) // use CD-C to expand dependency sets for operators // front load preventing applicable operators that would push crossjoins diff --git a/sql/memo/join_order_builder_test.go b/sql/memo/join_order_builder_test.go index 8ea10875c6..a0dd6b63ab 100644 --- a/sql/memo/join_order_builder_test.go +++ b/sql/memo/join_order_builder_test.go @@ -57,19 +57,19 @@ func TestJoinOrderBuilder(t *testing.T) { plans: `memo: ├── G1: (tablescan: a) ├── G2: (tablescan: b) -├── G3: (innerjoin 2 1) (innerjoin 1 2) +├── G3: (innerjoin 2[b] 1[a]) (innerjoin 1[a] 2[b]) ├── G4: (tablescan: c) -├── G5: (innerjoin 4 3) (innerjoin 8 2) (innerjoin 2 8) (innerjoin 9 1) (innerjoin 1 9) (innerjoin 3 4) +├── G5: (innerjoin 4[c] 3) (innerjoin 8 2[b]) (innerjoin 2[b] 8) (innerjoin 9 1[a]) (innerjoin 1[a] 9) (innerjoin 3 4[c]) ├── G6: (tablescan: d) -├── G7: (innerjoin 6 5) (innerjoin 10 9) (innerjoin 9 10) (innerjoin 11 8) (innerjoin 8 11) (innerjoin 12 4) (innerjoin 4 12) (innerjoin 13 3) (innerjoin 3 13) (innerjoin 14 2) (innerjoin 2 14) (innerjoin 15 1) (innerjoin 1 15) (innerjoin 5 6) -├── G8: (innerjoin 4 1) (innerjoin 1 4) -├── G9: (innerjoin 4 2) (innerjoin 2 4) -├── G10: (innerjoin 6 1) (innerjoin 1 6) -├── G11: (innerjoin 6 2) (innerjoin 2 6) -├── G12: (innerjoin 6 3) (innerjoin 3 6) (innerjoin 10 2) (innerjoin 2 10) (innerjoin 11 1) (innerjoin 1 11) -├── G13: (innerjoin 6 4) (innerjoin 4 6) -├── G14: (innerjoin 6 8) (innerjoin 8 6) (innerjoin 10 4) (innerjoin 4 10) (innerjoin 13 1) (innerjoin 1 13) -└── G15: (innerjoin 6 9) (innerjoin 9 6) (innerjoin 11 4) (innerjoin 4 11) (innerjoin 13 2) (innerjoin 2 13) +├── G7: (innerjoin 6[d] 5) (innerjoin 10 9) (innerjoin 9 10) (innerjoin 11 8) (innerjoin 8 11) (innerjoin 12 4[c]) (innerjoin 4[c] 12) (innerjoin 13 3) (innerjoin 3 13) (innerjoin 14 2[b]) (innerjoin 2[b] 14) (innerjoin 15 1[a]) (innerjoin 1[a] 15) (innerjoin 5 6[d]) +├── G8: (innerjoin 4[c] 1[a]) (innerjoin 1[a] 4[c]) +├── G9: (innerjoin 4[c] 2[b]) (innerjoin 2[b] 4[c]) +├── G10: (innerjoin 6[d] 1[a]) (innerjoin 1[a] 6[d]) +├── G11: (innerjoin 6[d] 2[b]) (innerjoin 2[b] 6[d]) +├── G12: (innerjoin 6[d] 3) (innerjoin 3 6[d]) (innerjoin 10 2[b]) (innerjoin 2[b] 10) (innerjoin 11 1[a]) (innerjoin 1[a] 11) +├── G13: (innerjoin 6[d] 4[c]) (innerjoin 4[c] 6[d]) +├── G14: (innerjoin 6[d] 8) (innerjoin 8 6[d]) (innerjoin 10 4[c]) (innerjoin 4[c] 10) (innerjoin 13 1[a]) (innerjoin 1[a] 13) +└── G15: (innerjoin 6[d] 9) (innerjoin 9 6[d]) (innerjoin 11 4[c]) (innerjoin 4[c] 11) (innerjoin 13 2[b]) (innerjoin 2[b] 13) `, }, { @@ -102,32 +102,32 @@ func TestJoinOrderBuilder(t *testing.T) { plans: `memo: ├── G1: (tablescan: a) ├── G2: (tablescan: b) -├── G3: (leftjoin 1 2) +├── G3: (leftjoin 1[a] 2[b]) ├── G4: (tablescan: c) ├── G5: (tablescan: d) -├── G6: (fullouterjoin 4 5) +├── G6: (fullouterjoin 4[c] 5[d]) ├── G7: (tablescan: e) -├── G8: (leftjoin 6 7) -├── G9: (innerjoin 8 3) (leftjoin 14 2) (innerjoin 3 8) +├── G8: (leftjoin 6 7[e]) +├── G9: (innerjoin 8 3) (leftjoin 14 2[b]) (innerjoin 3 8) ├── G10: (tablescan: f) ├── G11: (tablescan: g) -├── G12: (innerjoin 11 10) (innerjoin 10 11) -├── G13: (innerjoin 11 19) (innerjoin 19 11) (innerjoin 21 17) (innerjoin 17 21) (innerjoin 22 16) (innerjoin 16 22) (innerjoin 24 10) (innerjoin 10 24) (innerjoin 12 9) (innerjoin 26 8) (innerjoin 8 26) (innerjoin 27 3) (innerjoin 3 27) (leftjoin 28 2) (innerjoin 9 12) -├── G14: (innerjoin 8 1) (innerjoin 1 8) -├── G15: (innerjoin 10 1) (innerjoin 1 10) -├── G16: (innerjoin 10 3) (innerjoin 3 10) (leftjoin 15 2) -├── G17: (innerjoin 10 8) (innerjoin 8 10) -├── G18: (innerjoin 10 14) (innerjoin 14 10) (innerjoin 15 8) (innerjoin 8 15) (innerjoin 17 1) (innerjoin 1 17) -├── G19: (innerjoin 10 9) (innerjoin 9 10) (innerjoin 16 8) (innerjoin 8 16) (innerjoin 17 3) (innerjoin 3 17) (leftjoin 18 2) -├── G20: (innerjoin 11 1) (innerjoin 1 11) -├── G21: (innerjoin 11 3) (innerjoin 3 11) (leftjoin 20 2) -├── G22: (innerjoin 11 8) (innerjoin 8 11) -├── G23: (innerjoin 11 14) (innerjoin 14 11) (innerjoin 20 8) (innerjoin 8 20) (innerjoin 22 1) (innerjoin 1 22) -├── G24: (innerjoin 11 9) (innerjoin 9 11) (innerjoin 21 8) (innerjoin 8 21) (innerjoin 22 3) (innerjoin 3 22) (leftjoin 23 2) -├── G25: (innerjoin 11 15) (innerjoin 15 11) (innerjoin 20 10) (innerjoin 10 20) (innerjoin 12 1) (innerjoin 1 12) -├── G26: (innerjoin 11 16) (innerjoin 16 11) (innerjoin 21 10) (innerjoin 10 21) (innerjoin 12 3) (innerjoin 3 12) (leftjoin 25 2) -├── G27: (innerjoin 11 17) (innerjoin 17 11) (innerjoin 22 10) (innerjoin 10 22) (innerjoin 12 8) (innerjoin 8 12) -└── G28: (innerjoin 11 18) (innerjoin 18 11) (innerjoin 20 17) (innerjoin 17 20) (innerjoin 22 15) (innerjoin 15 22) (innerjoin 23 10) (innerjoin 10 23) (innerjoin 12 14) (innerjoin 14 12) (innerjoin 25 8) (innerjoin 8 25) (innerjoin 27 1) (innerjoin 1 27) +├── G12: (innerjoin 11[g] 10[f]) (innerjoin 10[f] 11[g]) +├── G13: (innerjoin 11[g] 19) (innerjoin 19 11[g]) (innerjoin 21 17) (innerjoin 17 21) (innerjoin 22 16) (innerjoin 16 22) (innerjoin 24 10[f]) (innerjoin 10[f] 24) (innerjoin 12 9) (innerjoin 26 8) (innerjoin 8 26) (innerjoin 27 3) (innerjoin 3 27) (leftjoin 28 2[b]) (innerjoin 9 12) +├── G14: (innerjoin 8 1[a]) (innerjoin 1[a] 8) +├── G15: (innerjoin 10[f] 1[a]) (innerjoin 1[a] 10[f]) +├── G16: (innerjoin 10[f] 3) (innerjoin 3 10[f]) (leftjoin 15 2[b]) +├── G17: (innerjoin 10[f] 8) (innerjoin 8 10[f]) +├── G18: (innerjoin 10[f] 14) (innerjoin 14 10[f]) (innerjoin 15 8) (innerjoin 8 15) (innerjoin 17 1[a]) (innerjoin 1[a] 17) +├── G19: (innerjoin 10[f] 9) (innerjoin 9 10[f]) (innerjoin 16 8) (innerjoin 8 16) (innerjoin 17 3) (innerjoin 3 17) (leftjoin 18 2[b]) +├── G20: (innerjoin 11[g] 1[a]) (innerjoin 1[a] 11[g]) +├── G21: (innerjoin 11[g] 3) (innerjoin 3 11[g]) (leftjoin 20 2[b]) +├── G22: (innerjoin 11[g] 8) (innerjoin 8 11[g]) +├── G23: (innerjoin 11[g] 14) (innerjoin 14 11[g]) (innerjoin 20 8) (innerjoin 8 20) (innerjoin 22 1[a]) (innerjoin 1[a] 22) +├── G24: (innerjoin 11[g] 9) (innerjoin 9 11[g]) (innerjoin 21 8) (innerjoin 8 21) (innerjoin 22 3) (innerjoin 3 22) (leftjoin 23 2[b]) +├── G25: (innerjoin 11[g] 15) (innerjoin 15 11[g]) (innerjoin 20 10[f]) (innerjoin 10[f] 20) (innerjoin 12 1[a]) (innerjoin 1[a] 12) +├── G26: (innerjoin 11[g] 16) (innerjoin 16 11[g]) (innerjoin 21 10[f]) (innerjoin 10[f] 21) (innerjoin 12 3) (innerjoin 3 12) (leftjoin 25 2[b]) +├── G27: (innerjoin 11[g] 17) (innerjoin 17 11[g]) (innerjoin 22 10[f]) (innerjoin 10[f] 22) (innerjoin 12 8) (innerjoin 8 12) +└── G28: (innerjoin 11[g] 18) (innerjoin 18 11[g]) (innerjoin 20 17) (innerjoin 17 20) (innerjoin 22 15) (innerjoin 15 22) (innerjoin 23 10[f]) (innerjoin 10[f] 23) (innerjoin 12 14) (innerjoin 14 12) (innerjoin 25 8) (innerjoin 8 25) (innerjoin 27 1[a]) (innerjoin 1[a] 27) `, }, { @@ -146,10 +146,10 @@ func TestJoinOrderBuilder(t *testing.T) { plans: `memo: ├── G1: (tablescan: a) ├── G2: (tablescan: c) -├── G3: (crossjoin 1 2) +├── G3: (crossjoin 1[a] 2[c]) ├── G4: (tablescan: b) -├── G5: (innerjoin 1 6) (innerjoin 6 1) (innerjoin 3 4) -└── G6: (innerjoin 4 2) (innerjoin 2 4) +├── G5: (innerjoin 1[a] 6) (innerjoin 6 1[a]) (innerjoin 3 4[b]) +└── G6: (innerjoin 4[b] 2[c]) (innerjoin 2[c] 4[b]) `, }, { @@ -173,12 +173,12 @@ func TestJoinOrderBuilder(t *testing.T) { plans: `memo: ├── G1: (tablescan: c) ├── G2: (tablescan: d) -├── G3: (innerjoin 1 2) (innerjoin 2 1) (innerjoin 1 2) +├── G3: (innerjoin 1[c] 2[d]) (innerjoin 2[d] 1[c]) (innerjoin 1[c] 2[d]) ├── G4: (tablescan: a) ├── G5: (tablescan: b) -├── G6: (innerjoin 4 5) -├── G7: (innerjoin 4 8) (innerjoin 8 4) (innerjoin 3 6) -└── G8: (innerjoin 5 3) (innerjoin 3 5) +├── G6: (innerjoin 4[a] 5[b]) +├── G7: (innerjoin 4[a] 8) (innerjoin 8 4[a]) (innerjoin 3 6) +└── G8: (innerjoin 5[b] 3) (innerjoin 3 5[b]) `, }, } diff --git a/sql/memo/memo.go b/sql/memo/memo.go index cc017fad8d..649c9c5d8a 100644 --- a/sql/memo/memo.go +++ b/sql/memo/memo.go @@ -16,6 +16,9 @@ package memo import ( "fmt" + "io" + "iter" + "slices" "strings" "github.com/dolthub/go-mysql-server/sql" @@ -48,6 +51,7 @@ type Memo struct { scopeLen int cnt uint16 Debug bool + Tracer *TraceLogger } func NewMemo(ctx *sql.Context, stats sql.StatsProvider, s *plan.Scope, scopeLen int, cost Coster, qFlags *sql.QueryFlags) *Memo { @@ -60,6 +64,7 @@ func NewMemo(ctx *sql.Context, stats sql.StatsProvider, s *plan.Scope, scopeLen TableProps: newTableProps(), hints: &joinHints{}, QFlags: qFlags, + Tracer: &TraceLogger{}, } } @@ -71,6 +76,10 @@ func (m *Memo) HandleErr(err error) { panic(MemoErr{Err: err}) } +func (m *Memo) EnableTrace(enable bool) { + m.Tracer.TraceEnabled = enable +} + func (m *Memo) Root() *ExprGroup { return m.root } @@ -399,6 +408,9 @@ func (m *Memo) MemoizeMax1Row(grp, child *ExprGroup) *ExprGroup { // OptimizeRoot finds the implementation for the root expression // that has the lowest cost. func (m *Memo) OptimizeRoot() error { + m.Tracer.PushDebugContext("OptimizeRoot") + defer m.Tracer.PopDebugContext() + err := m.optimizeMemoGroup(m.root) if err != nil { return err @@ -423,6 +435,9 @@ func (m *Memo) optimizeMemoGroup(grp *ExprGroup) error { return nil } + m.Tracer.PushDebugContextFmt("optimizeMemoGroup/%d", grp.Id) + defer m.Tracer.PopDebugContext() + n := grp.First if _, ok := n.(SourceRel); ok { // We should order the search bottom-up so that physical operators @@ -434,10 +449,12 @@ func (m *Memo) optimizeMemoGroup(grp *ExprGroup) error { grp.HintOk = true grp.Best = grp.First grp.Best.SetDistinct(NoDistinctOp) + m.Tracer.Log("source relation, setting as best plan", grp) return nil } for n != nil { + m.Tracer.Log("Evaluating plan (%s)", n) var cost float64 for _, g := range n.Children() { err := m.optimizeMemoGroup(g) @@ -454,10 +471,12 @@ func (m *Memo) optimizeMemoGroup(grp *ExprGroup) error { if grp.RelProps.Distinct.IsHash() { if sortedInputs(n) { n.SetDistinct(SortedDistinctOp) + m.Tracer.Log("Plan %s: using sorted distinct", n) } else { n.SetDistinct(HashDistinctOp) d := &Distinct{Child: grp} - relCost += float64(statsForRel(m.Ctx, d).RowCount()) + relCost += float64(m.statsForRel(m.Ctx, d).RowCount()) + m.Tracer.Log("Plan %s: using hash distinct", n) } } else { n.SetDistinct(NoDistinctOp) @@ -465,6 +484,7 @@ func (m *Memo) optimizeMemoGroup(grp *ExprGroup) error { n.SetCost(relCost) cost += relCost + m.Tracer.Log("Plan %s: relCost=%.2f, totalCost=%.2f", n, relCost, cost) m.updateBest(grp, n, cost) n = n.Next() } @@ -483,15 +503,22 @@ func (m *Memo) updateBest(grp *ExprGroup, n RelExpr, cost float64) { grp.Best = n grp.Cost = cost grp.HintOk = true + m.Tracer.Log("Set best plan for group %d to hinted plan %s with cost %.2f", grp.Id, n, cost) return } - grp.updateBest(n, cost) + if grp.updateBest(n, cost) { + m.Tracer.Log("Updated best plan for group %d to hinted plan %s with cost %.2f", grp.Id, n, cost) + } } else if grp.Best == nil || !grp.HintOk { - grp.updateBest(n, cost) + if grp.updateBest(n, cost) { + m.Tracer.Log("Updated best plan for group %d to plan %s with cost %.2f (no hints satisfied)", grp.Id, n, cost) + } } return } - grp.updateBest(n, cost) + if grp.updateBest(n, cost) { + m.Tracer.Log("Updated best plan for group %d to plan %s with cost %.2f", grp.Id, n, cost) + } } func (m *Memo) BestRootPlan(ctx *sql.Context) (sql.Node, error) { @@ -585,17 +612,16 @@ func (m *Memo) SetJoinOp(op HintType, left, right string) { m.hints.ops = append(m.hints.ops, hint) } +var _ fmt.Stringer = (*Memo)(nil) + func (m *Memo) String() string { exprs := make([]string, m.cnt) groups := make([]*ExprGroup, 0) if m.root != nil { - r := m.root.First - for r != nil { - groups = append(groups, r.Group()) - groups = append(groups, r.Children()...) - r = r.Next() - } + groups = append(groups, m.root.First.Group()) } + + // breadth-first traversal of memo groups via their children for len(groups) > 0 { newGroups := make([]*ExprGroup, 0) for _, g := range groups { @@ -603,7 +629,7 @@ func (m *Memo) String() string { continue } exprs[int(TableIdForSource(g.Id))] = g.String() - newGroups = append(newGroups, g.children()...) + newGroups = slices.AppendSeq(newGroups, g.children()) } groups = newGroups } @@ -619,6 +645,67 @@ func (m *Memo) String() string { return b.String() } +// LogCostDebugString logs a string representation of the memo with cost +// information for each expression, ordered by best to worst for each group, +// displayed in a tree structure. +// Only logs if tracing is enabled. +func (m *Memo) LogCostDebugString() { + if m.root == nil || !m.Tracer.TraceEnabled { + return + } + + exprs := make([]string, m.cnt) + groups := make([]*ExprGroup, 0) + + b := strings.Builder{} + b.WriteString(fmt.Sprintf("costed memo (root group %d):\n", m.root.Id)) + + if m.root != nil { + groups = append(groups, m.root.First.Group()) + } + + // breadth-first traversal of memo groups via their children + for len(groups) > 0 { + newGroups := make([]*ExprGroup, 0) + for _, g := range groups { + if exprs[int(TableIdForSource(g.Id))] != "" { + continue + } + + prefix := "| " + if int(g.Id) == int(m.cnt) { + prefix = " " + } + + exprs[int(TableIdForSource(g.Id))] = g.CostTreeString(prefix) + newGroups = slices.AppendSeq(newGroups, g.children()) + } + groups = newGroups + } + + beg := "├──" + for i, g := range exprs { + if i == len(exprs)-1 { + beg = "└──" + } + b.WriteString(fmt.Sprintf("%s G%d: %s\n", beg, i+1, g)) + } + + m.Tracer.Log("Completed cost-based optimization:\n%s", b.String()) +} + +// LogBestPlanDebugString logs a physical tree representation of the best plan for each group in the tree that is +// referenced by the best plan in the root. This differs from other debug strings in that it represents the groups +// as children of their parents, rather than as a flat list, and only includes groups that are part of the best plan. +// Only logs if tracing is enabled. +func (m *Memo) LogBestPlanDebugString() { + if m.root == nil || !m.Tracer.TraceEnabled { + return + } + + m.Tracer.Log("Best root plan:\n%s", m.root.BestPlanDebugString()) +} + type tableProps struct { grpToName map[GroupId]string nameToGrp map[string]GroupId @@ -693,6 +780,19 @@ func relKey(r RelExpr) uint64 { return uint64(key) } +// IterRelExprs returns an iterator over the linked list of RelExprs beginning at the head e +func IterRelExprs(e RelExpr) iter.Seq[RelExpr] { + curr := e + return func(yield func(RelExpr) bool) { + for curr != nil { + if !yield(curr) { + return + } + curr = curr.Next() + } + } +} + type distinctOp uint8 const ( @@ -878,3 +978,73 @@ type RangeHeap struct { RangeClosedOnLowerBound bool RangeClosedOnUpperBound bool } + +// FormatExpr formats an exprType for debugging purposes, compatible with fmt.Formatter +func FormatExpr(r exprType, s fmt.State, verb rune) { + verbString := fmt.Sprintf("%%%c", verb) + if verb == 'v' && s.Flag('+') { + verbString = "%+v" + } + switch r := r.(type) { + case *CrossJoin: + io.WriteString(s, fmt.Sprintf("crossjoin "+verbString+" "+verbString, r.Left, r.Right)) + case *InnerJoin: + io.WriteString(s, fmt.Sprintf("innerjoin "+verbString+" "+verbString, r.Left, r.Right)) + case *LeftJoin: + io.WriteString(s, fmt.Sprintf("leftjoin "+verbString+" "+verbString, r.Left, r.Right)) + case *SemiJoin: + io.WriteString(s, fmt.Sprintf("semijoin "+verbString+" "+verbString, r.Left, r.Right)) + case *AntiJoin: + io.WriteString(s, fmt.Sprintf("antijoin "+verbString+" "+verbString, r.Left, r.Right)) + case *LookupJoin: + io.WriteString(s, fmt.Sprintf("lookupjoin "+verbString+" "+verbString+" on %s", + r.Left, r.Right, r.Lookup.Index.idx.ID())) + case *RangeHeapJoin: + io.WriteString(s, fmt.Sprintf("rangeheapjoin "+verbString+" "+verbString, r.Left, r.Right)) + case *ConcatJoin: + io.WriteString(s, fmt.Sprintf("concatjoin "+verbString+" "+verbString, r.Left, r.Right)) + case *HashJoin: + io.WriteString(s, fmt.Sprintf("hashjoin "+verbString+" "+verbString, r.Left, r.Right)) + case *MergeJoin: + io.WriteString(s, fmt.Sprintf("mergejoin "+verbString+" "+verbString, r.Left, r.Right)) + case *FullOuterJoin: + io.WriteString(s, fmt.Sprintf("fullouterjoin "+verbString+" "+verbString, r.Left, r.Right)) + case *LateralJoin: + io.WriteString(s, fmt.Sprintf("lateraljoin "+verbString+" "+verbString, r.Left, r.Right)) + case *TableScan: + io.WriteString(s, fmt.Sprintf("tablescan: %s", r.Name())) + case *IndexScan: + if r.Alias != "" { + io.WriteString(s, fmt.Sprintf("indexscan on %s: %s", r.Index.SqlIdx().ID(), r.Alias)) + } + io.WriteString(s, fmt.Sprintf("indexscan on %s: %s", r.Index.SqlIdx().ID(), r.Name())) + case *Values: + io.WriteString(s, fmt.Sprintf("values: %s", r.Name())) + case *TableAlias: + io.WriteString(s, fmt.Sprintf("tablealias: %s", r.Name())) + case *RecursiveTable: + io.WriteString(s, fmt.Sprintf("recursivetable: %s", r.Name())) + case *RecursiveCte: + io.WriteString(s, fmt.Sprintf("recursivecte: %s", r.Name())) + case *SubqueryAlias: + io.WriteString(s, fmt.Sprintf("subqueryalias: %s", r.Name())) + case *TableFunc: + io.WriteString(s, fmt.Sprintf("tablefunc: %s", r.Name())) + case *JSONTable: + io.WriteString(s, fmt.Sprintf("jsontable: %s", r.Name())) + case *EmptyTable: + io.WriteString(s, fmt.Sprintf("emptytable: %s", r.Name())) + case *SetOp: + io.WriteString(s, fmt.Sprintf("setop: %s", r.Name())) + case *Project: + io.WriteString(s, fmt.Sprintf("project: %d", r.Child.Id)) + case *Distinct: + io.WriteString(s, fmt.Sprintf("distinct: %d", r.Child.Id)) + case *Max1Row: + io.WriteString(s, fmt.Sprintf("max1row: %d", r.Child.Id)) + case *Filter: + io.WriteString(s, fmt.Sprintf("filter: %d", r.Child.Id)) + default: + panic(fmt.Sprintf("unknown RelExpr type: %T", r)) + } +} diff --git a/sql/memo/memo.og.go b/sql/memo/memo.og.go index 09d52be0f9..e4db9bbd25 100644 --- a/sql/memo/memo.og.go +++ b/sql/memo/memo.og.go @@ -15,10 +15,16 @@ type CrossJoin struct { } var _ RelExpr = (*CrossJoin)(nil) +var _ fmt.Formatter = (*CrossJoin)(nil) +var _ fmt.Stringer = (*CrossJoin)(nil) var _ JoinRel = (*CrossJoin)(nil) func (r *CrossJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *CrossJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *CrossJoin) JoinPrivate() *JoinBase { @@ -30,10 +36,16 @@ type InnerJoin struct { } var _ RelExpr = (*InnerJoin)(nil) +var _ fmt.Formatter = (*InnerJoin)(nil) +var _ fmt.Stringer = (*InnerJoin)(nil) var _ JoinRel = (*InnerJoin)(nil) func (r *InnerJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *InnerJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *InnerJoin) JoinPrivate() *JoinBase { @@ -45,10 +57,16 @@ type LeftJoin struct { } var _ RelExpr = (*LeftJoin)(nil) +var _ fmt.Formatter = (*LeftJoin)(nil) +var _ fmt.Stringer = (*LeftJoin)(nil) var _ JoinRel = (*LeftJoin)(nil) func (r *LeftJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *LeftJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *LeftJoin) JoinPrivate() *JoinBase { @@ -60,10 +78,16 @@ type SemiJoin struct { } var _ RelExpr = (*SemiJoin)(nil) +var _ fmt.Formatter = (*SemiJoin)(nil) +var _ fmt.Stringer = (*SemiJoin)(nil) var _ JoinRel = (*SemiJoin)(nil) func (r *SemiJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *SemiJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *SemiJoin) JoinPrivate() *JoinBase { @@ -75,10 +99,16 @@ type AntiJoin struct { } var _ RelExpr = (*AntiJoin)(nil) +var _ fmt.Formatter = (*AntiJoin)(nil) +var _ fmt.Stringer = (*AntiJoin)(nil) var _ JoinRel = (*AntiJoin)(nil) func (r *AntiJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *AntiJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *AntiJoin) JoinPrivate() *JoinBase { @@ -92,10 +122,16 @@ type LookupJoin struct { } var _ RelExpr = (*LookupJoin)(nil) +var _ fmt.Formatter = (*LookupJoin)(nil) +var _ fmt.Stringer = (*LookupJoin)(nil) var _ JoinRel = (*LookupJoin)(nil) func (r *LookupJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *LookupJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *LookupJoin) JoinPrivate() *JoinBase { @@ -108,10 +144,16 @@ type RangeHeapJoin struct { } var _ RelExpr = (*RangeHeapJoin)(nil) +var _ fmt.Formatter = (*RangeHeapJoin)(nil) +var _ fmt.Stringer = (*RangeHeapJoin)(nil) var _ JoinRel = (*RangeHeapJoin)(nil) func (r *RangeHeapJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *RangeHeapJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *RangeHeapJoin) JoinPrivate() *JoinBase { @@ -124,10 +166,16 @@ type ConcatJoin struct { } var _ RelExpr = (*ConcatJoin)(nil) +var _ fmt.Formatter = (*ConcatJoin)(nil) +var _ fmt.Stringer = (*ConcatJoin)(nil) var _ JoinRel = (*ConcatJoin)(nil) func (r *ConcatJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *ConcatJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *ConcatJoin) JoinPrivate() *JoinBase { @@ -141,10 +189,16 @@ type HashJoin struct { } var _ RelExpr = (*HashJoin)(nil) +var _ fmt.Formatter = (*HashJoin)(nil) +var _ fmt.Stringer = (*HashJoin)(nil) var _ JoinRel = (*HashJoin)(nil) func (r *HashJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *HashJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *HashJoin) JoinPrivate() *JoinBase { @@ -161,10 +215,16 @@ type MergeJoin struct { } var _ RelExpr = (*MergeJoin)(nil) +var _ fmt.Formatter = (*MergeJoin)(nil) +var _ fmt.Stringer = (*MergeJoin)(nil) var _ JoinRel = (*MergeJoin)(nil) func (r *MergeJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *MergeJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *MergeJoin) JoinPrivate() *JoinBase { @@ -176,10 +236,16 @@ type FullOuterJoin struct { } var _ RelExpr = (*FullOuterJoin)(nil) +var _ fmt.Formatter = (*FullOuterJoin)(nil) +var _ fmt.Stringer = (*FullOuterJoin)(nil) var _ JoinRel = (*FullOuterJoin)(nil) func (r *FullOuterJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *FullOuterJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *FullOuterJoin) JoinPrivate() *JoinBase { @@ -191,10 +257,16 @@ type LateralJoin struct { } var _ RelExpr = (*LateralJoin)(nil) +var _ fmt.Formatter = (*LateralJoin)(nil) +var _ fmt.Stringer = (*LateralJoin)(nil) var _ JoinRel = (*LateralJoin)(nil) func (r *LateralJoin) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *LateralJoin) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *LateralJoin) JoinPrivate() *JoinBase { @@ -207,10 +279,16 @@ type TableScan struct { } var _ RelExpr = (*TableScan)(nil) +var _ fmt.Formatter = (*TableScan)(nil) +var _ fmt.Stringer = (*TableScan)(nil) var _ SourceRel = (*TableScan)(nil) func (r *TableScan) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *TableScan) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *TableScan) Name() string { @@ -242,10 +320,16 @@ type IndexScan struct { } var _ RelExpr = (*IndexScan)(nil) +var _ fmt.Formatter = (*IndexScan)(nil) +var _ fmt.Stringer = (*IndexScan)(nil) var _ SourceRel = (*IndexScan)(nil) func (r *IndexScan) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *IndexScan) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *IndexScan) Name() string { @@ -274,10 +358,16 @@ type Values struct { } var _ RelExpr = (*Values)(nil) +var _ fmt.Formatter = (*Values)(nil) +var _ fmt.Stringer = (*Values)(nil) var _ SourceRel = (*Values)(nil) func (r *Values) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *Values) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *Values) Name() string { @@ -306,10 +396,16 @@ type TableAlias struct { } var _ RelExpr = (*TableAlias)(nil) +var _ fmt.Formatter = (*TableAlias)(nil) +var _ fmt.Stringer = (*TableAlias)(nil) var _ SourceRel = (*TableAlias)(nil) func (r *TableAlias) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *TableAlias) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *TableAlias) Name() string { @@ -338,10 +434,16 @@ type RecursiveTable struct { } var _ RelExpr = (*RecursiveTable)(nil) +var _ fmt.Formatter = (*RecursiveTable)(nil) +var _ fmt.Stringer = (*RecursiveTable)(nil) var _ SourceRel = (*RecursiveTable)(nil) func (r *RecursiveTable) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *RecursiveTable) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *RecursiveTable) Name() string { @@ -370,10 +472,16 @@ type RecursiveCte struct { } var _ RelExpr = (*RecursiveCte)(nil) +var _ fmt.Formatter = (*RecursiveCte)(nil) +var _ fmt.Stringer = (*RecursiveCte)(nil) var _ SourceRel = (*RecursiveCte)(nil) func (r *RecursiveCte) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *RecursiveCte) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *RecursiveCte) Name() string { @@ -402,10 +510,16 @@ type SubqueryAlias struct { } var _ RelExpr = (*SubqueryAlias)(nil) +var _ fmt.Formatter = (*SubqueryAlias)(nil) +var _ fmt.Stringer = (*SubqueryAlias)(nil) var _ SourceRel = (*SubqueryAlias)(nil) func (r *SubqueryAlias) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *SubqueryAlias) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *SubqueryAlias) Name() string { @@ -434,10 +548,16 @@ type TableFunc struct { } var _ RelExpr = (*TableFunc)(nil) +var _ fmt.Formatter = (*TableFunc)(nil) +var _ fmt.Stringer = (*TableFunc)(nil) var _ SourceRel = (*TableFunc)(nil) func (r *TableFunc) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *TableFunc) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *TableFunc) Name() string { @@ -466,10 +586,16 @@ type JSONTable struct { } var _ RelExpr = (*JSONTable)(nil) +var _ fmt.Formatter = (*JSONTable)(nil) +var _ fmt.Stringer = (*JSONTable)(nil) var _ SourceRel = (*JSONTable)(nil) func (r *JSONTable) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *JSONTable) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *JSONTable) Name() string { @@ -498,10 +624,16 @@ type EmptyTable struct { } var _ RelExpr = (*EmptyTable)(nil) +var _ fmt.Formatter = (*EmptyTable)(nil) +var _ fmt.Stringer = (*EmptyTable)(nil) var _ SourceRel = (*EmptyTable)(nil) func (r *EmptyTable) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *EmptyTable) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *EmptyTable) Name() string { @@ -530,10 +662,16 @@ type SetOp struct { } var _ RelExpr = (*SetOp)(nil) +var _ fmt.Formatter = (*SetOp)(nil) +var _ fmt.Stringer = (*SetOp)(nil) var _ SourceRel = (*SetOp)(nil) func (r *SetOp) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *SetOp) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *SetOp) Name() string { @@ -563,9 +701,15 @@ type Project struct { } var _ RelExpr = (*Project)(nil) +var _ fmt.Formatter = (*Project)(nil) +var _ fmt.Stringer = (*Project)(nil) func (r *Project) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *Project) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *Project) Children() []*ExprGroup { @@ -582,9 +726,15 @@ type Distinct struct { } var _ RelExpr = (*Distinct)(nil) +var _ fmt.Formatter = (*Distinct)(nil) +var _ fmt.Stringer = (*Distinct)(nil) func (r *Distinct) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *Distinct) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *Distinct) Children() []*ExprGroup { @@ -601,9 +751,15 @@ type Max1Row struct { } var _ RelExpr = (*Max1Row)(nil) +var _ fmt.Formatter = (*Max1Row)(nil) +var _ fmt.Stringer = (*Max1Row)(nil) func (r *Max1Row) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *Max1Row) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *Max1Row) Children() []*ExprGroup { @@ -621,9 +777,15 @@ type Filter struct { } var _ RelExpr = (*Filter)(nil) +var _ fmt.Formatter = (*Filter)(nil) +var _ fmt.Stringer = (*Filter)(nil) func (r *Filter) String() string { - return FormatExpr(r) + return fmt.Sprintf("%s", r) +} + +func (r *Filter) Format(s fmt.State, verb rune) { + FormatExpr(r, s, verb) } func (r *Filter) Children() []*ExprGroup { @@ -634,70 +796,6 @@ func (r *Filter) outputCols() sql.ColSet { return r.Child.RelProps.OutputCols() } -func FormatExpr(r exprType) string { - switch r := r.(type) { - case *CrossJoin: - return fmt.Sprintf("crossjoin %d %d", r.Left.Id, r.Right.Id) - case *InnerJoin: - return fmt.Sprintf("innerjoin %d %d", r.Left.Id, r.Right.Id) - case *LeftJoin: - return fmt.Sprintf("leftjoin %d %d", r.Left.Id, r.Right.Id) - case *SemiJoin: - return fmt.Sprintf("semijoin %d %d", r.Left.Id, r.Right.Id) - case *AntiJoin: - return fmt.Sprintf("antijoin %d %d", r.Left.Id, r.Right.Id) - case *LookupJoin: - return fmt.Sprintf("lookupjoin %d %d", r.Left.Id, r.Right.Id) - case *RangeHeapJoin: - return fmt.Sprintf("rangeheapjoin %d %d", r.Left.Id, r.Right.Id) - case *ConcatJoin: - return fmt.Sprintf("concatjoin %d %d", r.Left.Id, r.Right.Id) - case *HashJoin: - return fmt.Sprintf("hashjoin %d %d", r.Left.Id, r.Right.Id) - case *MergeJoin: - return fmt.Sprintf("mergejoin %d %d", r.Left.Id, r.Right.Id) - case *FullOuterJoin: - return fmt.Sprintf("fullouterjoin %d %d", r.Left.Id, r.Right.Id) - case *LateralJoin: - return fmt.Sprintf("lateraljoin %d %d", r.Left.Id, r.Right.Id) - case *TableScan: - return fmt.Sprintf("tablescan: %s", r.Name()) - case *IndexScan: - if r.Alias != "" { - return fmt.Sprintf("indexscan: %s", r.Alias) - } - return fmt.Sprintf("indexscan: %s", r.Name()) - case *Values: - return fmt.Sprintf("values: %s", r.Name()) - case *TableAlias: - return fmt.Sprintf("tablealias: %s", r.Name()) - case *RecursiveTable: - return fmt.Sprintf("recursivetable: %s", r.Name()) - case *RecursiveCte: - return fmt.Sprintf("recursivecte: %s", r.Name()) - case *SubqueryAlias: - return fmt.Sprintf("subqueryalias: %s", r.Name()) - case *TableFunc: - return fmt.Sprintf("tablefunc: %s", r.Name()) - case *JSONTable: - return fmt.Sprintf("jsontable: %s", r.Name()) - case *EmptyTable: - return fmt.Sprintf("emptytable: %s", r.Name()) - case *SetOp: - return fmt.Sprintf("setop: %s", r.Name()) - case *Project: - return fmt.Sprintf("project: %d", r.Child.Id) - case *Distinct: - return fmt.Sprintf("distinct: %d", r.Child.Id) - case *Max1Row: - return fmt.Sprintf("max1row: %d", r.Child.Id) - case *Filter: - return fmt.Sprintf("filter: %d", r.Child.Id) - default: - panic(fmt.Sprintf("unknown RelExpr type: %T", r)) - } -} - func buildRelExpr(b *ExecBuilder, r RelExpr, children ...sql.Node) (sql.Node, error) { var result sql.Node var err error diff --git a/sql/memo/rel_props.go b/sql/memo/rel_props.go index f1cc8d5d17..2601e61f42 100644 --- a/sql/memo/rel_props.go +++ b/sql/memo/rel_props.go @@ -312,21 +312,24 @@ func (p *relProps) populateFds() { p.fds = fds } -func CardMemoGroups(ctx *sql.Context, g *ExprGroup) { +func (m *Memo) CardMemoGroups(ctx *sql.Context, g *ExprGroup) { // card checking is called after indexScans and lookups joins are generated, // both of which have metadata that makes cardinality estimation more // accurate. if g.RelProps.stat != nil { return } - for _, g := range g.children() { - CardMemoGroups(ctx, g) + for g := range g.children() { + m.CardMemoGroups(ctx, g) } - s := statsForRel(ctx, g.First) + s := m.statsForRel(ctx, g.First) g.RelProps.SetStats(s) } -func statsForRel(ctx *sql.Context, rel RelExpr) sql.Statistic { +func (m *Memo) statsForRel(ctx *sql.Context, rel RelExpr) sql.Statistic { + m.Tracer.PushDebugContext("statsForRel") + defer m.Tracer.PopDebugContext() + var stat sql.Statistic switch rel := rel.(type) { case JoinRel: diff --git a/sql/memo/select_hints.go b/sql/memo/select_hints.go index c197181453..a3a69c3def 100644 --- a/sql/memo/select_hints.go +++ b/sql/memo/select_hints.go @@ -199,7 +199,7 @@ func (o joinOrderHint) build(grp *ExprGroup) { } o.groups[grp.Id] = s - for _, g := range grp.children() { + for g := range grp.children() { if _, ok := o.groups[g.Id]; !ok { // avoid duplicate work o.build(g) diff --git a/sql/memo/trace_logger.go b/sql/memo/trace_logger.go new file mode 100755 index 0000000000..78e6cccbd9 --- /dev/null +++ b/sql/memo/trace_logger.go @@ -0,0 +1,65 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package memo + +import ( + "fmt" + "strings" + + "github.com/sirupsen/logrus" +) + +type TraceLogger struct { + // A stack of debugger context. See PushDebugContext, PopDebugContext + contextStack []string + TraceEnabled bool +} + +var log = logrus.New() + +// PushDebugContext pushes the given context string onto the context stack, to use when logging debug messages. +func (a *TraceLogger) PushDebugContext(msg string) { + if a != nil && a.TraceEnabled { + a.contextStack = append(a.contextStack, msg) + } +} + +// PushDebugContextFmt pushes a formatted context string onto the context stack, to use when logging debug messages. +// Useful to avoid the cost of formatting when tracing is disabled. +func (a *TraceLogger) PushDebugContextFmt(fmtStr string, args ...any) { + if a != nil && a.TraceEnabled { + a.contextStack = append(a.contextStack, fmt.Sprintf(fmtStr, args...)) + } +} + +// PopDebugContext pops a context message off the context stack. +func (a *TraceLogger) PopDebugContext() { + if a != nil && a.TraceEnabled && len(a.contextStack) > 0 { + a.contextStack = a.contextStack[:len(a.contextStack)-1] + } +} + +// Log prints an INFO message to stdout with the given message and args +// if the analyzer is in debug mode. +func (a *TraceLogger) Log(msg string, args ...interface{}) { + if a != nil && a.TraceEnabled { + if len(a.contextStack) > 0 { + ctx := strings.Join(a.contextStack, "/") + fmt.Printf("%s: "+msg+"\n", append([]interface{}{ctx}, args...)...) + } else { + fmt.Printf(msg+"\n", args...) + } + } +}