Skip to content

Commit ef9faa3

Browse files
authored
remove plan.QueryProcess and move logic to finalizeIters (#2714)
1 parent c5725b1 commit ef9faa3

19 files changed

+149
-371
lines changed

engine.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ func (e *Engine) QueryWithBindings(ctx *sql.Context, query string, parsed sqlpar
447447
return nil, nil, nil, err
448448
}
449449

450-
iter = finalizeIters(analyzed, qFlags, iter)
450+
iter = finalizeIters(ctx, analyzed, qFlags, iter)
451451

452452
return analyzed.Schema(), iter, qFlags, nil
453453
}
@@ -482,7 +482,7 @@ func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.
482482
return nil, nil, nil, err
483483
}
484484

485-
iter = finalizeIters(plan, nil, iter)
485+
iter = finalizeIters(ctx, plan, nil, iter)
486486

487487
return plan.Schema(), iter, nil, nil
488488
}
@@ -830,7 +830,7 @@ func (e *Engine) executeEvent(ctx *sql.Context, dbName, createEventStatement, us
830830
return err
831831
}
832832

833-
iter = finalizeIters(definitionNode, nil, iter)
833+
iter = finalizeIters(ctx, definitionNode, nil, iter)
834834

835835
// Drain the iterate to execute the event body/definition
836836
// NOTE: No row data is returned for an event; we just need to execute the statements
@@ -865,8 +865,9 @@ func findCreateEventNode(planTree sql.Node) (*plan.CreateEvent, error) {
865865
}
866866

867867
// finalizeIters applies the final transformations on sql.RowIter before execution.
868-
func finalizeIters(analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
869-
iter = rowexec.AddTransactionCommittingIter(iter, qFlags)
868+
func finalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
869+
iter = rowexec.AddTransactionCommittingIter(qFlags, iter)
870+
iter = plan.AddTrackedRowIter(ctx, analyzed, iter)
870871
iter = rowexec.AddExpressionCloser(analyzed, iter)
871872
return iter
872873
}

engine_test.go

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023 Dolthub, Inc.
1+
// Copyright 2023-2024 Dolthub, Inc.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -15,15 +15,21 @@
1515
package sqle
1616

1717
import (
18+
"context"
1819
"testing"
1920
"time"
2021

2122
"github.com/dolthub/vitess/go/vt/proto/query"
2223
"github.com/stretchr/testify/require"
2324

25+
"github.com/dolthub/go-mysql-server/memory"
2426
"github.com/dolthub/go-mysql-server/sql"
27+
"github.com/dolthub/go-mysql-server/sql/analyzer"
2528
"github.com/dolthub/go-mysql-server/sql/expression"
29+
"github.com/dolthub/go-mysql-server/sql/plan"
30+
"github.com/dolthub/go-mysql-server/sql/rowexec"
2631
"github.com/dolthub/go-mysql-server/sql/types"
32+
"github.com/dolthub/go-mysql-server/sql/variables"
2733
)
2834

2935
func TestBindingsToExprs(t *testing.T) {
@@ -145,3 +151,93 @@ func TestBindingsToExprs(t *testing.T) {
145151
})
146152
}
147153
}
154+
155+
// wrapper around sql.Table to make it not indexable
156+
type nonIndexableTable struct {
157+
*memory.Table
158+
}
159+
160+
var _ memory.MemTable = (*nonIndexableTable)(nil)
161+
162+
func (t *nonIndexableTable) IgnoreSessionData() bool {
163+
return true
164+
}
165+
166+
func getRuleFrom(rules []analyzer.Rule, id analyzer.RuleId) *analyzer.Rule {
167+
for _, rule := range rules {
168+
if rule.Id == id {
169+
return &rule
170+
}
171+
}
172+
173+
return nil
174+
}
175+
176+
// TODO: this was an analyzer test, but we don't have a mock process list for it to use, so it has to be here
177+
func TestTrackProcess(t *testing.T) {
178+
require := require.New(t)
179+
variables.InitStatusVariables()
180+
db := memory.NewDatabase("db")
181+
provider := memory.NewDBProvider(db)
182+
a := analyzer.NewDefault(provider)
183+
sess := memory.NewSession(sql.NewBaseSession(), provider)
184+
185+
node := plan.NewInnerJoin(
186+
plan.NewResolvedTable(&nonIndexableTable{memory.NewPartitionedTable(db.BaseDatabase, "foo", sql.PrimaryKeySchema{}, nil, 2)}, nil, nil),
187+
plan.NewResolvedTable(memory.NewPartitionedTable(db.BaseDatabase, "bar", sql.PrimaryKeySchema{}, nil, 4), nil, nil),
188+
expression.NewLiteral(int64(1), types.Int64),
189+
)
190+
191+
pl := NewProcessList()
192+
193+
ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithProcessList(pl), sql.WithSession(sess))
194+
pl.AddConnection(ctx.Session.ID(), "localhost")
195+
pl.ConnectionReady(ctx.Session)
196+
ctx, err := ctx.ProcessList.BeginQuery(ctx, "SELECT foo")
197+
require.NoError(err)
198+
199+
rule := getRuleFrom(analyzer.OnceAfterAll, analyzer.TrackProcessId)
200+
result, _, err := rule.Apply(ctx, a, node, nil, analyzer.DefaultRuleSelector, nil)
201+
require.NoError(err)
202+
203+
processes := ctx.ProcessList.Processes()
204+
require.Len(processes, 1)
205+
require.Equal("SELECT foo", processes[0].Query)
206+
require.Equal(
207+
map[string]sql.TableProgress{
208+
"foo": {
209+
Progress: sql.Progress{Name: "foo", Done: 0, Total: 2},
210+
PartitionsProgress: map[string]sql.PartitionProgress{},
211+
},
212+
"bar": {
213+
Progress: sql.Progress{Name: "bar", Done: 0, Total: 4},
214+
PartitionsProgress: map[string]sql.PartitionProgress{},
215+
},
216+
},
217+
processes[0].Progress)
218+
219+
join, ok := result.(*plan.JoinNode)
220+
require.True(ok)
221+
require.Equal(plan.JoinTypeInner, join.JoinType())
222+
223+
lhs, ok := join.Left().(*plan.ResolvedTable)
224+
require.True(ok)
225+
_, ok = lhs.Table.(*plan.ProcessTable)
226+
require.True(ok)
227+
228+
rhs, ok := join.Right().(*plan.ResolvedTable)
229+
require.True(ok)
230+
_, ok = rhs.Table.(*plan.ProcessTable)
231+
require.True(ok)
232+
233+
iter, err := rowexec.DefaultBuilder.Build(ctx, result, nil)
234+
iter = finalizeIters(ctx, result, nil, iter)
235+
require.NoError(err)
236+
_, err = sql.RowIterToRows(ctx, iter)
237+
require.NoError(err)
238+
239+
processes = ctx.ProcessList.Processes()
240+
require.Len(processes, 1)
241+
require.Equal(sql.ProcessCommandSleep, processes[0].Command)
242+
require.Error(ctx.Err())
243+
}

