Skip to content
Merged
41 changes: 13 additions & 28 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ func (e *Engine) QueryWithBindings(ctx *sql.Context, query string, parsed sqlpar

return nil, nil, nil, err
}
iter = rowexec.AddExpressionCloser(analyzed, iter)

iter = finalizeIters(analyzed, qFlags, iter)

return analyzed.Schema(), iter, qFlags, nil
}
Expand Down Expand Up @@ -480,7 +481,8 @@ func (e *Engine) PrepQueryPlanForExecution(ctx *sql.Context, _ string, plan sql.

return nil, nil, nil, err
}
iter = rowexec.AddExpressionCloser(plan, iter)

iter = finalizeIters(plan, nil, iter)

return plan.Schema(), iter, nil, nil
}
Expand Down Expand Up @@ -722,31 +724,6 @@ func (e *Engine) CloseSession(connID uint32) {
e.PreparedDataCache.DeleteSessionData(connID)
}

// Count number of BindVars in given tree
func countBindVars(node sql.Node) int {
var bindVars map[string]bool
bindCntFunc := func(e sql.Expression) bool {
if bv, ok := e.(*expression.BindVar); ok {
if bindVars == nil {
bindVars = make(map[string]bool)
}
bindVars[bv.Name] = true
}
return true
}
transform.InspectExpressions(node, bindCntFunc)

// InsertInto.Source not a child of InsertInto, so also need to traverse those
transform.Inspect(node, func(n sql.Node) bool {
if in, ok := n.(*plan.InsertInto); ok {
transform.InspectExpressions(in.Source, bindCntFunc)
return false
}
return true
})
return len(bindVars)
}

func (e *Engine) beginTransaction(ctx *sql.Context) error {
beginNewTransaction := ctx.GetTransaction() == nil || plan.ReadCommitted(ctx)
if beginNewTransaction {
Expand Down Expand Up @@ -852,7 +829,8 @@ func (e *Engine) executeEvent(ctx *sql.Context, dbName, createEventStatement, us
}
return err
}
iter = rowexec.AddExpressionCloser(definitionNode, iter)

iter = finalizeIters(definitionNode, nil, iter)

// Drain the iterate to execute the event body/definition
// NOTE: No row data is returned for an event; we just need to execute the statements
Expand Down Expand Up @@ -885,3 +863,10 @@ func findCreateEventNode(planTree sql.Node) (*plan.CreateEvent, error) {

return createEventNode, nil
}

// finalizeIters applies the final transformations on sql.RowIter before execution.
func finalizeIters(analyzed sql.Node, qFlags *sql.QueryFlags, iter sql.RowIter) sql.RowIter {
iter = rowexec.AddTransactionCommittingIter(iter, qFlags)
iter = rowexec.AddExpressionCloser(analyzed, iter)
return iter
}
16 changes: 0 additions & 16 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/types"
_ "github.com/dolthub/go-mysql-server/sql/variables"
)
Expand Down Expand Up @@ -196,13 +195,6 @@ func TestSingleQueryPrepared(t *testing.T) {
enginetest.TestScriptWithEnginePrepared(t, engine, harness, test)
}

func newUpdateResult(matched, updated int) types.OkResult {
return types.OkResult{
RowsAffected: uint64(updated),
Info: plan.UpdateInfo{Matched: matched, Updated: updated},
}
}

