Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ func (e *Engine) QueryWithBindings(ctx *sql.Context, query string, parsed sqlpar
return nil, nil, nil, err
}

iter = finalizeIters(analyzed, qFlags, iter)
iter = finalizeIters(ctx, analyzed, qFlags, iter)

return analyzed.Schema(), iter, qFlags, nil
}
Expand Down Expand Up @@ -482,7 +482,7 @@ func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.
return nil, nil, nil, err
}

iter = finalizeIters(plan, nil, iter)
iter = finalizeIters(ctx, plan, nil, iter)

return plan.Schema(), iter, nil, nil
}
Expand Down Expand Up @@ -830,7 +830,7 @@ func (e *Engine) executeEvent(ctx *sql.Context, dbName, createEventStatement, us
return err
}

iter = finalizeIters(definitionNode, nil, iter)
iter = finalizeIters(ctx, definitionNode, nil, iter)

// Drain the iterate to execute the event body/definition
// NOTE: No row data is returned for an event; we just need to execute the statements
Expand Down Expand Up @@ -865,8 +865,9 @@ func findCreateEventNode(planTree sql.Node) (*plan.CreateEvent, error) {
}

// finalizeIters applies the final transformations on sql.RowIter before execution.
func finalizeIters(analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
iter = rowexec.AddTransactionCommittingIter(iter, qFlags)
func finalizeIters(ctx *sql.Context, analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
iter = rowexec.AddTransactionCommittingIter(qFlags, iter)
iter = plan.AddTrackedRowIter(ctx, analyzed, iter)
iter = rowexec.AddExpressionCloser(analyzed, iter)
return iter
}
98 changes: 97 additions & 1 deletion engine_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023 Dolthub, Inc.
// Copyright 2023-2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -15,15 +15,21 @@
package sqle

import (
"context"
"testing"
"time"

"github.com/dolthub/vitess/go/vt/proto/query"
"github.com/stretchr/testify/require"

"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/rowexec"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/go-mysql-server/sql/variables"
)

func TestBindingsToExprs(t *testing.T) {
Expand Down Expand Up @@ -145,3 +151,93 @@ func TestBindingsToExprs(t *testing.T) {
})
}
}

// wrapper around sql.Table to make it not indexable
type nonIndexableTable struct {
*memory.Table
}

var _ memory.MemTable = (*nonIndexableTable)(nil)

func (t *nonIndexableTable) IgnoreSessionData() bool {
return true
}

func getRuleFrom(rules []analyzer.Rule, id analyzer.RuleId) *analyzer.Rule {
for _, rule := range rules {
if rule.Id == id {
return &rule
}
}

return nil
}

// TODO: this was an analyzer test, but we don't have a mock process list for it to use, so it has to be here
func TestTrackProcess(t *testing.T) {
require := require.New(t)
variables.InitStatusVariables()
db := memory.NewDatabase("db")
provider := memory.NewDBProvider(db)
a := analyzer.NewDefault(provider)
sess := memory.NewSession(sql.NewBaseSession(), provider)

node := plan.NewInnerJoin(
plan.NewResolvedTable(&nonIndexableTable{memory.NewPartitionedTable(db.BaseDatabase, "foo", sql.PrimaryKeySchema{}, nil, 2)}, nil, nil),
plan.NewResolvedTable(memory.NewPartitionedTable(db.BaseDatabase, "bar", sql.PrimaryKeySchema{}, nil, 4), nil, nil),
expression.NewLiteral(int64(1), types.Int64),
)

pl := NewProcessList()

ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithProcessList(pl), sql.WithSession(sess))
pl.AddConnection(ctx.Session.ID(), "localhost")
pl.ConnectionReady(ctx.Session)
ctx, err := ctx.ProcessList.BeginQuery(ctx, "SELECT foo")
require.NoError(err)

rule := getRuleFrom(analyzer.OnceAfterAll, analyzer.TrackProcessId)
result, _, err := rule.Apply(ctx, a, node, nil, analyzer.DefaultRuleSelector, nil)
require.NoError(err)