enginetest/engine_only_test.go

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -293,74 +293,6 @@ b (2/6 partitions)
293293
require.ElementsMatch(expected, rows)
294294
}
295295

296-
// TODO: this was an analyzer test, but we don't have a mock process list for it to use, so it has to be here
297-
func TestTrackProcess(t *testing.T) {
298-
require := require.New(t)
299-
db := memory.NewDatabase("db")
300-
provider := memory.NewDBProvider(db)
301-
a := analyzer.NewDefault(provider)
302-
sess := memory.NewSession(sql.NewBaseSession(), provider)
303-
304-
node := plan.NewInnerJoin(
305-
plan.NewResolvedTable(&nonIndexableTable{memory.NewPartitionedTable(db.BaseDatabase, "foo", sql.PrimaryKeySchema{}, nil, 2)}, nil, nil),
306-
plan.NewResolvedTable(memory.NewPartitionedTable(db.BaseDatabase, "bar", sql.PrimaryKeySchema{}, nil, 4), nil, nil),
307-
expression.NewLiteral(int64(1), types.Int64),
308-
)
309-
310-
pl := sqle.NewProcessList()
311-
312-
ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithProcessList(pl), sql.WithSession(sess))
313-
pl.AddConnection(ctx.Session.ID(), "localhost")
314-
pl.ConnectionReady(ctx.Session)
315-
ctx, err := ctx.ProcessList.BeginQuery(ctx, "SELECT foo")
316-
require.NoError(err)
317-
318-
rule := getRuleFrom(analyzer.OnceAfterAll, analyzer.TrackProcessId)
319-
result, _, err := rule.Apply(ctx, a, node, nil, analyzer.DefaultRuleSelector, nil)
320-
require.NoError(err)
321-
322-
processes := ctx.ProcessList.Processes()
323-
require.Len(processes, 1)
324-
require.Equal("SELECT foo", processes[0].Query)
325-
require.Equal(
326-
map[string]sql.TableProgress{
327-
"foo": sql.TableProgress{
328-
Progress: sql.Progress{Name: "foo", Done: 0, Total: 2},
329-
PartitionsProgress: map[string]sql.PartitionProgress{}},
330-
"bar": sql.TableProgress{
331-
Progress: sql.Progress{Name: "bar", Done: 0, Total: 4},
332-
PartitionsProgress: map[string]sql.PartitionProgress{}},
333-
},
334-
processes[0].Progress)
335-
336-
proc, ok := result.(*plan.QueryProcess)
337-
require.True(ok)
338-
339-
join, ok := proc.Child().(*plan.JoinNode)
340-
require.True(ok)
341-
require.Equal(join.JoinType(), plan.JoinTypeInner)
342-
343-
lhs, ok := join.Left().(*plan.ResolvedTable)
344-
require.True(ok)
345-
_, ok = lhs.Table.(*plan.ProcessTable)
346-
require.True(ok)
347-
348-
rhs, ok := join.Right().(*plan.ResolvedTable)
349-
require.True(ok)
350-
_, ok = rhs.Table.(*plan.ProcessTable)
351-
require.True(ok)
352-
353-
iter, err := rowexec.DefaultBuilder.Build(ctx, proc, nil)
354-
require.NoError(err)
355-
_, err = sql.RowIterToRows(ctx, iter)
356-
require.NoError(err)
357-
358-
procs := ctx.ProcessList.Processes()
359-
require.Len(procs, 1)
360-
require.Equal(procs[0].Command, sql.ProcessCommandSleep)
361-
require.Error(ctx.Err())
362-
}
363-
364296
func TestConcurrentProcessList(t *testing.T) {
365297
enginetest.TestConcurrentProcessList(t, enginetest.NewDefaultMemoryHarness())
366298
}
@@ -515,7 +447,6 @@ func TestAnalyzer_Exp(t *testing.T) {
515447
require.NoError(t, err)
516448

517449
analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, nil)
518-
analyzed = analyzer.StripPassthroughNodes(analyzed)
519450
if tt.err != nil {
520451
require.Error(t, err)
521452
assert.True(t, tt.err.Is(err))
@@ -527,10 +458,6 @@ func TestAnalyzer_Exp(t *testing.T) {
527458
}
528459

529460
func assertNodesEqualWithDiff(t *testing.T, expected, actual sql.Node) {
530-
if x, ok := actual.(*plan.QueryProcess); ok {
531-
actual = x.Child()
532-
}
533-
534461
if !assert.Equal(t, expected, actual) {
535462
expectedStr := sql.DebugString(expected)
536463
actualStr := sql.DebugString(actual)

enginetest/evaluation.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,8 +1080,6 @@ func assertSchemasEqualWithDefaults(t *testing.T, expected, actual sql.Schema) b
10801080

10811081
func ExtractQueryNode(node sql.Node) sql.Node {
10821082
switch node := node.(type) {
1083-
case *plan.QueryProcess:
1084-
return ExtractQueryNode(node.Child())
10851083
case *plan.Releaser:
10861084
return ExtractQueryNode(node.Child)
10871085
default:

sql/analyzer/describe.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ func resolveDescribeQuery(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan
3232
return nil, transform.SameTree, err
3333
}
3434

35-
return d.WithQuery(StripPassthroughNodes(q)), transform.NewTree, nil
35+
return d.WithQuery(q), transform.NewTree, nil
3636
}

sql/analyzer/inserts.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
6262
if err != nil {
6363
return nil, transform.SameTree, err
6464
}
65-
66-
source = StripPassthroughNodes(source)
6765
}
6866

6967
dstSchema := insertable.Schema()

sql/analyzer/parallelize.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope
6363
return node, transform.SameTree, nil
6464
}
6565

66-
proc, ok := node.(*plan.QueryProcess)
67-
if (ok && !shouldParallelize(proc.Child(), nil)) || !shouldParallelize(node, scope) {
66+
if !shouldParallelize(node, scope) {
6867
return node, transform.SameTree, nil
6968
}
7069

sql/analyzer/process.go

Lines changed: 2 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,10 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
3939
if !n.Resolved() {
4040
return n, transform.SameTree, nil
4141
}
42-
43-
if _, ok := n.(*plan.QueryProcess); ok {
44-
return n, transform.SameTree, nil
45-
}
46-
4742
processList := ctx.ProcessList
4843

4944
var seen = make(map[string]struct{})
50-
n, _, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
45+
n, same, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
5146
switch n := n.(type) {
5247
case *plan.ResolvedTable:
5348
switch n.Table.(type) {
@@ -106,41 +101,9 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
106101
return n, transform.SameTree, nil
107102
}
108103
})
109-
if err != nil {
110-
return nil, transform.SameTree, err
111-
}
112-
113-
// Don't wrap CreateIndex in a QueryProcess, as it is a CreateIndexProcess.
114-
// CreateIndex will take care of marking the process as done on its own.
115-
if _, ok := n.(*plan.CreateIndex); ok {
116-
return n, transform.SameTree, nil
117-
}
118104

119-
// Remove QueryProcess nodes from the subqueries and trigger bodies. Otherwise, the process
120-
// will be marked as done as soon as a subquery / trigger finishes.
121-
node, _, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
122-
if sq, ok := n.(*plan.SubqueryAlias); ok {
123-
if qp, ok := sq.Child.(*plan.QueryProcess); ok {
124-
n, err := sq.WithChildren(qp.Child())
125-
return n, transform.NewTree, err
126-
}
127-
}
128-
if t, ok := n.(*plan.TriggerExecutor); ok {
129-
if qp, ok := t.Right().(*plan.QueryProcess); ok {
130-
n, err := t.WithChildren(t.Left(), qp.Child())
131-
return n, transform.NewTree, err
132-
}
133-
}
134-
return n, transform.SameTree, nil
135-
})
136105
if err != nil {
137106
return nil, transform.SameTree, err
138107
}
139-
140-
return plan.NewQueryProcess(node, func() {
141-
processList.EndQuery(ctx)
142-
if span := ctx.RootSpan(); span != nil {
143-
span.End()
144-
}
145-
}), transform.NewTree, nil
108+
return n, same, nil
146109
}

0 commit comments

Comments
 (0)