// Convenience test for debugging a single query. Unskip and set to the desired query.
func TestSingleScript(t *testing.T) {
t.Skip()
Expand Down Expand Up @@ -1065,14 +1057,6 @@ func findTable(dbs []sql.Database, tableName string) (sql.Database, sql.Table) {
return nil, nil
}

func mergeSetupScripts(scripts ...setup.SetupScript) []string {
var all []string
for _, s := range scripts {
all = append(all, s...)
}
return all
}

func TestSQLLogicTests(t *testing.T) {
enginetest.TestSQLLogicTests(t, enginetest.NewMemoryHarness("default", 1, testNumPartitions, true, mergableIndexDriver))
}
Expand Down
22 changes: 10 additions & 12 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,28 +519,26 @@ func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter, resultFields []*quer
func GetDeferredProjections(iter sql.RowIter) (sql.RowIter, []sql.Expression) {
switch i := iter.(type) {
case *rowexec.ExprCloserIter:
_, projs := GetDeferredProjections(i.GetIter())
return i, projs
if newChild, projs := GetDeferredProjections(i.GetIter()); projs != nil {
return i.WithChildIter(newChild), projs
}
case *plan.TrackedRowIter:
_, projs := GetDeferredProjections(i.GetIter())
return i, projs
if newChild, projs := GetDeferredProjections(i.GetIter()); projs != nil {
return i.WithChildIter(newChild), projs
}
case *rowexec.TransactionCommittingIter:
newChild, projs := GetDeferredProjections(i.GetIter())
if projs != nil {
i.WithChildIter(newChild)
if newChild, projs := GetDeferredProjections(i.GetIter()); projs != nil {
return i.WithChildIter(newChild), projs
}
return i, projs
case *iters.LimitIter:
newChild, projs := GetDeferredProjections(i.ChildIter)
if projs != nil {
if newChild, projs := GetDeferredProjections(i.ChildIter); projs != nil {
i.ChildIter = newChild
return i, projs
}
return i, projs
case *rowexec.ProjectIter:
if i.CanDefer() {
return i.GetChildIter(), i.GetProjections()
}
return i, nil
}
return iter, nil
}
Expand Down
1 change: 0 additions & 1 deletion sql/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,6 @@ func NewProcRuleSelector(sel RuleSelector) RuleSelector {
unnestInSubqueriesId,

// once after default rules should only be run once
AutocommitId,
TrackProcessId,
parallelizeId:
return false
Expand Down
34 changes: 0 additions & 34 deletions sql/analyzer/autocommit.go

This file was deleted.

4 changes: 0 additions & 4 deletions sql/analyzer/parallelize.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ func shouldParallelize(node sql.Node, scope *plan.Scope) bool {
return false
}

if tc, ok := node.(*plan.TransactionCommittingNode); ok {
return shouldParallelize(tc.Child(), scope)
}

// Do not try to parallelize DDL or descriptive operations
return !plan.IsNoRowNode(node)
}
Expand Down
2 changes: 0 additions & 2 deletions sql/analyzer/resolve_subqueries.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,6 @@ func StripPassthroughNodes(n sql.Node) sql.Node {
switch tn := n.(type) {
case *plan.QueryProcess:
n = tn.Child()
case *plan.TransactionCommittingNode:
n = tn.Child()
default:
nodeIsPassthrough = false
}
Expand Down
3 changes: 1 addition & 2 deletions sql/analyzer/rule_ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,7 @@ const (

// after all
cacheSubqueryAliasesInJoinsId // cacheSubqueryAliasesInJoins
backtickDefaulColumnValueNamesId // backtickDefaultColumnValueNames
AutocommitId // addAutocommit
BacktickDefaulColumnValueNamesId // backtickDefaultColumnValueNames
TrackProcessId // trackProcess
parallelizeId // parallelize
)
11 changes: 5 additions & 6 deletions sql/analyzer/ruleid_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ func init() {
{wrapWithRollbackId, wrapWithRollback},
{inlineSubqueryAliasRefsId, inlineSubqueryAliasRefs},
{cacheSubqueryAliasesInJoinsId, cacheSubqueryAliasesInJoins},
{backtickDefaulColumnValueNamesId, backtickDefaultColumnValueNames},
{AutocommitId, addAutocommit},
{BacktickDefaulColumnValueNamesId, backtickDefaultColumnValueNames},
{TrackProcessId, trackProcess},
{parallelizeId, parallelize},
}
Expand Down
6 changes: 6 additions & 0 deletions sql/plan/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,12 @@ func (i *TrackedRowIter) updateSessionVars(ctx *sql.Context) {
}
}

func (i *TrackedRowIter) WithChildIter(childIter sql.RowIter) sql.RowIter {
ni := *i
ni.iter = childIter
return &ni
}

type trackedPartitionIndexKeyValueIter struct {
sql.PartitionIndexKeyValueIter
OnPartitionDone NamedNotifyFunc
Expand Down
79 changes: 0 additions & 79 deletions sql/plan/transaction_committing_iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,84 +15,9 @@
package plan

import (
"fmt"
"os"

"github.com/dolthub/go-mysql-server/sql"
)

const (
fakeReadCommittedEnvVar = "READ_COMMITTED_HACK"
)

var fakeReadCommitted bool

func init() {
_, ok := os.LookupEnv(fakeReadCommittedEnvVar)
if ok {
fakeReadCommitted = true
}
}

// TransactionCommittingNode implements autocommit logic. It wraps relevant queries and ensures the database commits
// the transaction.
type TransactionCommittingNode struct {
UnaryNode
}

var _ sql.Node = (*TransactionCommittingNode)(nil)
var _ sql.CollationCoercible = (*TransactionCommittingNode)(nil)

// NewTransactionCommittingNode returns a TransactionCommittingNode.
func NewTransactionCommittingNode(child sql.Node) *TransactionCommittingNode {
return &TransactionCommittingNode{UnaryNode: UnaryNode{Child: child}}
}

// String implements the sql.Node interface.
func (t *TransactionCommittingNode) String() string {
return t.Child().String()
}

// DebugString implements the sql.DebugStringer interface.
func (t *TransactionCommittingNode) DebugString() string {
return sql.DebugString(t.Child())
}

// Describe implements the sql.Describable interface.
func (t *TransactionCommittingNode) Describe(options sql.DescribeOptions) string {
return sql.Describe(t.Child(), options)
}

func (t *TransactionCommittingNode) IsReadOnly() bool {
return t.Child().IsReadOnly()
}

// WithChildren implements the sql.Node interface.
func (t *TransactionCommittingNode) WithChildren(children ...sql.Node) (sql.Node, error) {
if len(children) != 1 {
return nil, fmt.Errorf("ds")
}

t2 := *t
t2.UnaryNode = UnaryNode{Child: children[0]}
return &t2, nil
}

// CheckPrivileges implements the sql.Node interface.
func (t *TransactionCommittingNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
return t.Child().CheckPrivileges(ctx, opChecker)
}

// CollationCoercibility implements the interface sql.CollationCoercible.
func (*TransactionCommittingNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
return sql.Collation_binary, 7
}

// Child implements the sql.UnaryNode interface.
func (t *TransactionCommittingNode) Child() sql.Node {
return t.UnaryNode.Child
}

// IsSessionAutocommit returns true if the current session is using implicit transaction management
// through autocommit.
func IsSessionAutocommit(ctx *sql.Context) (bool, error) {
Expand All @@ -108,10 +33,6 @@ func IsSessionAutocommit(ctx *sql.Context) (bool, error) {
}

func ReadCommitted(ctx *sql.Context) bool {
if !fakeReadCommitted {
return false
}

val, err := ctx.GetSessionVariable(ctx, "transaction_isolation")
if err != nil {
return false
Expand Down
6 changes: 6 additions & 0 deletions sql/rowexec/expr_closer.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,9 @@ func (eci *ExprCloserIter) Close(ctx *sql.Context) error {
func (eci *ExprCloserIter) GetIter() sql.RowIter {
return eci.iter
}

func (eci *ExprCloserIter) WithChildIter(childIter sql.RowIter) sql.RowIter {
neci := *eci
neci.iter = childIter
return &neci
}
2 changes: 0 additions & 2 deletions sql/rowexec/node_builder.gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ func (b *BaseBuilder) buildNodeExecNoAnalyze(ctx *sql.Context, n sql.Node, row s
return b.buildCreateRole(ctx, n, row)
case *plan.Loop:
return b.buildLoop(ctx, n, row)
case *plan.TransactionCommittingNode:
return b.buildTransactionCommittingNode(ctx, n, row)
case *plan.DropColumn:
return b.buildDropColumn(ctx, n, row)
case *plan.AnalyzeTable:
Expand Down
8 changes: 0 additions & 8 deletions sql/rowexec/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,3 @@ func (b *BaseBuilder) buildExecuteQuery(ctx *sql.Context, n *plan.ExecuteQuery,
func (b *BaseBuilder) buildUse(ctx *sql.Context, n *plan.Use, row sql.Row) (sql.RowIter, error) {
return n.RowIter(ctx, row)
}

func (b *BaseBuilder) buildTransactionCommittingNode(ctx *sql.Context, n *plan.TransactionCommittingNode, row sql.Row) (sql.RowIter, error) {
iter, err := b.Build(ctx, n.Child(), row)
if err != nil {
return nil, err
}
return &TransactionCommittingIter{childIter: iter}, nil
}
Loading
Loading