processes := ctx.ProcessList.Processes()
require.Len(processes, 1)
require.Equal("SELECT foo", processes[0].Query)
require.Equal(
map[string]sql.TableProgress{
"foo": {
Progress: sql.Progress{Name: "foo", Done: 0, Total: 2},
PartitionsProgress: map[string]sql.PartitionProgress{},
},
"bar": {
Progress: sql.Progress{Name: "bar", Done: 0, Total: 4},
PartitionsProgress: map[string]sql.PartitionProgress{},
},
},
processes[0].Progress)

join, ok := result.(*plan.JoinNode)
require.True(ok)
require.Equal(plan.JoinTypeInner, join.JoinType())

lhs, ok := join.Left().(*plan.ResolvedTable)
require.True(ok)
_, ok = lhs.Table.(*plan.ProcessTable)
require.True(ok)

rhs, ok := join.Right().(*plan.ResolvedTable)
require.True(ok)
_, ok = rhs.Table.(*plan.ProcessTable)
require.True(ok)

iter, err := rowexec.DefaultBuilder.Build(ctx, result, nil)
iter = finalizeIters(ctx, result, nil, iter)
require.NoError(err)
_, err = sql.RowIterToRows(ctx, iter)
require.NoError(err)

processes = ctx.ProcessList.Processes()
require.Len(processes, 1)
require.Equal(sql.ProcessCommandSleep, processes[0].Command)
require.Error(ctx.Err())
}
73 changes: 0 additions & 73 deletions enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,74 +293,6 @@ b (2/6 partitions)
require.ElementsMatch(expected, rows)
}

// TODO: this was an analyzer test, but we don't have a mock process list for it to use, so it has to be here
func TestTrackProcess(t *testing.T) {
require := require.New(t)
db := memory.NewDatabase("db")
provider := memory.NewDBProvider(db)
a := analyzer.NewDefault(provider)
sess := memory.NewSession(sql.NewBaseSession(), provider)

node := plan.NewInnerJoin(
plan.NewResolvedTable(&nonIndexableTable{memory.NewPartitionedTable(db.BaseDatabase, "foo", sql.PrimaryKeySchema{}, nil, 2)}, nil, nil),
plan.NewResolvedTable(memory.NewPartitionedTable(db.BaseDatabase, "bar", sql.PrimaryKeySchema{}, nil, 4), nil, nil),
expression.NewLiteral(int64(1), types.Int64),
)

pl := sqle.NewProcessList()

ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithProcessList(pl), sql.WithSession(sess))
pl.AddConnection(ctx.Session.ID(), "localhost")
pl.ConnectionReady(ctx.Session)
ctx, err := ctx.ProcessList.BeginQuery(ctx, "SELECT foo")
require.NoError(err)

rule := getRuleFrom(analyzer.OnceAfterAll, analyzer.TrackProcessId)
result, _, err := rule.Apply(ctx, a, node, nil, analyzer.DefaultRuleSelector, nil)
require.NoError(err)

processes := ctx.ProcessList.Processes()
require.Len(processes, 1)
require.Equal("SELECT foo", processes[0].Query)
require.Equal(
map[string]sql.TableProgress{
"foo": sql.TableProgress{
Progress: sql.Progress{Name: "foo", Done: 0, Total: 2},
PartitionsProgress: map[string]sql.PartitionProgress{}},
"bar": sql.TableProgress{
Progress: sql.Progress{Name: "bar", Done: 0, Total: 4},
PartitionsProgress: map[string]sql.PartitionProgress{}},
},
processes[0].Progress)

proc, ok := result.(*plan.QueryProcess)
require.True(ok)

join, ok := proc.Child().(*plan.JoinNode)
require.True(ok)
require.Equal(join.JoinType(), plan.JoinTypeInner)

lhs, ok := join.Left().(*plan.ResolvedTable)
require.True(ok)
_, ok = lhs.Table.(*plan.ProcessTable)
require.True(ok)

rhs, ok := join.Right().(*plan.ResolvedTable)
require.True(ok)
_, ok = rhs.Table.(*plan.ProcessTable)
require.True(ok)

iter, err := rowexec.DefaultBuilder.Build(ctx, proc, nil)
require.NoError(err)
_, err = sql.RowIterToRows(ctx, iter)
require.NoError(err)

procs := ctx.ProcessList.Processes()
require.Len(procs, 1)
require.Equal(procs[0].Command, sql.ProcessCommandSleep)
require.Error(ctx.Err())
}

