diff --git a/engine.go b/engine.go index faee6d5dbb..a555633ce3 100644 --- a/engine.go +++ b/engine.go @@ -447,7 +447,7 @@ func (e *Engine) QueryWithBindings(ctx *sql.Context, query string, parsed sqlpar return nil, nil, nil, err } - iter = finalizeIters(analyzed, qFlags, iter) + iter = finalizeIters(ctx, analyzed, qFlags, iter) return analyzed.Schema(), iter, qFlags, nil } @@ -482,7 +482,7 @@ func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql. return nil, nil, nil, err } - iter = finalizeIters(plan, nil, iter) + iter = finalizeIters(ctx, plan, nil, iter) return plan.Schema(), iter, nil, nil } @@ -830,7 +830,7 @@ func (e *Engine) executeEvent(ctx *sql.Context, dbName, createEventStatement, us return err } - iter = finalizeIters(definitionNode, nil, iter) + iter = finalizeIters(ctx, definitionNode, nil, iter) // Drain the iterate to execute the event body/definition // NOTE: No row data is returned for an event; we just need to execute the statements @@ -865,8 +865,9 @@ func findCreateEventNode(planTree sql.Node) (*plan.CreateEvent, error) { } // finalizeIters applies the final transformations on sql.RowIter before execution. -func finalizeIters(analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter { - iter = rowexec.AddTransactionCommittingIter(iter, qFlags) +func finalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter { + iter = rowexec.AddTransactionCommittingIter(qFlags, iter) + iter = plan.AddTrackedRowIter(ctx, analyzed, iter) iter = rowexec.AddExpressionCloser(analyzed, iter) return iter } diff --git a/engine_test.go b/engine_test.go index 0986f10d5e..4838b34911 100755 --- a/engine_test.go +++ b/engine_test.go @@ -1,4 +1,4 @@ -// Copyright 2023 Dolthub, Inc. +// Copyright 2023-2024 Dolthub, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,15 +15,21 @@ package sqle import ( + "context" "testing" "time" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/stretchr/testify/require" + "github.com/dolthub/go-mysql-server/memory" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/analyzer" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/rowexec" "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/go-mysql-server/sql/variables" ) func TestBindingsToExprs(t *testing.T) { @@ -145,3 +151,93 @@ func TestBindingsToExprs(t *testing.T) { }) } } + +// wrapper around sql.Table to make it not indexable +type nonIndexableTable struct { + *memory.Table +} + +var _ memory.MemTable = (*nonIndexableTable)(nil) + +func (t *nonIndexableTable) IgnoreSessionData() bool { + return true +} + +func getRuleFrom(rules []analyzer.Rule, id analyzer.RuleId) *analyzer.Rule { + for _, rule := range rules { + if rule.Id == id { + return &rule + } + } + + return nil +} + +// TODO: this was an analyzer test, but we don't have a mock process list for it to use, so it has to be here +func TestTrackProcess(t *testing.T) { + require := require.New(t) + variables.InitStatusVariables() + db := memory.NewDatabase("db") + provider := memory.NewDBProvider(db) + a := analyzer.NewDefault(provider) + sess := memory.NewSession(sql.NewBaseSession(), provider) + + node := plan.NewInnerJoin( + plan.NewResolvedTable(&nonIndexableTable{memory.NewPartitionedTable(db.BaseDatabase, "foo", sql.PrimaryKeySchema{}, nil, 2)}, nil, nil), + plan.NewResolvedTable(memory.NewPartitionedTable(db.BaseDatabase, "bar", sql.PrimaryKeySchema{}, nil, 4), nil, nil), + expression.NewLiteral(int64(1), types.Int64), + ) + + pl := NewProcessList() + + ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithProcessList(pl), sql.WithSession(sess)) + pl.AddConnection(ctx.Session.ID(), "localhost") + pl.ConnectionReady(ctx.Session) + ctx, err := ctx.ProcessList.BeginQuery(ctx, "SELECT foo") + require.NoError(err) + + rule := getRuleFrom(analyzer.OnceAfterAll, analyzer.TrackProcessId) + result, _, err := rule.Apply(ctx, a, node, nil, analyzer.DefaultRuleSelector, nil) + require.NoError(err) + + processes := ctx.ProcessList.Processes() + require.Len(processes, 1) + require.Equal("SELECT foo", processes[0].Query) + require.Equal( + map[string]sql.TableProgress{ + "foo": { + Progress: sql.Progress{Name: "foo", Done: 0, Total: 2}, + PartitionsProgress: map[string]sql.PartitionProgress{}, + }, + "bar": { + Progress: sql.Progress{Name: "bar", Done: 0, Total: 4}, + PartitionsProgress: map[string]sql.PartitionProgress{}, + }, + }, + processes[0].Progress) + + join, ok := result.(*plan.JoinNode) + require.True(ok) + require.Equal(plan.JoinTypeInner, join.JoinType()) + + lhs, ok := join.Left().(*plan.ResolvedTable) + require.True(ok) + _, ok = lhs.Table.(*plan.ProcessTable) + require.True(ok) + + rhs, ok := join.Right().(*plan.ResolvedTable) + require.True(ok) + _, ok = rhs.Table.(*plan.ProcessTable) + require.True(ok) + + iter, err := rowexec.DefaultBuilder.Build(ctx, result, nil) + iter = finalizeIters(ctx, result, nil, iter) + require.NoError(err) + _, err = sql.RowIterToRows(ctx, iter) + require.NoError(err) + + processes = ctx.ProcessList.Processes() + require.Len(processes, 1) + require.Equal(sql.ProcessCommandSleep, processes[0].Command) + require.Error(ctx.Err()) +} diff --git a/enginetest/engine_only_test.go b/enginetest/engine_only_test.go index 92c16714f0..f8b709d9ad 100644 --- a/enginetest/engine_only_test.go +++ b/enginetest/engine_only_test.go @@ -293,74 +293,6 @@ b (2/6 partitions) require.ElementsMatch(expected, rows) } -// TODO: this was an analyzer test, but we don't have a mock process list for it to use, so it has to be here -func TestTrackProcess(t *testing.T) { - require := require.New(t) - db := memory.NewDatabase("db") - provider := memory.NewDBProvider(db) - a := analyzer.NewDefault(provider) - sess := memory.NewSession(sql.NewBaseSession(), provider) - - node := plan.NewInnerJoin( - plan.NewResolvedTable(&nonIndexableTable{memory.NewPartitionedTable(db.BaseDatabase, "foo", sql.PrimaryKeySchema{}, nil, 2)}, nil, nil), - plan.NewResolvedTable(memory.NewPartitionedTable(db.BaseDatabase, "bar", sql.PrimaryKeySchema{}, nil, 4), nil, nil), - expression.NewLiteral(int64(1), types.Int64), - ) - - pl := sqle.NewProcessList() - - ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithProcessList(pl), sql.WithSession(sess)) - pl.AddConnection(ctx.Session.ID(), "localhost") - pl.ConnectionReady(ctx.Session) - ctx, err := ctx.ProcessList.BeginQuery(ctx, "SELECT foo") - require.NoError(err) - - rule := getRuleFrom(analyzer.OnceAfterAll, analyzer.TrackProcessId) - result, _, err := rule.Apply(ctx, a, node, nil, analyzer.DefaultRuleSelector, nil) - require.NoError(err) - - processes := ctx.ProcessList.Processes() - require.Len(processes, 1) - require.Equal("SELECT foo", processes[0].Query) - require.Equal( - map[string]sql.TableProgress{ - "foo": sql.TableProgress{ - Progress: sql.Progress{Name: "foo", Done: 0, Total: 2}, - PartitionsProgress: map[string]sql.PartitionProgress{}}, - "bar": sql.TableProgress{ - Progress: sql.Progress{Name: "bar", Done: 0, Total: 4}, - PartitionsProgress: map[string]sql.PartitionProgress{}}, - }, - processes[0].Progress) - - proc, ok := result.(*plan.QueryProcess) - require.True(ok) - - join, ok := proc.Child().(*plan.JoinNode) - require.True(ok) - require.Equal(join.JoinType(), plan.JoinTypeInner) - - lhs, ok := join.Left().(*plan.ResolvedTable) - require.True(ok) - _, ok = lhs.Table.(*plan.ProcessTable) - require.True(ok) - - rhs, ok := join.Right().(*plan.ResolvedTable) - require.True(ok) - _, ok = rhs.Table.(*plan.ProcessTable) - require.True(ok) - - iter, err := rowexec.DefaultBuilder.Build(ctx, proc, nil) - require.NoError(err) - _, err = sql.RowIterToRows(ctx, iter) - require.NoError(err) - - procs := ctx.ProcessList.Processes() - require.Len(procs, 1) - require.Equal(procs[0].Command, sql.ProcessCommandSleep) - require.Error(ctx.Err()) -} - func TestConcurrentProcessList(t *testing.T) { enginetest.TestConcurrentProcessList(t, enginetest.NewDefaultMemoryHarness()) } @@ -515,7 +447,6 @@ func TestAnalyzer_Exp(t *testing.T) { require.NoError(t, err) analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, nil) - analyzed = analyzer.StripPassthroughNodes(analyzed) if tt.err != nil { require.Error(t, err) assert.True(t, tt.err.Is(err)) @@ -527,10 +458,6 @@ func TestAnalyzer_Exp(t *testing.T) { } func assertNodesEqualWithDiff(t *testing.T, expected, actual sql.Node) { - if x, ok := actual.(*plan.QueryProcess); ok { - actual = x.Child() - } - if !assert.Equal(t, expected, actual) { expectedStr := sql.DebugString(expected) actualStr := sql.DebugString(actual) diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index ef57d62dbb..8e0994079a 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -1080,8 +1080,6 @@ func assertSchemasEqualWithDefaults(t *testing.T, expected, actual sql.Schema) b func ExtractQueryNode(node sql.Node) sql.Node { switch node := node.(type) { - case *plan.QueryProcess: - return ExtractQueryNode(node.Child()) case *plan.Releaser: return ExtractQueryNode(node.Child) default: diff --git a/sql/analyzer/describe.go b/sql/analyzer/describe.go index eacf6fa049..1e85ff84db 100644 --- a/sql/analyzer/describe.go +++ b/sql/analyzer/describe.go @@ -32,5 +32,5 @@ func resolveDescribeQuery(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan return nil, transform.SameTree, err } - return d.WithQuery(StripPassthroughNodes(q)), transform.NewTree, nil + return d.WithQuery(q), transform.NewTree, nil } diff --git a/sql/analyzer/inserts.go b/sql/analyzer/inserts.go index fdc2040219..69262461f1 100644 --- a/sql/analyzer/inserts.go +++ b/sql/analyzer/inserts.go @@ -62,8 +62,6 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc if err != nil { return nil, transform.SameTree, err } - - source = StripPassthroughNodes(source) } dstSchema := insertable.Schema() diff --git a/sql/analyzer/parallelize.go b/sql/analyzer/parallelize.go index d36c38636d..7f4b32f3d7 100644 --- a/sql/analyzer/parallelize.go +++ b/sql/analyzer/parallelize.go @@ -63,8 +63,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope return node, transform.SameTree, nil } - proc, ok := node.(*plan.QueryProcess) - if (ok && !shouldParallelize(proc.Child(), nil)) || !shouldParallelize(node, scope) { + if !shouldParallelize(node, scope) { return node, transform.SameTree, nil } diff --git a/sql/analyzer/process.go b/sql/analyzer/process.go index e1b40e0598..77904032d5 100644 --- a/sql/analyzer/process.go +++ b/sql/analyzer/process.go @@ -39,15 +39,10 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, if !n.Resolved() { return n, transform.SameTree, nil } - - if _, ok := n.(*plan.QueryProcess); ok { - return n, transform.SameTree, nil - } - processList := ctx.ProcessList var seen = make(map[string]struct{}) - n, _, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { + n, same, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { switch n := n.(type) { case *plan.ResolvedTable: switch n.Table.(type) { @@ -106,41 +101,9 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, return n, transform.SameTree, nil } }) - if err != nil { - return nil, transform.SameTree, err - } - - // Don't wrap CreateIndex in a QueryProcess, as it is a CreateIndexProcess. - // CreateIndex will take care of marking the process as done on its own. - if _, ok := n.(*plan.CreateIndex); ok { - return n, transform.SameTree, nil - } - // Remove QueryProcess nodes from the subqueries and trigger bodies. Otherwise, the process - // will be marked as done as soon as a subquery / trigger finishes. - node, _, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { - if sq, ok := n.(*plan.SubqueryAlias); ok { - if qp, ok := sq.Child.(*plan.QueryProcess); ok { - n, err := sq.WithChildren(qp.Child()) - return n, transform.NewTree, err - } - } - if t, ok := n.(*plan.TriggerExecutor); ok { - if qp, ok := t.Right().(*plan.QueryProcess); ok { - n, err := t.WithChildren(t.Left(), qp.Child()) - return n, transform.NewTree, err - } - } - return n, transform.SameTree, nil - }) if err != nil { return nil, transform.SameTree, err } - - return plan.NewQueryProcess(node, func() { - processList.EndQuery(ctx) - if span := ctx.RootSpan(); span != nil { - span.End() - } - }), transform.NewTree, nil + return n, same, nil } diff --git a/sql/analyzer/process_test.go b/sql/analyzer/process_test.go deleted file mode 100644 index 970013c213..0000000000 --- a/sql/analyzer/process_test.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2020-2021 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package analyzer - -import ( - "testing" - - "github.com/stretchr/testify/require" - - "github.com/dolthub/go-mysql-server/memory" - "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/plan" -) - -func TestTrackProcessSubquery(t *testing.T) { - require := require.New(t) - rule := getRuleFrom(OnceAfterAll, TrackProcessId) - a := NewDefault(sql.NewDatabaseProvider()) - - db := memory.NewDatabase("db") - pro := memory.NewDBProvider(db) - ctx := newContext(pro) - - node := plan.NewProject( - nil, - plan.NewSubqueryAlias("f", "", - plan.NewQueryProcess( - plan.NewResolvedTable(memory.NewTable(db, "foo", sql.PrimaryKeySchema{}, nil), nil, nil), - nil, - ), - ), - ) - - result, _, err := rule.Apply(ctx, a, node, nil, DefaultRuleSelector, nil) - require.NoError(err) - - expectedChild := plan.NewProject( - nil, - plan.NewSubqueryAlias("f", "", - plan.NewResolvedTable(memory.NewTable(db, "foo", sql.PrimaryKeySchema{}, nil), nil, nil), - ), - ) - - proc, ok := result.(*plan.QueryProcess) - require.True(ok) - require.Equal(expectedChild, proc.Child()) -} - -// wrapper around sql.Table to make it not indexable -type table struct { - sql.Table -} - -var _ sql.PartitionCounter = (*table)(nil) - -func (t *table) PartitionCount(ctx *sql.Context) (int64, error) { - return t.Table.(sql.PartitionCounter).PartitionCount(ctx) -} diff --git a/sql/analyzer/resolve_create_select.go b/sql/analyzer/resolve_create_select.go index 34960ca178..d7c7de9965 100644 --- a/sql/analyzer/resolve_create_select.go +++ b/sql/analyzer/resolve_create_select.go @@ -55,7 +55,7 @@ func resolveCreateSelect(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan. return nil, transform.SameTree, err } - return plan.NewTableCopier(ct.Database(), StripPassthroughNodes(analyzedCreate), StripPassthroughNodes(analyzedSelect), plan.CopierProps{}), transform.NewTree, nil + return plan.NewTableCopier(ct.Database(), analyzedCreate, analyzedSelect, plan.CopierProps{}), transform.NewTree, nil } // stripSchema removes all non-type information from a schema, such as the key info, default value, etc. diff --git a/sql/analyzer/resolve_subqueries.go b/sql/analyzer/resolve_subqueries.go index fcd114bd7d..39c2969aca 100644 --- a/sql/analyzer/resolve_subqueries.go +++ b/sql/analyzer/resolve_subqueries.go @@ -252,7 +252,7 @@ func analyzeSubqueryExpression(ctx *sql.Context, a *Analyzer, n sql.Node, sq *pl // to the expense of positive errors, where a rule reports a change when the plan // is the same before/after. // .Resolved() might be useful for fixing these bugs. - return sq.WithQuery(StripPassthroughNodes(analyzed)).WithExecBuilder(a.ExecBuilder), transform.NewTree, nil + return sq.WithQuery(analyzed).WithExecBuilder(a.ExecBuilder), transform.NewTree, nil } // analyzeSubqueryAlias runs analysis on the specified subquery alias, |sqa|. The |finalize| parameter indicates if this is @@ -282,28 +282,10 @@ func analyzeSubqueryAlias(ctx *sql.Context, a *Analyzer, sqa *plan.SubqueryAlias if same { return sqa, transform.SameTree, nil } - newn, err := sqa.WithChildren(StripPassthroughNodes(child)) + newn, err := sqa.WithChildren(child) return newn, transform.NewTree, err } -// StripPassthroughNodes strips all top-level passthrough nodes meant to apply only to top-level queries (query -// tracking, transaction logic, etc) from the node tree given and return the first non-passthrough child element. This -// is useful for when we invoke the analyzer recursively when e.g. analyzing subqueries or triggers -// TODO: instead of stripping this node off after analysis, it would be better to just not add it in the first place. -func StripPassthroughNodes(n sql.Node) sql.Node { - nodeIsPassthrough := true - for nodeIsPassthrough { - switch tn := n.(type) { - case *plan.QueryProcess: - n = tn.Child() - default: - nodeIsPassthrough = false - } - } - - return n -} - // cacheSubqueryAlisesInJoins will look for joins against subquery aliases that // will repeatedly execute the subquery, and will insert a *plan.CachedResults // node on top of those nodes. The left-most child of a join root is an exception diff --git a/sql/analyzer/resolve_unions.go b/sql/analyzer/resolve_unions.go index 4efcad572e..5425a54201 100644 --- a/sql/analyzer/resolve_unions.go +++ b/sql/analyzer/resolve_unions.go @@ -51,7 +51,7 @@ func resolveUnions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, return nil, transform.SameTree, err } - ret, err := n.WithChildren(StripPassthroughNodes(left), StripPassthroughNodes(right)) + ret, err := n.WithChildren(left, right) if err != nil { return nil, transform.SameTree, err } @@ -95,7 +95,7 @@ func finalizeUnions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope scope.SetJoin(false) - newn, err := n.WithChildren(StripPassthroughNodes(left), StripPassthroughNodes(right)) + newn, err := n.WithChildren(left, right) if err != nil { return nil, transform.SameTree, err } diff --git a/sql/analyzer/stored_procedures.go b/sql/analyzer/stored_procedures.go index 26f1584711..35a555a522 100644 --- a/sql/analyzer/stored_procedures.go +++ b/sql/analyzer/stored_procedures.go @@ -154,7 +154,7 @@ func analyzeProcedureBodies(ctx *sql.Context, a *Analyzer, node sql.Node, skipCa if err != nil { return nil, transform.SameTree, err } - newChildren[i] = StripPassthroughNodes(newChild) + newChildren[i] = newChild } node, err = node.WithChildren(newChildren...) if err != nil { diff --git a/sql/analyzer/triggers.go b/sql/analyzer/triggers.go index e60ced633b..7fb8b4402a 100644 --- a/sql/analyzer/triggers.go +++ b/sql/analyzer/triggers.go @@ -475,7 +475,7 @@ func getTriggerLogic(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scop triggerLogic, _, err = a.analyzeWithSelector(ctx, trigger.Body, s, SelectAllBatches, noRowUpdateAccumulators, qFlags) } - return StripPassthroughNodes(triggerLogic), err + return triggerLogic, err } // validateNoCircularUpdates returns an error if the trigger logic attempts to update the table that invoked it (or any diff --git a/sql/plan/process.go b/sql/plan/process.go index db9e510288..b25b568a37 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -22,92 +22,9 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) -// QueryProcess represents a running query process node. It will use a callback -// to notify when it has finished running. -// TODO: QueryProcess -> trackedRowIter is required to dispose certain iter caches. -// Make a proper scheduler interface to perform lifecycle management, caching, and -// scan attaching -type QueryProcess struct { - UnaryNode - Notify NotifyFunc -} - -var _ sql.Node = (*QueryProcess)(nil) -var _ sql.CollationCoercible = (*QueryProcess)(nil) - // NotifyFunc is a function to notify about some event. type NotifyFunc func() -// NewQueryProcess creates a new QueryProcess node. -func NewQueryProcess(node sql.Node, notify NotifyFunc) *QueryProcess { - return &QueryProcess{UnaryNode{Child: node}, notify} -} - -func (p *QueryProcess) Child() sql.Node { - return p.UnaryNode.Child -} - -func (p *QueryProcess) IsReadOnly() bool { - return p.Child().IsReadOnly() -} - -// WithChildren implements the Node interface. -func (p *QueryProcess) WithChildren(children ...sql.Node) (sql.Node, error) { - if len(children) != 1 { - return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) - } - - return NewQueryProcess(children[0], p.Notify), nil -} - -// CheckPrivileges implements the interface sql.Node. -func (p *QueryProcess) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { - return p.Child().CheckPrivileges(ctx, opChecker) -} - -// CollationCoercibility implements the interface sql.CollationCoercible. -func (p *QueryProcess) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - return sql.GetCoercibility(ctx, p.Child()) -} - -func (p *QueryProcess) String() string { return p.Child().String() } - -func (p *QueryProcess) DebugString() string { - tp := sql.NewTreePrinter() - _ = tp.WriteNode("QueryProcess") - _ = tp.WriteChildren(sql.DebugString(p.Child())) - return tp.String() -} - -// ShouldSetFoundRows returns whether the query process should set the FOUND_ROWS query variable. It should do this for -// any select except a Limit with a SQL_CALC_FOUND_ROWS modifier, which is handled in the Limit node itself. -func (p *QueryProcess) ShouldSetFoundRows() bool { - var fromLimit *bool - var fromTopN *bool - transform.Inspect(p.Child(), func(n sql.Node) bool { - switch n := n.(type) { - case *StartTransaction: - return true - case *Limit: - fromLimit = &n.CalcFoundRows - return true - case *TopN: - fromTopN = &n.CalcFoundRows - return true - default: - return true - } - }) - - if fromLimit == nil && fromTopN == nil { - return true - } - if fromTopN != nil { - return !*fromTopN - } - return !*fromLimit -} - // ProcessIndexableTable is a wrapper for sql.Tables inside a query process // that support indexing. // It notifies the process manager about the status of a query when a @@ -325,6 +242,38 @@ func NewTrackedRowIter( return &TrackedRowIter{node: node, iter: iter, onDone: onDone, onNext: onNext} } +// ShouldSetFoundRows returns whether the query process should set the FOUND_ROWS query variable. It should do this for +// any select except a Limit with a SQL_CALC_FOUND_ROWS modifier, which is handled in the Limit node itself. +func shouldSetFoundRows(node sql.Node) bool { + result := true + transform.Inspect(node, func(n sql.Node) bool { + switch nn := n.(type) { + case *Limit: + if nn.CalcFoundRows { + result = false + } + case *TopN: + if nn.CalcFoundRows { + result = false + } + } + return true + }) + return result +} + +func AddTrackedRowIter(ctx *sql.Context, node sql.Node, iter sql.RowIter) sql.RowIter { + trackedIter := NewTrackedRowIter(node, iter, nil, func() { + ctx.ProcessList.EndQuery(ctx) + if span := ctx.RootSpan(); span != nil { + span.End() + } + }) + trackedIter.QueryType = GetQueryType(node) + trackedIter.ShouldSetFoundRows = trackedIter.QueryType == QueryTypeSelect && shouldSetFoundRows(node) + return trackedIter +} + func (i *TrackedRowIter) done() { if i.onDone != nil { i.onDone() diff --git a/sql/rowexec/node_builder.gen.go b/sql/rowexec/node_builder.gen.go index 2bde43b174..4ec32514df 100644 --- a/sql/rowexec/node_builder.gen.go +++ b/sql/rowexec/node_builder.gen.go @@ -62,8 +62,6 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s return b.buildUpdateHistogram(ctx, n, row) case *plan.DropHistogram: return b.buildDropHistogram(ctx, n, row) - case *plan.QueryProcess: - return b.buildQueryProcess(ctx, n, row) case *plan.ShowBinlogs: return b.buildShowBinlogs(ctx, n, row) case *plan.ShowBinlogStatus: diff --git a/sql/rowexec/other.go b/sql/rowexec/other.go index f69094b68a..f1f6fef3b6 100644 --- a/sql/rowexec/other.go +++ b/sql/rowexec/other.go @@ -367,21 +367,6 @@ func (b *BaseBuilder) buildPrependNode(ctx *sql.Context, n *plan.PrependNode, ro }, nil } -func (b *BaseBuilder) buildQueryProcess(ctx *sql.Context, n *plan.QueryProcess, row sql.Row) (sql.RowIter, error) { - iter, err := b.Build(ctx, n.Child(), row) - if err != nil { - return nil, err - } - - qType := plan.GetQueryType(n.Child()) - - trackedIter := plan.NewTrackedRowIter(n.Child(), iter, nil, n.Notify) - trackedIter.QueryType = qType - trackedIter.ShouldSetFoundRows = qType == plan.QueryTypeSelect && n.ShouldSetFoundRows() - - return trackedIter, nil -} - func (b *BaseBuilder) buildAnalyzeTable(ctx *sql.Context, n *plan.AnalyzeTable, row sql.Row) (sql.RowIter, error) { // Assume table is in current database database := ctx.GetCurrentDatabase() diff --git a/sql/rowexec/process_test.go b/sql/rowexec/process_test.go index 6243c264be..10da29ef56 100644 --- a/sql/rowexec/process_test.go +++ b/sql/rowexec/process_test.go @@ -26,49 +26,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) -func TestQueryProcess(t *testing.T) { - require := require.New(t) - - db := memory.NewDatabase("test") - pro := memory.NewDBProvider(db) - ctx := newContext(pro) - - table := memory.NewTable(db.BaseDatabase, "foo", sql.NewPrimaryKeySchema(sql.Schema{ - {Name: "a", Type: types.Int64}, - }), nil) - - table.Insert(ctx, sql.NewRow(int64(1))) - table.Insert(ctx, sql.NewRow(int64(2))) - - var notifications int - - node := plan.NewQueryProcess( - plan.NewProject( - []sql.Expression{ - expression.NewGetField(0, types.Int64, "a", false), - }, - plan.NewResolvedTable(table, nil, nil), - ), - func() { - notifications++ - }, - ) - - iter, err := DefaultBuilder.Build(ctx, node, nil) - require.NoError(err) - - rows, err := sql.RowIterToRows(ctx, iter) - require.NoError(err) - - expected := []sql.Row{ - {int64(1)}, - {int64(2)}, - } - - require.ElementsMatch(expected, rows) - require.Equal(1, notifications) -} - func TestProcessTable(t *testing.T) { require := require.New(t) diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index ee3ad8da23..4332ce9c6d 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -74,17 +74,12 @@ type TransactionCommittingIter struct { transactionDatabase string } -func AddTransactionCommittingIter(child sql.RowIter, qFlags *sql.QueryFlags) sql.RowIter { +func AddTransactionCommittingIter(qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter { // TODO: This is a bit of a hack. Need to figure out better relationship between new transaction node and warnings. if qFlags != nil && qFlags.IsSet(sql.QFlagShowWarnings) { - return child + return iter } - // TODO: remove this once trackedRowIter is moved out of planbuilder - // Insert TransactionCommittingIter as child of TrackedRowIter - if trackedRowIter, ok := child.(*plan.TrackedRowIter); ok { - return trackedRowIter.WithChildIter(&TransactionCommittingIter{childIter: trackedRowIter.GetIter()}) - } - return &TransactionCommittingIter{childIter: child} + return &TransactionCommittingIter{childIter: iter} } func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) {