func TestConcurrentProcessList(t *testing.T) {
enginetest.TestConcurrentProcessList(t, enginetest.NewDefaultMemoryHarness())
}
Expand Down Expand Up @@ -515,7 +447,6 @@ func TestAnalyzer_Exp(t *testing.T) {
require.NoError(t, err)

analyzed, err := e.EngineAnalyzer().Analyze(ctx, parsed, nil, nil)
analyzed = analyzer.StripPassthroughNodes(analyzed)
if tt.err != nil {
require.Error(t, err)
assert.True(t, tt.err.Is(err))
Expand All @@ -527,10 +458,6 @@ func TestAnalyzer_Exp(t *testing.T) {
}

func assertNodesEqualWithDiff(t *testing.T, expected, actual sql.Node) {
if x, ok := actual.(*plan.QueryProcess); ok {
actual = x.Child()
}

if !assert.Equal(t, expected, actual) {
expectedStr := sql.DebugString(expected)
actualStr := sql.DebugString(actual)
Expand Down
2 changes: 0 additions & 2 deletions enginetest/evaluation.go
Original file line number Diff line number Diff line change
Expand Up @@ -1080,8 +1080,6 @@ func assertSchemasEqualWithDefaults(t *testing.T, expected, actual sql.Schema) b

func ExtractQueryNode(node sql.Node) sql.Node {
switch node := node.(type) {
case *plan.QueryProcess:
return ExtractQueryNode(node.Child())
case *plan.Releaser:
return ExtractQueryNode(node.Child)
default:
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/describe.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ func resolveDescribeQuery(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan
return nil, transform.SameTree, err
}

return d.WithQuery(StripPassthroughNodes(q)), transform.NewTree, nil
return d.WithQuery(q), transform.NewTree, nil
}
2 changes: 0 additions & 2 deletions sql/analyzer/inserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
if err != nil {
return nil, transform.SameTree, err
}

source = StripPassthroughNodes(source)
}

dstSchema := insertable.Schema()
Expand Down
3 changes: 1 addition & 2 deletions sql/analyzer/parallelize.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ func parallelize(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope
return node, transform.SameTree, nil
}

proc, ok := node.(*plan.QueryProcess)
if (ok && !shouldParallelize(proc.Child(), nil)) || !shouldParallelize(node, scope) {
if !shouldParallelize(node, scope) {
return node, transform.SameTree, nil
}

Expand Down
41 changes: 2 additions & 39 deletions sql/analyzer/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,10 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
if !n.Resolved() {
return n, transform.SameTree, nil
}

if _, ok := n.(*plan.QueryProcess); ok {
return n, transform.SameTree, nil
}

processList := ctx.ProcessList

var seen = make(map[string]struct{})
n, _, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
n, same, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
switch n := n.(type) {
case *plan.ResolvedTable:
switch n.Table.(type) {
Expand Down Expand Up @@ -106,41 +101,9 @@ func trackProcess(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
return n, transform.SameTree, nil
}
})
if err != nil {
return nil, transform.SameTree, err
}

// Don't wrap CreateIndex in a QueryProcess, as it is a CreateIndexProcess.
// CreateIndex will take care of marking the process as done on its own.
if _, ok := n.(*plan.CreateIndex); ok {
return n, transform.SameTree, nil
}

// Remove QueryProcess nodes from the subqueries and trigger bodies. Otherwise, the process
// will be marked as done as soon as a subquery / trigger finishes.
node, _, err := transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
if sq, ok := n.(*plan.SubqueryAlias); ok {
if qp, ok := sq.Child.(*plan.QueryProcess); ok {
n, err := sq.WithChildren(qp.Child())
return n, transform.NewTree, err
}
}
if t, ok := n.(*plan.TriggerExecutor); ok {
if qp, ok := t.Right().(*plan.QueryProcess); ok {
n, err := t.WithChildren(t.Left(), qp.Child())
return n, transform.NewTree, err
}
}
return n, transform.SameTree, nil
})
if err != nil {
return nil, transform.SameTree, err
}

return plan.NewQueryProcess(node, func() {
processList.EndQuery(ctx)
if span := ctx.RootSpan(); span != nil {
span.End()
}
}), transform.NewTree, nil
return n, same, nil
}
Loading
Loading