From 3cdb14e69488955f3cc035a774e8d765745b77b5 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 13 Feb 2025 14:21:06 -0800 Subject: [PATCH 001/111] temp --- sql/analyzer/interpreter.go | 33 +-- sql/expression/procedurereference.go | 2 + sql/plan/call.go | 15 ++ sql/planbuilder/create_ddl.go | 2 +- sql/planbuilder/from.go | 5 + sql/planbuilder/proc.go | 6 + sql/procedures.go | 31 +++ sql/procedures/interpreter_operation.go | 51 +++++ sql/procedures/interpreter_stack.go | 133 +++++++++++ sql/procedures/parse.go | 47 ++++ sql/procedures/statements.go | 282 ++++++++++++++++++++++++ sql/rowexec/proc.go | 7 + 12 files changed, 597 insertions(+), 17 deletions(-) create mode 100644 sql/procedures/interpreter_operation.go create mode 100644 sql/procedures/interpreter_stack.go create mode 100644 sql/procedures/parse.go create mode 100644 sql/procedures/statements.go diff --git a/sql/analyzer/interpreter.go b/sql/analyzer/interpreter.go index 80d07c620e..b6642448b1 100644 --- a/sql/analyzer/interpreter.go +++ b/sql/analyzer/interpreter.go @@ -15,7 +15,6 @@ package analyzer import ( - "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/go-mysql-server/sql/transform" @@ -23,25 +22,27 @@ import ( "github.com/dolthub/go-mysql-server/sql/plan" ) -// Interpreter is an interface that implements an interpreter. These are typically used for functions (which may be -// implemented as a set of operations that are interpreted during runtime). -type Interpreter interface { - SetStatementRunner(ctx *sql.Context, runner StatementRunner) sql.Expression -} - -// StatementRunner is essentially an interface that the engine will implement. We cannot directly reference the engine -// here as it will cause an import cycle, so this may be updated to suit any function changes that the engine -// experiences. -type StatementRunner interface { - QueryWithBindings(ctx *sql.Context, query string, parsed sqlparser.Statement, bindings map[string]sqlparser.Expr, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) -} - // interpreter hands the engine to any interpreter expressions. func interpreter(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { - return transform.NodeExprs(n, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - if interp, ok := expr.(Interpreter); ok { + newNode, sameNode, err := transform.Node(n, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) { + if interp, ok := node.(sql.InterpreterNode); ok { + return interp.SetStatementRunner(ctx, a.Runner), transform.NewTree, nil + } + return node, transform.SameTree, nil + }) + if err != nil { + return nil, transform.SameTree, err + } + + newNode, sameExpr, err := transform.NodeExprs(newNode, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { + if interp, ok := expr.(sql.Interpreter); ok { return interp.SetStatementRunner(ctx, a.Runner), transform.NewTree, nil } return expr, transform.SameTree, nil }) + if err != nil { + return nil, transform.SameTree, err + } + + return newNode, sameNode && sameExpr, err } diff --git a/sql/expression/procedurereference.go b/sql/expression/procedurereference.go index bed05f478b..3e3d82c28b 100644 --- a/sql/expression/procedurereference.go +++ b/sql/expression/procedurereference.go @@ -23,6 +23,8 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) +// TODO: instead of procedure reference, copy stack from doltgres + // ProcedureReference contains the state for a single CALL statement of a stored procedure. type ProcedureReference struct { InnermostScope *procedureScope diff --git a/sql/plan/call.go b/sql/plan/call.go index 450ae38af3..114680bb5e 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -23,6 +23,10 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" ) + +// TODO: we need different types of calls: one for extrenal procedures one for stored procedures + + type Call struct { db sql.Database Name string @@ -32,11 +36,15 @@ type Call struct { Pref *expression.ProcedureReference cat sql.Catalog Analyzed bool + + // this will have list of parsed operations to run + runner sql.StatementRunner } var _ sql.Node = (*Call)(nil) var _ sql.CollationCoercible = (*Call)(nil) var _ sql.Expressioner = (*Call)(nil) +var _ sql.InterpreterNode = (*Call)(nil) var _ Versionable = (*Call)(nil) // NewCall returns a *Call node. @@ -197,3 +205,10 @@ func (c *Call) Dispose() { disposeNode(c.Procedure) } } + +// SetStatementRunner implements the sql.InterpreterNode interface. +func (c *Call) SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Node { + nc := *c + nc.runner = runner + return &nc +} diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index 93017cfa25..3dac989095 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -263,7 +263,7 @@ func (b *Builder) validateStatement(inScope *scope, stmt ast.Statement) { if s.TriggerSpec != nil { b.handleErr(fmt.Errorf("can't create a TRIGGER from within another stored routine")) } - b.handleErr(fmt.Errorf("CREATE statements in CREATE PROCEDURE not yet supported")) + //b.handleErr(fmt.Errorf("CREATE statements in CREATE PROCEDURE not yet supported")) default: b.handleErr(fmt.Errorf("DDL in CREATE PROCEDURE not yet supported")) } diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index c7070fe69f..8912d8cfac 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -705,6 +705,11 @@ func (b *Builder) buildResolvedTable(inScope *scope, db, schema, name string, as b.TriggerCtx().UnresolvedTables = append(b.TriggerCtx().UnresolvedTables, name) return outScope, true } + // TODO: do the same for stored procedures + if b.procCtx != nil { + outScope.node = plan.NewUnresolvedTable(name, db) + return outScope, true + } return outScope, false } else { b.handleErr(tableResolveErr) diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 8f8d9045b4..816f053940 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -24,6 +24,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/procedures" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -254,6 +255,10 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, isCreateProc bool, stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, procDetails.CreateStatement, ';', false, b.parserOpts) procStmt := stmt.(*ast.DDL) + // TODO: convert ast to operations + + procedures.Parse(procStmt.ProcedureSpec.Body) + procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) characteristics, securityType, comment := b.buildProcedureCharacteristics(procStmt.ProcedureSpec.Characteristics) @@ -321,6 +326,7 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { } if esp != nil { proc, err = resolveExternalStoredProcedure(*esp) + // TODO: return plan.NewExternalCall here } else if spdb, ok := db.(sql.StoredProcedureDatabase); ok { var procDetails sql.StoredProcedureDetails procDetails, ok, err = spdb.GetStoredProcedure(b.ctx, procName) diff --git a/sql/procedures.go b/sql/procedures.go index ff49c5867d..d02d854ffb 100644 --- a/sql/procedures.go +++ b/sql/procedures.go @@ -17,8 +17,39 @@ package sql import ( "fmt" "time" + + "github.com/dolthub/vitess/go/vt/sqlparser" ) +// Interpreter is an interface that implements an interpreter. These are typically used for functions (which may be +// implemented as a set of operations that are interpreted during runtime). +type Interpreter interface { + SetStatementRunner(ctx *Context, runner StatementRunner) Expression +} + +// TODO: InterpreterNode interface +// TODO: alternatively have plan.Call just have an interpreter expression + +// InterpreterNode is an interface that implements an interpreter. These are typically used for functions (which may be +// implemented as a set of operations that are interpreted during runtime). +type InterpreterNode interface { + SetStatementRunner(ctx *Context, runner StatementRunner) Node +} + +// StatementRunner is essentially an interface that the engine will implement. We cannot directly reference the engine +// here as it will cause an import cycle, so this may be updated to suit any function changes that the engine +// experiences. +type StatementRunner interface { + QueryWithBindings(ctx *Context, query string, parsed sqlparser.Statement, bindings map[string]sqlparser.Expr, qFlags *QueryFlags) (Schema, RowIter, *QueryFlags, error) +} + + + + + + + + // StoredProcedureDetails are the details of the stored procedure. Integrators only need to store and retrieve the given // details for a stored procedure, as the engine handles all parsing and processing. type StoredProcedureDetails struct { diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go new file mode 100644 index 0000000000..7521a22691 --- /dev/null +++ b/sql/procedures/interpreter_operation.go @@ -0,0 +1,51 @@ +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package procedures + +// OpCode states the operation to be performed. Most operations have a direct analogue to a Pl/pgSQL operation, however +// some exist only in Doltgres (specific to our interpreter implementation). +type OpCode uint16 + +const ( + OpCode_Alias OpCode = iota // https://www.postgresql.org/docs/15/plpgsql-declarations.html#PLPGSQL-DECLARATION-ALIAS + OpCode_Assign // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-ASSIGNMENT + OpCode_Case // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS + OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html + OpCode_DeleteInto // https://www.postgresql.org/docs/15/plpgsql-statements.html + OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING + OpCode_Execute // Executing a standard SQL statement (expects no rows returned unless Target is specified) + OpCode_ExecuteDynamic // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-EXECUTING-DYN + OpCode_For // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONTROL-STRUCTURES-LOOPS + OpCode_Foreach // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONTROL-STRUCTURES-LOOPS + OpCode_Get // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-DIAGNOSTICS + OpCode_Goto // All control-flow structures can be represented using Goto + OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS + OpCode_InsertInto // https://www.postgresql.org/docs/15/plpgsql-statements.html + OpCode_Loop // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONTROL-STRUCTURES-LOOPS + OpCode_Perform // https://www.postgresql.org/docs/15/plpgsql-statements.html + OpCode_Query // This is just a standard query, nothing special + OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING + OpCode_ScopeBegin // This is used for scope control, specific to Doltgres + OpCode_ScopeEnd // This is used for scope control, specific to Doltgres + OpCode_SelectInto // https://www.postgresql.org/docs/15/plpgsql-statements.html + OpCode_When // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS + OpCode_While // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONTROL-STRUCTURES-LOOPS + OpCode_UpdateInto // https://www.postgresql.org/docs/15/plpgsql-statements.html +) + +// InterpreterOperation is an operation that will be performed by the interpreter. +type InterpreterOperation struct { + OpCode OpCode + PrimaryData string // This will represent the "main" data, such as the query for PERFORM, expression for IF, etc. + SecondaryData []string // This represents auxiliary data, such as bindings, strictness, etc. + Target string // This is the variable that will store the results (if applicable) + Index int // This is the index that should be set for operations that move the function counter +} \ No newline at end of file diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go new file mode 100644 index 0000000000..99370f6f17 --- /dev/null +++ b/sql/procedures/interpreter_stack.go @@ -0,0 +1,133 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package procedures + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/analyzer" + + pgtypes "github.com/dolthub/doltgresql/server/types" + "github.com/dolthub/doltgresql/utils" +) + +// InterpreterVariable is a variable that lives on the stack. +type InterpreterVariable struct { + Type *pgtypes.DoltgresType + Value any +} + +// InterpreterScopeDetails contains all of the details that are relevant to a particular scope. +type InterpreterScopeDetails struct { + variables map[string]*InterpreterVariable +} + +// InterpreterStack represents the working information that an interpreter will use during execution. It is not exactly +// the same as a stack in the traditional programming sense, but rather is a loose abstraction that serves the same +// general purpose. +type InterpreterStack struct { + stack *utils.Stack[*InterpreterScopeDetails] + runner analyzer.StatementRunner +} + +// NewInterpreterStack creates a new InterpreterStack. +func NewInterpreterStack(runner analyzer.StatementRunner) InterpreterStack { + stack := utils.NewStack[*InterpreterScopeDetails]() + // This first push represents the function base, including parameters + stack.Push(&InterpreterScopeDetails{ + variables: make(map[string]*InterpreterVariable), + }) + return InterpreterStack{ + stack: stack, + runner: runner, + } +} + +// Details returns the details for the current scope. +func (is *InterpreterStack) Details() *InterpreterScopeDetails { + return is.stack.Peek() +} + +// Runner returns the runner that is being used for the function's execution. +func (is *InterpreterStack) Runner() analyzer.StatementRunner { + return is.runner +} + +// GetVariable traverses the stack (starting from the top) to find a variable with a matching name. Returns nil if no +// variable was found. +func (is *InterpreterStack) GetVariable(name string) *InterpreterVariable { + for i := 0; i < is.stack.Len(); i++ { + if iv, ok := is.stack.PeekDepth(i).variables[name]; ok { + return iv + } + } + return nil +} + +// ListVariables returns a map with the names of all variables. +func (is *InterpreterStack) ListVariables() map[string]struct{} { + seen := make(map[string]struct{}) + for i := 0; i < is.stack.Len(); i++ { + for varName := range is.stack.PeekDepth(i).variables { + seen[varName] = struct{}{} + } + } + return seen +} + +// NewVariable creates a new variable in the current scope. If a variable with the same name exists in a previous scope, +// then that variable will be shadowed until the current scope exits. +func (is *InterpreterStack) NewVariable(name string, typ *pgtypes.DoltgresType) { + is.NewVariableWithValue(name, typ, typ.Zero()) +} + +// NewVariableWithValue creates a new variable in the current scope, setting its initial value to the one given. +func (is *InterpreterStack) NewVariableWithValue(name string, typ *pgtypes.DoltgresType, val any) { + is.stack.Peek().variables[name] = &InterpreterVariable{ + Type: typ, + Value: val, + } +} + +// NewVariableAlias creates a new variable alias, named |alias|, in the current frame of this stack, +// pointing to the specified |variable|. +func (is *InterpreterStack) NewVariableAlias(alias string, variable *InterpreterVariable) { + is.stack.Peek().variables[alias] = variable +} + +// PushScope creates a new scope. +func (is *InterpreterStack) PushScope() { + is.stack.Push(&InterpreterScopeDetails{ + variables: make(map[string]*InterpreterVariable), + }) +} + +// PopScope removes the current scope. +func (is *InterpreterStack) PopScope() { + is.stack.Pop() +} + +// SetVariable sets the first variable found, with a matching name, to the value given. This does not ensure that the +// value matches the expectations of the type, so it should be validated before this is called. Returns an error if the +// variable cannot be found. +func (is *InterpreterStack) SetVariable(ctx *sql.Context, name string, val any) error { + iv := is.GetVariable(name) + if iv == nil { + return fmt.Errorf("variable `%s` could not be found", name) + } + iv.Value = val + return nil +} \ No newline at end of file diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go new file mode 100644 index 0000000000..396b50eca6 --- /dev/null +++ b/sql/procedures/parse.go @@ -0,0 +1,47 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package procedures + +import ( + ast "github.com/dolthub/vitess/go/vt/sqlparser" +) + +func ConvertStmt(stmt *ast.Statement) (Block, error) { + block := Block{} + switch s := stmt.(type) { + case *ast.BeginEndBlock: + // TODO: convert this into what? operations? + + } + + + return block, nil +} + +// Parse parses the given CREATE FUNCTION string (which must be the entire string, not just the body) into a Block +// containing the contents of the body. +func Parse(stmt *ast.Statement) ([]InterpreterOperation, error) { + block, err := ConvertStmt(stmt) + if err != nil { + return nil, err + } + + ops := make([]InterpreterOperation, 0, len(block.Body)+len(block.Variable)) + stack := NewInterpreterStack(nil) + if err := block.AppendOperations(&ops, &stack); err != nil { + return nil, err + } + return ops, nil +} \ No newline at end of file diff --git a/sql/procedures/statements.go b/sql/procedures/statements.go new file mode 100644 index 0000000000..fe0133eaac --- /dev/null +++ b/sql/procedures/statements.go @@ -0,0 +1,282 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package procedures + +import ( + "fmt"pg_query "github.com/pganalyze/pg_query_go/v6" + +) + +// Statement represents a PL/pgSQL statement. +type Statement interface { + // OperationSize reports the number of operations that the statement will convert to. + OperationSize() int32 + // AppendOperations adds the statement to the operation slice. + AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error +} + +// Assignment represents an assignment statement. +type Assignment struct { + VariableName string + Expression string + VariableIndex int32 // TODO: figure out what this is used for, probably to get around shadowed variables? +} + +var _ Statement = Assignment{} + +// OperationSize implements the interface Statement. +func (Assignment) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt Assignment) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { + expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack) + if err != nil { + return err + } + + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Assign, + PrimaryData: "SELECT " + expression + ";", + SecondaryData: referencedVariables, + Target: stmt.VariableName, + }) + return nil +} + +// Block contains a collection of statements, alongside the variables that were declared for the block. Only the +// top-level block will contain parameter variables. +type Block struct { + Variable []Variable + Body []Statement +} + +var _ Statement = Block{} + +// OperationSize implements the interface Statement. +func (stmt Block) OperationSize() int32 { + total := int32(2) // We start with 2 since we'll have ScopeBegin and ScopeEnd + for _, variable := range stmt.Variable { + if !variable.IsParameter { + total++ + } + } + for _, innerStmt := range stmt.Body { + total += innerStmt.OperationSize() + } + return total +} + +// AppendOperations implements the interface Statement. +func (stmt Block) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { + stack.PushScope() + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_ScopeBegin, + }) + for _, variable := range stmt.Variable { + if !variable.IsParameter { + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Declare, + PrimaryData: variable.Type, + Target: variable.Name, + }) + } + stack.NewVariableWithValue(variable.Name, nil, nil) + } + for _, innerStmt := range stmt.Body { + if err := innerStmt.AppendOperations(ops, stack); err != nil { + return err + } + } + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_ScopeEnd, + }) + stack.PopScope() + return nil +} + +// ExecuteSQL represents a standard SQL statement's execution (including the INTO syntax). +type ExecuteSQL struct { + Statement string + Target string +} + +var _ Statement = ExecuteSQL{} + +// OperationSize implements the interface Statement. +func (ExecuteSQL) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt ExecuteSQL) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { + statementStr, referencedVariables, err := substituteVariableReferences(stmt.Statement, stack) + if err != nil { + return err + } + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Execute, + PrimaryData: statementStr, + SecondaryData: referencedVariables, + Target: stmt.Target, + }) + return nil +} + +// Goto jumps to the counter at the given offset. +type Goto struct { + Offset int32 +} + +var _ Statement = Goto{} + +// OperationSize implements the interface Statement. +func (Goto) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt Goto) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Goto, + Index: len(*ops) + int(stmt.Offset), + }) + return nil +} + +// If represents an IF condition, alongside its Goto offset if the condition is true. +type If struct { + Condition string + GotoOffset int32 +} + +var _ Statement = If{} + +// OperationSize implements the interface Statement. +func (If) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt If) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { + condition, referencedVariables, err := substituteVariableReferences(stmt.Condition, stack) + if err != nil { + return err + } + + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_If, + PrimaryData: "SELECT " + condition + ";", + SecondaryData: referencedVariables, + Index: len(*ops) + int(stmt.GotoOffset), + }) + return nil +} + +// Perform represents a PERFORM statement. +type Perform struct { + Statement string +} + +var _ Statement = Perform{} + +// OperationSize implements the interface Statement. +func (Perform) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt Perform) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { + statementStr, referencedVariables, err := substituteVariableReferences(stmt.Statement, stack) + if err != nil { + return err + } + + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Perform, + PrimaryData: statementStr, + SecondaryData: referencedVariables, + }) + return nil +} + +// Return represents a RETURN statement. +type Return struct { + Expression string +} + +var _ Statement = Return{} + +// OperationSize implements the interface Statement. +func (Return) OperationSize() int32 { + return 1 +} + +// AppendOperations implements the interface Statement. +func (stmt Return) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { + expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack) + if err != nil { + return err + } + if len(expression) > 0 { + expression = "SELECT " + expression + ";" + } + *ops = append(*ops, InterpreterOperation{ + OpCode: OpCode_Return, + PrimaryData: expression, + SecondaryData: referencedVariables, + }) + return nil +} + +// Variable represents a variable. These are exclusively found within Block. +type Variable struct { + Name string + Type string + IsParameter bool +} + +// OperationSizeForStatements returns the sum of OperationSize for every statement. +func OperationSizeForStatements(stmts []Statement) int32 { + total := int32(0) + for _, stmt := range stmts { + total += stmt.OperationSize() + } + return total +} + +// substituteVariableReferences parses the specified |expression| and replaces +// any token that matches a variable name in the |stack| with "$N", where N +// indicates which variable in the returned |referenceVars| slice is used. +func substituteVariableReferences(expression string, stack *InterpreterStack) (newExpression string, referencedVars []string, err error) { + scanResult, err := pg_query.Scan(expression) + if err != nil { + return "", nil, err + } + + varMap := stack.ListVariables() + for _, token := range scanResult.Tokens { + substring := expression[token.Start:token.End] + if _, ok := varMap[substring]; ok { + referencedVars = append(referencedVars, substring) + newExpression += fmt.Sprintf("$%d ", len(referencedVars)) + } else { + newExpression += substring + " " + } + } + + return newExpression, referencedVars, nil +} \ No newline at end of file diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 4531605586..6732949fc9 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -197,10 +197,17 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq n.Pref.PushScope() defer n.Pref.PopScope(ctx) + // TODO: mirror plpgsql interpreter_logic.go Call() + // TODO: instead of building, run the actual operations + // This means call the runner.QueryWithBindings innerIter, err := b.buildNodeExec(ctx, n.Procedure, row) if err != nil { return nil, err } + + // TODO: save any select ast rowIters to be returned later + + return &callIter{ call: n, innerIter: innerIter, From 24f2c2ac8a685a0e3e2852fc5e1a1ab52593ffc5 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 14 Feb 2025 10:40:11 -0800 Subject: [PATCH 002/111] progress --- sql/plan/call.go | 49 ++++++-- sql/plan/procedure.go | 40 +++---- sql/planbuilder/proc.go | 26 ++--- sql/procedures.go | 6 + sql/procedures/interpreter_logic.go | 144 ++++++++++++++++++++++++ sql/procedures/interpreter_operation.go | 23 +--- sql/procedures/interpreter_stack.go | 6 +- sql/procedures/parse.go | 53 ++++++--- sql/procedures/statements.go | 63 +---------- sql/rowexec/proc.go | 21 +++- 10 files changed, 282 insertions(+), 149 deletions(-) create mode 100644 sql/procedures/interpreter_logic.go diff --git a/sql/plan/call.go b/sql/plan/call.go index 114680bb5e..4296e9aa93 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -17,15 +17,14 @@ package plan import ( "fmt" - "github.com/dolthub/go-mysql-server/sql/types" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/procedures" + "github.com/dolthub/go-mysql-server/sql/types" ) -// TODO: we need different types of calls: one for extrenal procedures one for stored procedures - +// TODO: we need different types of calls: one for external procedures one for stored procedures type Call struct { db sql.Database @@ -38,17 +37,18 @@ type Call struct { Analyzed bool // this will have list of parsed operations to run - runner sql.StatementRunner + Runner sql.StatementRunner + Ops []procedures.InterpreterOperation } var _ sql.Node = (*Call)(nil) var _ sql.CollationCoercible = (*Call)(nil) var _ sql.Expressioner = (*Call)(nil) -var _ sql.InterpreterNode = (*Call)(nil) +var _ procedures.InterpreterNode = (*Call)(nil) var _ Versionable = (*Call)(nil) // NewCall returns a *Call node. -func NewCall(db sql.Database, name string, params []sql.Expression, proc *Procedure, asOf sql.Expression, catalog sql.Catalog) *Call { +func NewCall(db sql.Database, name string, params []sql.Expression, proc *Procedure, asOf sql.Expression, catalog sql.Catalog, ops []procedures.InterpreterOperation) *Call { return &Call{ db: db, Name: name, @@ -56,6 +56,7 @@ func NewCall(db sql.Database, name string, params []sql.Expression, proc *Proced Procedure: proc, asOf: asOf, cat: catalog, + Ops: ops, } } @@ -178,9 +179,6 @@ func (c *Call) DebugString() string { } else { tp.WriteNode("CALL %s.%s(%s)", c.db.Name(), c.Name, paramStr) } - if c.Procedure != nil { - tp.WriteChildren(sql.DebugString(c.Procedure.Body)) - } return tp.String() } @@ -209,6 +207,35 @@ func (c *Call) Dispose() { // SetStatementRunner implements the sql.InterpreterNode interface. func (c *Call) SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Node { nc := *c - nc.runner = runner + nc.Runner = runner return &nc } + +// GetRunner implements the sql.InterpreterNode interface. +func (c *Call) GetRunner() sql.StatementRunner { + return c.Runner +} + +// GetParameters implements the sql.InterpreterNode interface. +func (c *Call) GetParameters() []sql.Type { + return nil +} + +// GetParameterNames implements the sql.InterpreterNode interface. +func (c *Call) GetParameterNames() []string { + return nil +} + +// GetStatements implements the sql.InterpreterNode interface. +func (c *Call) GetStatements() []procedures.InterpreterOperation { + return c.Ops +} + +// GetReturn implements the sql.InterpreterNode interface. +func (c *Call) GetReturn() sql.Type { + return nil +} + + + + diff --git a/sql/plan/procedure.go b/sql/plan/procedure.go index fc21b5b96a..5fb6d8a88b 100644 --- a/sql/plan/procedure.go +++ b/sql/plan/procedure.go @@ -20,9 +20,10 @@ import ( "strings" "time" - "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/procedures" + "github.com/dolthub/go-mysql-server/sql/types" ) // ProcedureSecurityContext determines whether the stored procedure is executed using the privileges of the definer or @@ -80,7 +81,8 @@ type Procedure struct { Comment string Characteristics []Characteristic CreateProcedureString string - Body sql.Node + Ops []procedures.InterpreterOperation + ExternalProc sql.Node CreatedAt time.Time ModifiedAt time.Time ValidationError error @@ -100,7 +102,7 @@ func NewProcedure( comment string, characteristics []Characteristic, createProcedureString string, - body sql.Node, + ops []procedures.InterpreterOperation, createdAt time.Time, modifiedAt time.Time, ) *Procedure { @@ -121,7 +123,7 @@ func NewProcedure( Comment: comment, Characteristics: characteristics, CreateProcedureString: createProcedureString, - Body: body, + Ops: ops, CreatedAt: createdAt, ModifiedAt: modifiedAt, } @@ -129,47 +131,41 @@ func NewProcedure( // Resolved implements the sql.Node interface. func (p *Procedure) Resolved() bool { - return p.Body.Resolved() + return true } func (p *Procedure) IsReadOnly() bool { - return p.Body.IsReadOnly() + return false } // String implements the sql.Node interface. func (p *Procedure) String() string { - return p.Body.String() + return "" } // DebugString implements the sql.DebugStringer interface. func (p *Procedure) DebugString() string { - return sql.DebugString(p.Body) + return sql.DebugString(p.Ops) } // Schema implements the sql.Node interface. func (p *Procedure) Schema() sql.Schema { - return p.Body.Schema() + return types.OkResultSchema } // Children implements the sql.Node interface. func (p *Procedure) Children() []sql.Node { - return []sql.Node{p.Body} + return nil } // WithChildren implements the sql.Node interface. func (p *Procedure) WithChildren(children ...sql.Node) (sql.Node, error) { - if len(children) != 1 { - return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) - } - - np := *p - np.Body = children[0] - return &np, nil + return p, nil } // CollationCoercibility implements the interface sql.CollationCoercible. func (p *Procedure) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { - return sql.GetCoercibility(ctx, p.Body) + return sql.GetCoercibility(ctx, p.Ops) } // implementsRepresentsBlock implements the RepresentsBlock interface. @@ -181,9 +177,9 @@ func (p *Procedure) ExtendVariadic(ctx *sql.Context, length int) *Procedure { return p } np := *p - body := p.Body.(*ExternalProcedure) + body := p.ExternalProc.(*ExternalProcedure) newBody := *body - np.Body = &newBody + np.ExternalProc = &newBody newParamDefinitions := make([]ProcedureParam, length) newParams := make([]*expression.ProcedureParam, length) @@ -226,7 +222,7 @@ func (p *Procedure) HasVariadicParameter() bool { // IsExternal returns whether the stored procedure is external. func (p *Procedure) IsExternal() bool { - if _, ok := p.Body.(*ExternalProcedure); ok { + if _, ok := p.ExternalProc.(*ExternalProcedure); ok { return true } return false diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 816f053940..799981217c 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -257,24 +257,14 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, isCreateProc bool, // TODO: convert ast to operations - procedures.Parse(procStmt.ProcedureSpec.Body) + ops, err := procedures.Parse(procStmt.ProcedureSpec.Body) + if err != nil { + b.handleErr(err) + } procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) characteristics, securityType, comment := b.buildProcedureCharacteristics(procStmt.ProcedureSpec.Characteristics) - // populate inScope with the procedure parameters. this will be - // subject maybe a bug where an inner procedure has access to - // outer procedure parameters. - if inScope == nil { - inScope = b.newScope() - } - inScope.initProc() - for _, p := range procParams { - inScope.proc.AddVar(expression.NewProcedureParam(strings.ToLower(p.Name), p.Type)) - } - - bodyStr := strings.TrimSpace(procDetails.CreateStatement[procStmt.SubStatementPositionStart:procStmt.SubStatementPositionEnd]) - bodyScope := b.buildSubquery(inScope, procStmt.ProcedureSpec.Body, bodyStr, procDetails.CreateStatement) proc = plan.NewProcedure( procDetails.Name, @@ -284,12 +274,12 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, isCreateProc bool, comment, characteristics, procDetails.CreateStatement, - bodyScope.node, + ops, procDetails.CreatedAt, procDetails.ModifiedAt, ) - qFlags = b.qFlags - return + + return proc, qFlags, nil } func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { @@ -332,7 +322,7 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { procDetails, ok, err = spdb.GetStoredProcedure(b.ctx, procName) if err == nil { if ok { - proc, innerQFlags, err = BuildProcedureHelper(b.ctx, b.cat, false, inScope, db, asOf, procDetails) + ops, innerQFlags, err = BuildProcedureHelper(b.ctx, b.cat, false, inScope, db, asOf, procDetails) // TODO: somewhat hacky way of preserving this flag // This is necessary so that the resolveSubqueries analyzer rule // will apply NodeExecBuilder to Subqueries in procedure body diff --git a/sql/procedures.go b/sql/procedures.go index d02d854ffb..874486dcc5 100644 --- a/sql/procedures.go +++ b/sql/procedures.go @@ -18,6 +18,8 @@ import ( "fmt" "time" + "github.com/dolthub/go-mysql-server/sql/procedures" + "github.com/dolthub/vitess/go/vt/sqlparser" ) @@ -34,6 +36,10 @@ type Interpreter interface { // implemented as a set of operations that are interpreted during runtime). type InterpreterNode interface { SetStatementRunner(ctx *Context, runner StatementRunner) Node + GetParameters() []Type + GetParameterNames() []string + GetReturn() Type + GetStatements() []procedures.InterpreterOperation } // StatementRunner is essentially an interface that the engine will implement. We cannot directly reference the engine diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go new file mode 100644 index 0000000000..346ae23ba1 --- /dev/null +++ b/sql/procedures/interpreter_logic.go @@ -0,0 +1,144 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package procedures + +import ( + "fmt" + "github.com/dolthub/doltgresql/core" +"github.com/dolthub/doltgresql/core/id" +"github.com/dolthub/go-mysql-server/sql" +) + +// InterpreterNode is an interface that implements an interpreter. These are typically used for functions (which may be +// implemented as a set of operations that are interpreted during runtime). +type InterpreterNode interface { + GetRunner() sql.StatementRunner + GetParameters() []sql.Type + GetParameterNames() []string + GetReturn() sql.Type + GetStatements() []InterpreterOperation +} + +// Call runs the contained operations on the given runner. +func Call(ctx *sql.Context, call sql.InterpreterNode, runner sql.StatementRunner) (any, error) { + // Set up the initial state of the function + counter := -1 // We increment before accessing, so start at -1 + stack := NewInterpreterStack() + // Add the parameters + parameterTypes := call.GetParameters() + parameterNames := call.GetParameterNames() + if len(vals) != len(parameterTypes) { + return nil, fmt.Errorf("parameter count mismatch: expected %d got %d", len(parameterTypes), len(vals)) + } + for i := range vals { + stack.NewVariableWithValue(parameterNames[i], parameterTypes[i], vals[i]) + } + // Run the statements + statements := call.Ops + for { + counter++ + if counter >= len(statements) { + break + } else if counter < 0 { + panic("negative function counter") + } + + operation := statements[counter] + switch operation.OpCode { + case OpCode_Select: + + case OpCode_Declare: + typeCollection, err := core.GetTypesCollectionFromContext(ctx) + if err != nil { + return nil, err + } + resolvedType, exists := typeCollection.GetType(id.NewType("pg_catalog", operation.PrimaryData)) + if !exists { + return nil, pgtypes.ErrTypeDoesNotExist.New(operation.PrimaryData) + } + stack.NewVariable(operation.Target, resolvedType) + case OpCode_Exception: + // TODO: implement + case OpCode_Execute: + if len(operation.Target) > 0 { + target := stack.GetVariable(operation.Target) + if target == nil { + return nil, fmt.Errorf("variable `%s` could not be found", operation.Target) + } + retVal, err := iFunc.QuerySingleReturn(ctx, stack, operation.PrimaryData, target.Type, operation.SecondaryData) + if err != nil { + return nil, err + } + err = stack.SetVariable(ctx, operation.Target, retVal) + if err != nil { + return nil, err + } + } else { + rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) + if err != nil { + return nil, err + } + if _, err = sql.RowIterToRows(ctx, rowIter); err != nil { + return nil, err + } + } + + case OpCode_Goto: + // We must compare to the index - 1, so that the increment hits our target + if counter <= operation.Index { + for ; counter < operation.Index-1; counter++ { + switch statements[counter].OpCode { + case OpCode_ScopeBegin: + stack.PushScope() + case OpCode_ScopeEnd: + stack.PopScope() + } + } + } else { + for ; counter > operation.Index-1; counter-- { + switch statements[counter].OpCode { + case OpCode_ScopeBegin: + stack.PopScope() + case OpCode_ScopeEnd: + stack.PushScope() + } + } + } + case OpCode_If: + sch, rowIter, _, err := runner.QueryWithBindings(ctx, "", operation.PrimaryData, nil, nil) + if err != nil { + return nil, err + } + row, err := rowIter.Next(ctx) + if err != nil { + return nil, err + } + rowIter.Close(ctx) + if retVal.(bool) { + // We're never changing the scope, so we can just assign it directly. + // Also, we must assign to index-1, so that the increment hits our target. + counter = operation.Index - 1 + } + + case OpCode_ScopeBegin: + stack.PushScope() + case OpCode_ScopeEnd: + stack.PopScope() + default: + panic("unimplemented opcode") + } + } + return nil, nil +} diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index 7521a22691..3893c0a748 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -10,41 +10,28 @@ package procedures +import ast "github.com/dolthub/vitess/go/vt/sqlparser" + // OpCode states the operation to be performed. Most operations have a direct analogue to a Pl/pgSQL operation, however // some exist only in Doltgres (specific to our interpreter implementation). type OpCode uint16 const ( - OpCode_Alias OpCode = iota // https://www.postgresql.org/docs/15/plpgsql-declarations.html#PLPGSQL-DECLARATION-ALIAS - OpCode_Assign // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-ASSIGNMENT - OpCode_Case // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS + OpCode_Select OpCode = iota OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html - OpCode_DeleteInto // https://www.postgresql.org/docs/15/plpgsql-statements.html OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING - OpCode_Execute // Executing a standard SQL statement (expects no rows returned unless Target is specified) - OpCode_ExecuteDynamic // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-EXECUTING-DYN - OpCode_For // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONTROL-STRUCTURES-LOOPS - OpCode_Foreach // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONTROL-STRUCTURES-LOOPS - OpCode_Get // https://www.postgresql.org/docs/15/plpgsql-statements.html#PLPGSQL-STATEMENTS-DIAGNOSTICS + OpCode_Execute // Everything that's not a SELECT OpCode_Goto // All control-flow structures can be represented using Goto OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS - OpCode_InsertInto // https://www.postgresql.org/docs/15/plpgsql-statements.html - OpCode_Loop // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONTROL-STRUCTURES-LOOPS - OpCode_Perform // https://www.postgresql.org/docs/15/plpgsql-statements.html - OpCode_Query // This is just a standard query, nothing special OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING OpCode_ScopeBegin // This is used for scope control, specific to Doltgres OpCode_ScopeEnd // This is used for scope control, specific to Doltgres - OpCode_SelectInto // https://www.postgresql.org/docs/15/plpgsql-statements.html - OpCode_When // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS - OpCode_While // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONTROL-STRUCTURES-LOOPS - OpCode_UpdateInto // https://www.postgresql.org/docs/15/plpgsql-statements.html ) // InterpreterOperation is an operation that will be performed by the interpreter. type InterpreterOperation struct { OpCode OpCode - PrimaryData string // This will represent the "main" data, such as the query for PERFORM, expression for IF, etc. + PrimaryData ast.Statement // This will represent the "main" data, such as the query for PERFORM, expression for IF, etc. SecondaryData []string // This represents auxiliary data, such as bindings, strictness, etc. Target string // This is the variable that will store the results (if applicable) Index int // This is the index that should be set for operations that move the function counter diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 99370f6f17..e904ffc0e3 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -40,11 +40,10 @@ type InterpreterScopeDetails struct { // general purpose. type InterpreterStack struct { stack *utils.Stack[*InterpreterScopeDetails] - runner analyzer.StatementRunner } // NewInterpreterStack creates a new InterpreterStack. -func NewInterpreterStack(runner analyzer.StatementRunner) InterpreterStack { +func NewInterpreterStack() InterpreterStack { stack := utils.NewStack[*InterpreterScopeDetails]() // This first push represents the function base, including parameters stack.Push(&InterpreterScopeDetails{ @@ -52,7 +51,6 @@ func NewInterpreterStack(runner analyzer.StatementRunner) InterpreterStack { }) return InterpreterStack{ stack: stack, - runner: runner, } } @@ -62,7 +60,7 @@ func (is *InterpreterStack) Details() *InterpreterScopeDetails { } // Runner returns the runner that is being used for the function's execution. -func (is *InterpreterStack) Runner() analyzer.StatementRunner { +func (is *InterpreterStack) Runner() sql.StatementRunner { return is.runner } diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 396b50eca6..86bd6ad2da 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -18,30 +18,57 @@ import ( ast "github.com/dolthub/vitess/go/vt/sqlparser" ) -func ConvertStmt(stmt *ast.Statement) (Block, error) { - block := Block{} +func ConvertStmt(ops *[]InterpreterOperation, stack *InterpreterStack, stmt ast.Statement) error { switch s := stmt.(type) { case *ast.BeginEndBlock: - // TODO: convert this into what? operations? + stack.PushScope() + startOP := InterpreterOperation{ + OpCode: OpCode_ScopeBegin, + } + *ops = append(*ops, startOP) - } + // TODO: add declares + for _, ss := range s.Statements { + if err := ConvertStmt(ops, stack, ss); err != nil { + return err + } + } + endOp := InterpreterOperation{ + OpCode: OpCode_ScopeEnd, + } + *ops = append(*ops, endOp) + stack.PopScope() + case *ast.Select: + selectOp := InterpreterOperation{ + OpCode: OpCode_Select, + PrimaryData: s, + } + *ops = append(*ops, selectOp) + + case *ast.Declare: + // TODO: + //declareOp := InterpreterOperation{} + //stack.NewVariable + default: + execOp := InterpreterOperation{ + OpCode: OpCode_Execute, + PrimaryData: s, + } + *ops = append(*ops, execOp) + } - return block, nil + return nil } // Parse parses the given CREATE FUNCTION string (which must be the entire string, not just the body) into a Block // containing the contents of the body. -func Parse(stmt *ast.Statement) ([]InterpreterOperation, error) { - block, err := ConvertStmt(stmt) +func Parse(stmt ast.Statement) ([]InterpreterOperation, error) { + ops := make([]InterpreterOperation, 0, 64) + stack := NewInterpreterStack() + err := ConvertStmt(&ops, &stack, stmt) if err != nil { return nil, err } - - ops := make([]InterpreterOperation, 0, len(block.Body)+len(block.Variable)) - stack := NewInterpreterStack(nil) - if err := block.AppendOperations(&ops, &stack); err != nil { - return nil, err - } return ops, nil } \ No newline at end of file diff --git a/sql/procedures/statements.go b/sql/procedures/statements.go index fe0133eaac..8a16384d51 100644 --- a/sql/procedures/statements.go +++ b/sql/procedures/statements.go @@ -14,11 +14,6 @@ package procedures -import ( - "fmt"pg_query "github.com/pganalyze/pg_query_go/v6" - -) - // Statement represents a PL/pgSQL statement. type Statement interface { // OperationSize reports the number of operations that the statement will convert to. @@ -43,15 +38,8 @@ func (Assignment) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt Assignment) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { - expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack) - if err != nil { - return err - } - *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_Assign, - PrimaryData: "SELECT " + expression + ";", - SecondaryData: referencedVariables, Target: stmt.VariableName, }) return nil @@ -123,14 +111,8 @@ func (ExecuteSQL) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt ExecuteSQL) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { - statementStr, referencedVariables, err := substituteVariableReferences(stmt.Statement, stack) - if err != nil { - return err - } *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_Execute, - PrimaryData: statementStr, - SecondaryData: referencedVariables, Target: stmt.Target, }) return nil @@ -172,15 +154,10 @@ func (If) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt If) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { - condition, referencedVariables, err := substituteVariableReferences(stmt.Condition, stack) - if err != nil { - return err - } *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_If, - PrimaryData: "SELECT " + condition + ";", - SecondaryData: referencedVariables, + PrimaryData: "SELECT ;", Index: len(*ops) + int(stmt.GotoOffset), }) return nil @@ -200,15 +177,9 @@ func (Perform) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt Perform) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { - statementStr, referencedVariables, err := substituteVariableReferences(stmt.Statement, stack) - if err != nil { - return err - } *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_Perform, - PrimaryData: statementStr, - SecondaryData: referencedVariables, }) return nil } @@ -227,17 +198,8 @@ func (Return) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt Return) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { - expression, referencedVariables, err := substituteVariableReferences(stmt.Expression, stack) - if err != nil { - return err - } - if len(expression) > 0 { - expression = "SELECT " + expression + ";" - } *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_Return, - PrimaryData: expression, - SecondaryData: referencedVariables, }) return nil } @@ -256,27 +218,4 @@ func OperationSizeForStatements(stmts []Statement) int32 { total += stmt.OperationSize() } return total -} - -// substituteVariableReferences parses the specified |expression| and replaces -// any token that matches a variable name in the |stack| with "$N", where N -// indicates which variable in the returned |referenceVars| slice is used. -func substituteVariableReferences(expression string, stack *InterpreterStack) (newExpression string, referencedVars []string, err error) { - scanResult, err := pg_query.Scan(expression) - if err != nil { - return "", nil, err - } - - varMap := stack.ListVariables() - for _, token := range scanResult.Tokens { - substring := expression[token.Start:token.End] - if _, ok := varMap[substring]; ok { - referencedVars = append(referencedVars, substring) - newExpression += fmt.Sprintf("$%d ", len(referencedVars)) - } else { - newExpression += substring + " " - } - } - - return newExpression, referencedVars, nil } \ No newline at end of file diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 6732949fc9..837da36ae2 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -17,7 +17,8 @@ package rowexec import ( "errors" "fmt" - "io" + "github.com/dolthub/go-mysql-server/sql/procedures" +"io" "strings" "github.com/dolthub/go-mysql-server/sql" @@ -197,6 +198,24 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq n.Pref.PushScope() defer n.Pref.PopScope(ctx) + procedures.Call(ctx, n, n.Runner) + + for _, stmt := range n.Ops { + _, rowIter, _, err := n.Runner.QueryWithBindings(ctx, "", stmt.PrimaryData, nil, nil) + if err != nil { + return nil, err + } + for { + if _, err = rowIter.Next(ctx); err != nil { + if err == io.EOF { + break + } + return nil, err + } + } + } + + // TODO: mirror plpgsql interpreter_logic.go Call() // TODO: instead of building, run the actual operations // This means call the runner.QueryWithBindings From d9efee57274bbe9cca2f2b052753b6065b81f849 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 14 Feb 2025 12:39:32 -0800 Subject: [PATCH 003/111] compiling? --- sql/procedures/interpreter_logic.go | 87 ++++++++++++++--------------- sql/procedures/interpreter_stack.go | 78 +++++++++++++++++++++----- 2 files changed, 106 insertions(+), 59 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 346ae23ba1..7b67cbac62 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -16,9 +16,10 @@ package procedures import ( "fmt" - "github.com/dolthub/doltgresql/core" -"github.com/dolthub/doltgresql/core/id" -"github.com/dolthub/go-mysql-server/sql" + "io" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" ) // InterpreterNode is an interface that implements an interpreter. These are typically used for functions (which may be @@ -32,13 +33,13 @@ type InterpreterNode interface { } // Call runs the contained operations on the given runner. -func Call(ctx *sql.Context, call sql.InterpreterNode, runner sql.StatementRunner) (any, error) { +func Call(ctx *sql.Context, iNode sql.InterpreterNode, runner sql.StatementRunner, vals []any) (any, error) { // Set up the initial state of the function counter := -1 // We increment before accessing, so start at -1 stack := NewInterpreterStack() // Add the parameters - parameterTypes := call.GetParameters() - parameterNames := call.GetParameterNames() + parameterTypes := iNode.GetParameters() + parameterNames := iNode.GetParameterNames() if len(vals) != len(parameterTypes) { return nil, fmt.Errorf("parameter count mismatch: expected %d got %d", len(parameterTypes), len(vals)) } @@ -46,7 +47,7 @@ func Call(ctx *sql.Context, call sql.InterpreterNode, runner sql.StatementRunner stack.NewVariableWithValue(parameterNames[i], parameterTypes[i], vals[i]) } // Run the statements - statements := call.Ops + statements := iNode.GetStatements() for { counter++ if counter >= len(statements) { @@ -58,43 +59,28 @@ func Call(ctx *sql.Context, call sql.InterpreterNode, runner sql.StatementRunner operation := statements[counter] switch operation.OpCode { case OpCode_Select: - + // TODO case OpCode_Declare: - typeCollection, err := core.GetTypesCollectionFromContext(ctx) - if err != nil { - return nil, err - } - resolvedType, exists := typeCollection.GetType(id.NewType("pg_catalog", operation.PrimaryData)) - if !exists { - return nil, pgtypes.ErrTypeDoesNotExist.New(operation.PrimaryData) - } + resolvedType := types.Uint32 // TODO: figure out actual type from operation stack.NewVariable(operation.Target, resolvedType) case OpCode_Exception: // TODO: implement case OpCode_Execute: - if len(operation.Target) > 0 { - target := stack.GetVariable(operation.Target) - if target == nil { - return nil, fmt.Errorf("variable `%s` could not be found", operation.Target) - } - retVal, err := iFunc.QuerySingleReturn(ctx, stack, operation.PrimaryData, target.Type, operation.SecondaryData) - if err != nil { - return nil, err - } - err = stack.SetVariable(ctx, operation.Target, retVal) - if err != nil { - return nil, err - } - } else { - rowIter, err := iFunc.QueryMultiReturn(ctx, stack, operation.PrimaryData, operation.SecondaryData) - if err != nil { - return nil, err - } - if _, err = sql.RowIterToRows(ctx, rowIter); err != nil { + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", operation.PrimaryData, nil, nil) + if err != nil { + return nil, err + } + for { + if _, rErr := rowIter.Next(ctx); rErr != nil { + if rErr == io.EOF { + break + } return nil, err } } - + if err = rowIter.Close(ctx); err != nil { + return nil, err + } case OpCode_Goto: // We must compare to the index - 1, so that the increment hits our target if counter <= operation.Index { @@ -104,6 +90,8 @@ func Call(ctx *sql.Context, call sql.InterpreterNode, runner sql.StatementRunner stack.PushScope() case OpCode_ScopeEnd: stack.PopScope() + default: + // No-op } } } else { @@ -113,24 +101,33 @@ func Call(ctx *sql.Context, call sql.InterpreterNode, runner sql.StatementRunner stack.PopScope() case OpCode_ScopeEnd: stack.PushScope() + default: + // No-op } } } case OpCode_If: - sch, rowIter, _, err := runner.QueryWithBindings(ctx, "", operation.PrimaryData, nil, nil) + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", operation.PrimaryData, nil, nil) if err != nil { return nil, err } - row, err := rowIter.Next(ctx) - if err != nil { - return nil, err + for { + if _, rErr := rowIter.Next(ctx); rErr != nil { + if rErr == io.EOF { + break + } + return nil, err + } } - rowIter.Close(ctx) - if retVal.(bool) { - // We're never changing the scope, so we can just assign it directly. - // Also, we must assign to index-1, so that the increment hits our target. - counter = operation.Index - 1 + if err = rowIter.Close(ctx); err != nil { + return nil, err } + // TODO: ensure there is exactly one result that is a bool + //if retVal.(bool) { + // // We're never changing the scope, so we can just assign it directly. + // // Also, we must assign to index-1, so that the increment hits our target. + // counter = operation.Index - 1 + //} case OpCode_ScopeBegin: stack.PushScope() diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index e904ffc0e3..7f0459a0f8 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -18,15 +18,70 @@ import ( "fmt" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/analyzer" - - pgtypes "github.com/dolthub/doltgresql/server/types" - "github.com/dolthub/doltgresql/utils" ) +// Stack is a generic stack. +type Stack[T any] struct { + values []T +} + +// NewStack creates a new, empty stack. +func NewStack[T any]() *Stack[T] { + return &Stack[T]{} +} + +// Len returns the size of the stack. +func (s *Stack[T]) Len() int { + return len(s.values) +} + +// Peek returns the top value on the stack without removing it. +func (s *Stack[T]) Peek() (value T) { + if len(s.values) == 0 { + return + } + return s.values[len(s.values)-1] +} + +// PeekDepth returns the n-th value from the top. PeekDepth(0) is equivalent to the standard Peek(). +func (s *Stack[T]) PeekDepth(depth int) (value T) { + if len(s.values) <= depth { + return + } + return s.values[len(s.values)-(1+depth)] +} + +// PeekReference returns a reference to the top value on the stack without removing it. +func (s *Stack[T]) PeekReference() *T { + if len(s.values) == 0 { + return nil + } + return &s.values[len(s.values)-1] +} + +// Pop returns the top value on the stack while also removing it from the stack. +func (s *Stack[T]) Pop() (value T) { + if len(s.values) == 0 { + return + } + value = s.values[len(s.values)-1] + s.values = s.values[:len(s.values)-1] + return +} + +// Push adds the given value to the stack. +func (s *Stack[T]) Push(value T) { + s.values = append(s.values, value) +} + +// Empty returns whether the stack is empty. +func (s *Stack[T]) Empty() bool { + return len(s.values) == 0 +} + // InterpreterVariable is a variable that lives on the stack. type InterpreterVariable struct { - Type *pgtypes.DoltgresType + Type sql.Type Value any } @@ -39,12 +94,12 @@ type InterpreterScopeDetails struct { // the same as a stack in the traditional programming sense, but rather is a loose abstraction that serves the same // general purpose. type InterpreterStack struct { - stack *utils.Stack[*InterpreterScopeDetails] + stack *Stack[*InterpreterScopeDetails] } // NewInterpreterStack creates a new InterpreterStack. func NewInterpreterStack() InterpreterStack { - stack := utils.NewStack[*InterpreterScopeDetails]() + stack := NewStack[*InterpreterScopeDetails]() // This first push represents the function base, including parameters stack.Push(&InterpreterScopeDetails{ variables: make(map[string]*InterpreterVariable), @@ -59,11 +114,6 @@ func (is *InterpreterStack) Details() *InterpreterScopeDetails { return is.stack.Peek() } -// Runner returns the runner that is being used for the function's execution. -func (is *InterpreterStack) Runner() sql.StatementRunner { - return is.runner -} - // GetVariable traverses the stack (starting from the top) to find a variable with a matching name. Returns nil if no // variable was found. func (is *InterpreterStack) GetVariable(name string) *InterpreterVariable { @@ -88,12 +138,12 @@ func (is *InterpreterStack) ListVariables() map[string]struct{} { // NewVariable creates a new variable in the current scope. If a variable with the same name exists in a previous scope, // then that variable will be shadowed until the current scope exits. -func (is *InterpreterStack) NewVariable(name string, typ *pgtypes.DoltgresType) { +func (is *InterpreterStack) NewVariable(name string, typ sql.Type) { is.NewVariableWithValue(name, typ, typ.Zero()) } // NewVariableWithValue creates a new variable in the current scope, setting its initial value to the one given. -func (is *InterpreterStack) NewVariableWithValue(name string, typ *pgtypes.DoltgresType, val any) { +func (is *InterpreterStack) NewVariableWithValue(name string, typ sql.Type, val any) { is.stack.Peek().variables[name] = &InterpreterVariable{ Type: typ, Value: val, From 0ce24166c2f7ec7d6fee758eba9969531d550b8e Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 14 Feb 2025 20:41:30 +0000 Subject: [PATCH 004/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/analyzer/interpreter.go | 1 - sql/plan/call.go | 5 ----- sql/planbuilder/proc.go | 1 - sql/procedures.go | 7 ------- sql/procedures/interpreter_operation.go | 28 ++++++++++++------------- sql/procedures/interpreter_stack.go | 6 +++--- sql/procedures/parse.go | 6 +++--- sql/procedures/statements.go | 20 +++++++++--------- sql/rowexec/proc.go | 7 +++---- 9 files changed, 33 insertions(+), 48 deletions(-) diff --git a/sql/analyzer/interpreter.go b/sql/analyzer/interpreter.go index b6642448b1..6ebb7a47fc 100644 --- a/sql/analyzer/interpreter.go +++ b/sql/analyzer/interpreter.go @@ -15,7 +15,6 @@ package analyzer import ( - "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/go-mysql-server/sql" diff --git a/sql/plan/call.go b/sql/plan/call.go index 4296e9aa93..4fd641c18a 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -23,7 +23,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) - // TODO: we need different types of calls: one for external procedures one for stored procedures type Call struct { @@ -235,7 +234,3 @@ func (c *Call) GetStatements() []procedures.InterpreterOperation { func (c *Call) GetReturn() sql.Type { return nil } - - - - diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 799981217c..55d95e50ce 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -265,7 +265,6 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, isCreateProc bool, procParams := b.buildProcedureParams(procStmt.ProcedureSpec.Params) characteristics, securityType, comment := b.buildProcedureCharacteristics(procStmt.ProcedureSpec.Characteristics) - proc = plan.NewProcedure( procDetails.Name, procStmt.ProcedureSpec.Definer, diff --git a/sql/procedures.go b/sql/procedures.go index 874486dcc5..eabfcb7f59 100644 --- a/sql/procedures.go +++ b/sql/procedures.go @@ -49,13 +49,6 @@ type StatementRunner interface { QueryWithBindings(ctx *Context, query string, parsed sqlparser.Statement, bindings map[string]sqlparser.Expr, qFlags *QueryFlags) (Schema, RowIter, *QueryFlags, error) } - - - - - - - // StoredProcedureDetails are the details of the stored procedure. Integrators only need to store and retrieve the given // details for a stored procedure, as the engine handles all parsing and processing. type StoredProcedureDetails struct { diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index 3893c0a748..20573188e7 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -17,22 +17,22 @@ import ast "github.com/dolthub/vitess/go/vt/sqlparser" type OpCode uint16 const ( - OpCode_Select OpCode = iota - OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html - OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING - OpCode_Execute // Everything that's not a SELECT - OpCode_Goto // All control-flow structures can be represented using Goto - OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS - OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING - OpCode_ScopeBegin // This is used for scope control, specific to Doltgres - OpCode_ScopeEnd // This is used for scope control, specific to Doltgres + OpCode_Select OpCode = iota + OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html + OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING + OpCode_Execute // Everything that's not a SELECT + OpCode_Goto // All control-flow structures can be represented using Goto + OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS + OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING + OpCode_ScopeBegin // This is used for scope control, specific to Doltgres + OpCode_ScopeEnd // This is used for scope control, specific to Doltgres ) // InterpreterOperation is an operation that will be performed by the interpreter. type InterpreterOperation struct { OpCode OpCode - PrimaryData ast.Statement // This will represent the "main" data, such as the query for PERFORM, expression for IF, etc. - SecondaryData []string // This represents auxiliary data, such as bindings, strictness, etc. - Target string // This is the variable that will store the results (if applicable) - Index int // This is the index that should be set for operations that move the function counter -} \ No newline at end of file + PrimaryData ast.Statement // This will represent the "main" data, such as the query for PERFORM, expression for IF, etc. + SecondaryData []string // This represents auxiliary data, such as bindings, strictness, etc. + Target string // This is the variable that will store the results (if applicable) + Index int // This is the index that should be set for operations that move the function counter +} diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 7f0459a0f8..5fb7ae9d6a 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -94,7 +94,7 @@ type InterpreterScopeDetails struct { // the same as a stack in the traditional programming sense, but rather is a loose abstraction that serves the same // general purpose. type InterpreterStack struct { - stack *Stack[*InterpreterScopeDetails] + stack *Stack[*InterpreterScopeDetails] } // NewInterpreterStack creates a new InterpreterStack. @@ -105,7 +105,7 @@ func NewInterpreterStack() InterpreterStack { variables: make(map[string]*InterpreterVariable), }) return InterpreterStack{ - stack: stack, + stack: stack, } } @@ -178,4 +178,4 @@ func (is *InterpreterStack) SetVariable(ctx *sql.Context, name string, val any) } iv.Value = val return nil -} \ No newline at end of file +} diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 86bd6ad2da..e8a3453497 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -40,7 +40,7 @@ func ConvertStmt(ops *[]InterpreterOperation, stack *InterpreterStack, stmt ast. stack.PopScope() case *ast.Select: selectOp := InterpreterOperation{ - OpCode: OpCode_Select, + OpCode: OpCode_Select, PrimaryData: s, } *ops = append(*ops, selectOp) @@ -52,7 +52,7 @@ func ConvertStmt(ops *[]InterpreterOperation, stack *InterpreterStack, stmt ast. default: execOp := InterpreterOperation{ - OpCode: OpCode_Execute, + OpCode: OpCode_Execute, PrimaryData: s, } *ops = append(*ops, execOp) @@ -71,4 +71,4 @@ func Parse(stmt ast.Statement) ([]InterpreterOperation, error) { return nil, err } return ops, nil -} \ No newline at end of file +} diff --git a/sql/procedures/statements.go b/sql/procedures/statements.go index 8a16384d51..a7c0217bbf 100644 --- a/sql/procedures/statements.go +++ b/sql/procedures/statements.go @@ -39,8 +39,8 @@ func (Assignment) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt Assignment) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Assign, - Target: stmt.VariableName, + OpCode: OpCode_Assign, + Target: stmt.VariableName, }) return nil } @@ -112,8 +112,8 @@ func (ExecuteSQL) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt ExecuteSQL) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Execute, - Target: stmt.Target, + OpCode: OpCode_Execute, + Target: stmt.Target, }) return nil } @@ -156,9 +156,9 @@ func (If) OperationSize() int32 { func (stmt If) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_If, - PrimaryData: "SELECT ;", - Index: len(*ops) + int(stmt.GotoOffset), + OpCode: OpCode_If, + PrimaryData: "SELECT ;", + Index: len(*ops) + int(stmt.GotoOffset), }) return nil } @@ -179,7 +179,7 @@ func (Perform) OperationSize() int32 { func (stmt Perform) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Perform, + OpCode: OpCode_Perform, }) return nil } @@ -199,7 +199,7 @@ func (Return) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt Return) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Return, + OpCode: OpCode_Return, }) return nil } @@ -218,4 +218,4 @@ func OperationSizeForStatements(stmts []Statement) int32 { total += stmt.OperationSize() } return total -} \ No newline at end of file +} diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 837da36ae2..e33f37383a 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -17,10 +17,11 @@ package rowexec import ( "errors" "fmt" - "github.com/dolthub/go-mysql-server/sql/procedures" -"io" + "io" "strings" + "github.com/dolthub/go-mysql-server/sql/procedures" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" @@ -215,7 +216,6 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq } } - // TODO: mirror plpgsql interpreter_logic.go Call() // TODO: instead of building, run the actual operations // This means call the runner.QueryWithBindings @@ -226,7 +226,6 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq // TODO: save any select ast rowIters to be returned later - return &callIter{ call: n, innerIter: innerIter, From e9ef94107183af07d93f2c9a8f5da11a205cf2a2 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 14 Feb 2025 13:30:42 -0800 Subject: [PATCH 005/111] compiling now? --- sql/plan/call.go | 2 +- sql/procedures.go | 21 +-------------- sql/procedures/interpreter_logic.go | 8 ++++-- sql/procedures/statements.go | 40 ++++++++++++++--------------- sql/rowexec/proc.go | 24 +++++++---------- 5 files changed, 36 insertions(+), 59 deletions(-) diff --git a/sql/plan/call.go b/sql/plan/call.go index 4296e9aa93..4f44dd4add 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -228,7 +228,7 @@ func (c *Call) GetParameterNames() []string { // GetStatements implements the sql.InterpreterNode interface. func (c *Call) GetStatements() []procedures.InterpreterOperation { - return c.Ops + return c.Procedure.Ops } // GetReturn implements the sql.InterpreterNode interface. diff --git a/sql/procedures.go b/sql/procedures.go index 874486dcc5..4b0c488ef1 100644 --- a/sql/procedures.go +++ b/sql/procedures.go @@ -18,9 +18,7 @@ import ( "fmt" "time" - "github.com/dolthub/go-mysql-server/sql/procedures" - - "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/vitess/go/vt/sqlparser" ) // Interpreter is an interface that implements an interpreter. These are typically used for functions (which may be @@ -32,16 +30,6 @@ type Interpreter interface { // TODO: InterpreterNode interface // TODO: alternatively have plan.Call just have an interpreter expression -// InterpreterNode is an interface that implements an interpreter. These are typically used for functions (which may be -// implemented as a set of operations that are interpreted during runtime). -type InterpreterNode interface { - SetStatementRunner(ctx *Context, runner StatementRunner) Node - GetParameters() []Type - GetParameterNames() []string - GetReturn() Type - GetStatements() []procedures.InterpreterOperation -} - // StatementRunner is essentially an interface that the engine will implement. We cannot directly reference the engine // here as it will cause an import cycle, so this may be updated to suit any function changes that the engine // experiences. @@ -49,13 +37,6 @@ type StatementRunner interface { QueryWithBindings(ctx *Context, query string, parsed sqlparser.Statement, bindings map[string]sqlparser.Expr, qFlags *QueryFlags) (Schema, RowIter, *QueryFlags, error) } - - - - - - - // StoredProcedureDetails are the details of the stored procedure. Integrators only need to store and retrieve the given // details for a stored procedure, as the engine handles all parsing and processing. type StoredProcedureDetails struct { diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 7b67cbac62..66d6704c0a 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -33,7 +33,7 @@ type InterpreterNode interface { } // Call runs the contained operations on the given runner. -func Call(ctx *sql.Context, iNode sql.InterpreterNode, runner sql.StatementRunner, vals []any) (any, error) { +func Call(ctx *sql.Context, iNode InterpreterNode, runner sql.StatementRunner, vals []any) (any, error) { // Set up the initial state of the function counter := -1 // We increment before accessing, so start at -1 stack := NewInterpreterStack() @@ -59,7 +59,11 @@ func Call(ctx *sql.Context, iNode sql.InterpreterNode, runner sql.StatementRunne operation := statements[counter] switch operation.OpCode { case OpCode_Select: - // TODO + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", operation.PrimaryData, nil, nil) + if err != nil { + return nil, err + } + return rowIter, nil case OpCode_Declare: resolvedType := types.Uint32 // TODO: figure out actual type from operation stack.NewVariable(operation.Target, resolvedType) diff --git a/sql/procedures/statements.go b/sql/procedures/statements.go index 8a16384d51..2ac2594c0b 100644 --- a/sql/procedures/statements.go +++ b/sql/procedures/statements.go @@ -38,10 +38,10 @@ func (Assignment) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt Assignment) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { - *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Assign, - Target: stmt.VariableName, - }) + //*ops = append(*ops, InterpreterOperation{ + // OpCode: OpCode_Assign, + // Target: stmt.VariableName, + //}) return nil } @@ -75,13 +75,13 @@ func (stmt Block) AppendOperations(ops *[]InterpreterOperation, stack *Interpret OpCode: OpCode_ScopeBegin, }) for _, variable := range stmt.Variable { - if !variable.IsParameter { - *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Declare, - PrimaryData: variable.Type, - Target: variable.Name, - }) - } + //if !variable.IsParameter { + // *ops = append(*ops, InterpreterOperation{ + // OpCode: OpCode_Declare, + // PrimaryData: variable.Type, + // Target: variable.Name, + // }) + //} stack.NewVariableWithValue(variable.Name, nil, nil) } for _, innerStmt := range stmt.Body { @@ -154,12 +154,11 @@ func (If) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt If) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { - - *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_If, - PrimaryData: "SELECT ;", - Index: len(*ops) + int(stmt.GotoOffset), - }) + //*ops = append(*ops, InterpreterOperation{ + // OpCode: OpCode_If, + // PrimaryData: "SELECT ;", + // Index: len(*ops) + int(stmt.GotoOffset), + //}) return nil } @@ -177,10 +176,9 @@ func (Perform) OperationSize() int32 { // AppendOperations implements the interface Statement. func (stmt Perform) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { - - *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Perform, - }) + //*ops = append(*ops, InterpreterOperation{ + // OpCode: OpCode_Perform, + //}) return nil } diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 837da36ae2..52f0bda73e 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -198,23 +198,17 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq n.Pref.PushScope() defer n.Pref.PopScope(ctx) - procedures.Call(ctx, n, n.Runner) - - for _, stmt := range n.Ops { - _, rowIter, _, err := n.Runner.QueryWithBindings(ctx, "", stmt.PrimaryData, nil, nil) - if err != nil { - return nil, err - } - for { - if _, err = rowIter.Next(ctx); err != nil { - if err == io.EOF { - break - } - return nil, err - } - } + rowIter, err := procedures.Call(ctx, n, n.Runner, nil) + if err != nil { + return nil, err } + return &callIter{ + call: n, + innerIter: rowIter.(sql.RowIter), + }, nil + + // TODO: mirror plpgsql interpreter_logic.go Call() // TODO: instead of building, run the actual operations From 117410ef9fc0b2a701b4784781626553b3a40ce4 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 14 Feb 2025 15:36:34 -0800 Subject: [PATCH 006/111] simplest select statement working --- engine.go | 2 +- sql/analyzer/analyzer.go | 2 +- sql/analyzer/interpreter.go | 6 +++--- sql/plan/common.go | 2 +- sql/plan/procedure.go | 5 +++-- sql/planbuilder/proc.go | 8 +++----- .../resolve_external_stored_procedures.go | 15 ++++++++------- sql/procedures/interpreter_logic.go | 2 ++ sql/rowexec/rel.go | 5 ++++- 9 files changed, 26 insertions(+), 21 deletions(-) diff --git a/engine.go b/engine.go index 80c7685d0b..9e7946a3e7 100644 --- a/engine.go +++ b/engine.go @@ -150,7 +150,7 @@ type Engine struct { Parser sql.Parser } -var _ analyzer.StatementRunner = (*Engine)(nil) +var _ sql.StatementRunner = (*Engine)(nil) type ColumnWithRawDefault struct { SqlColumn *sql.Column diff --git a/sql/analyzer/analyzer.go b/sql/analyzer/analyzer.go index 5bd7df72ea..d9441e77e0 100644 --- a/sql/analyzer/analyzer.go +++ b/sql/analyzer/analyzer.go @@ -287,7 +287,7 @@ type Analyzer struct { // ExecBuilder converts a sql.Node tree into an executable iterator. ExecBuilder sql.NodeExecBuilder // Runner represents the engine, which is represented as a separate interface to work around circular dependencies - Runner StatementRunner + Runner sql.StatementRunner } // NewDefault creates a default Analyzer instance with all default Rules and configuration. diff --git a/sql/analyzer/interpreter.go b/sql/analyzer/interpreter.go index 6ebb7a47fc..9a2b559bf4 100644 --- a/sql/analyzer/interpreter.go +++ b/sql/analyzer/interpreter.go @@ -15,16 +15,16 @@ package analyzer import ( - "github.com/dolthub/go-mysql-server/sql/transform" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/procedures" + "github.com/dolthub/go-mysql-server/sql/transform" ) // interpreter hands the engine to any interpreter expressions. func interpreter(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { newNode, sameNode, err := transform.Node(n, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) { - if interp, ok := node.(sql.InterpreterNode); ok { + if interp, ok := node.(procedures.InterpreterNode); ok { return interp.SetStatementRunner(ctx, a.Runner), transform.NewTree, nil } return node, transform.SameTree, nil diff --git a/sql/plan/common.go b/sql/plan/common.go index f8b954a91c..9fabb03626 100644 --- a/sql/plan/common.go +++ b/sql/plan/common.go @@ -106,7 +106,7 @@ func NodeRepresentsSelect(s sql.Node) bool { case *Call: return NodeRepresentsSelect(node.Procedure) case *Procedure: - return NodeRepresentsSelect(node.Body) + return NodeRepresentsSelect(node.ExternalProc) case *Block: for _, stmt := range node.statements { if NodeRepresentsSelect(stmt) { diff --git a/sql/plan/procedure.go b/sql/plan/procedure.go index 5fb6d8a88b..c0fab99f5f 100644 --- a/sql/plan/procedure.go +++ b/sql/plan/procedure.go @@ -102,9 +102,9 @@ func NewProcedure( comment string, characteristics []Characteristic, createProcedureString string, - ops []procedures.InterpreterOperation, createdAt time.Time, modifiedAt time.Time, + ops []procedures.InterpreterOperation, ) *Procedure { lowercasedParams := make([]ProcedureParam, len(params)) for i, param := range params { @@ -123,9 +123,10 @@ func NewProcedure( Comment: comment, Characteristics: characteristics, CreateProcedureString: createProcedureString, - Ops: ops, CreatedAt: createdAt, ModifiedAt: modifiedAt, + + Ops: ops, } } diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 55d95e50ce..0bdf9b80cb 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -255,8 +255,6 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, isCreateProc bool, stmt, _, _, _ := b.parser.ParseWithOptions(b.ctx, procDetails.CreateStatement, ';', false, b.parserOpts) procStmt := stmt.(*ast.DDL) - // TODO: convert ast to operations - ops, err := procedures.Parse(procStmt.ProcedureSpec.Body) if err != nil { b.handleErr(err) @@ -273,9 +271,9 @@ func BuildProcedureHelper(ctx *sql.Context, cat sql.Catalog, isCreateProc bool, comment, characteristics, procDetails.CreateStatement, - ops, procDetails.CreatedAt, procDetails.ModifiedAt, + ops, ) return proc, qFlags, nil @@ -321,7 +319,7 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { procDetails, ok, err = spdb.GetStoredProcedure(b.ctx, procName) if err == nil { if ok { - ops, innerQFlags, err = BuildProcedureHelper(b.ctx, b.cat, false, inScope, db, asOf, procDetails) + proc, innerQFlags, err = BuildProcedureHelper(b.ctx, b.cat, false, inScope, db, asOf, procDetails) // TODO: somewhat hacky way of preserving this flag // This is necessary so that the resolveSubqueries analyzer rule // will apply NodeExecBuilder to Subqueries in procedure body @@ -353,7 +351,7 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { } outScope = inScope.push() - outScope.node = plan.NewCall(db, procName, params, proc, asOf, b.cat) + outScope.node = plan.NewCall(db, procName, params, proc, asOf, b.cat, nil) return outScope } diff --git a/sql/planbuilder/resolve_external_stored_procedures.go b/sql/planbuilder/resolve_external_stored_procedures.go index bbeacba7bd..56e60e7195 100644 --- a/sql/planbuilder/resolve_external_stored_procedures.go +++ b/sql/planbuilder/resolve_external_stored_procedures.go @@ -145,7 +145,7 @@ func resolveExternalStoredProcedure(externalProcedure sql.ExternalStoredProcedur } } - procedure := plan.NewProcedure( + proc := plan.NewProcedure( externalProcedure.Name, "root", paramDefinitions, @@ -153,13 +153,14 @@ func resolveExternalStoredProcedure(externalProcedure sql.ExternalStoredProcedur "External stored procedure", nil, externalProcedure.FakeCreateProcedureStmt(), - &plan.ExternalProcedure{ - ExternalStoredProcedureDetails: externalProcedure, - ParamDefinitions: paramDefinitions, - Params: paramReferences, - }, time.Unix(1, 0), time.Unix(1, 0), + nil, ) - return procedure, nil + proc.ExternalProc = &plan.ExternalProcedure{ + ExternalStoredProcedureDetails: externalProcedure, + ParamDefinitions: paramDefinitions, + Params: paramReferences, + } + return proc, nil } diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 66d6704c0a..7268c39c71 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -30,6 +30,8 @@ type InterpreterNode interface { GetParameterNames() []string GetReturn() sql.Type GetStatements() []InterpreterOperation + + SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Node } // Call runs the contained operations on the given runner. diff --git a/sql/rowexec/rel.go b/sql/rowexec/rel.go index 6c15ba28ac..c95adc2825 100644 --- a/sql/rowexec/rel.go +++ b/sql/rowexec/rel.go @@ -330,7 +330,10 @@ func (b *BaseBuilder) buildVirtualColumnTable(ctx *sql.Context, n *plan.VirtualC } func (b *BaseBuilder) buildProcedure(ctx *sql.Context, n *plan.Procedure, row sql.Row) (sql.RowIter, error) { - return b.buildNodeExec(ctx, n.Body, row) + if n.ExternalProc == nil { + return nil, nil + } + return b.buildNodeExec(ctx, n.ExternalProc, row) } func (b *BaseBuilder) buildRecursiveTable(ctx *sql.Context, n *plan.RecursiveTable, row sql.Row) (sql.RowIter, error) { From 513f53db2ef9ca76dddd2724c9c5a1a38043f986 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 17 Feb 2025 12:06:51 -0800 Subject: [PATCH 007/111] better --- enginetest/queries/procedure_queries.go | 112 +++++++++++++++++++++--- sql/planbuilder/create_ddl.go | 3 - sql/procedures/interpreter_logic.go | 36 ++++++-- 3 files changed, 130 insertions(+), 21 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 8db4281534..5df50e0eb5 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -18,6 +18,7 @@ import ( "time" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -75,17 +76,13 @@ var ProcedureLogicTests = []ScriptTest{ { Query: "CALL testabc(2, 3)", Expected: []sql.Row{ - { - 6.0, - }, + {6.0}, }, }, { Query: "CALL testabc(9, 9.5)", Expected: []sql.Row{ - { - 85.5, - }, + {85.5}, }, }, }, @@ -2831,19 +2828,112 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, }, + { - Name: "procedure must not contain CREATE TABLE", + Name: "table ddl statements in stored procedures", Assertions: []ScriptTestAssertion{ { - Query: "create procedure p() create table t (pk int);", - ExpectedErrStr: "CREATE statements in CREATE PROCEDURE not yet supported", + Query: "create procedure create_proc() create table t (i int primary key, j int);", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, }, { - Query: "create procedure p() begin create table t (pk int); end;", - ExpectedErrStr: "CREATE statements in CREATE PROCEDURE not yet supported", + Query: "call create_proc()", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "show create table t;", + Expected: []sql.Row{ + {"t", "CREATE TABLE `t` (\n" + + " `i` int NOT NULL,\n" + + " `j` int,\n" + + " PRIMARY KEY (`i`)\n" + + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + }, + }, + { + Query: "call create_proc()", + ExpectedErrStr: "table with name t already exists", + }, + + { + Query: "create procedure insert_proc() insert into t values (1, 1), (2, 2), (3, 3);", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "call insert_proc()", + Expected: []sql.Row{ + {types.NewOkResult(3)}, + }, + }, + { + Query: "select * from t", + Expected: []sql.Row{ + {1, 1}, + {2, 2}, + {3, 3}, + }, + }, + { + Query: "call insert_proc()", + ExpectedErrStr: "duplicate primary key given: [1]", + }, + + { + Query: "create procedure update_proc() update t set j = 999 where i > 1;", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "call update_proc()", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 2, Info: plan.UpdateInfo{Matched: 2, Updated: 2}}}, + }, + }, + { + Query: "select * from t", + Expected: []sql.Row{ + {1, 1}, + {2, 999}, + {3, 999}, + }, + }, + { + Query: "call update_proc()", + Expected: []sql.Row{ + {types.OkResult{RowsAffected: 0, Info: plan.UpdateInfo{Matched: 2}}}, + }, + }, + + { + Query: "create procedure drop_proc() drop table t;", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "call drop_proc()", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "show tables like 't'", + Expected: []sql.Row{}, + }, + { + Query: "call drop_proc()", + ExpectedErrStr: "Unknown table 't'", }, }, }, + { Name: "procedure must not contain CREATE TRIGGER", SetUpScript: []string{ diff --git a/sql/planbuilder/create_ddl.go b/sql/planbuilder/create_ddl.go index 3dac989095..330246fab1 100644 --- a/sql/planbuilder/create_ddl.go +++ b/sql/planbuilder/create_ddl.go @@ -263,9 +263,6 @@ func (b *Builder) validateStatement(inScope *scope, stmt ast.Statement) { if s.TriggerSpec != nil { b.handleErr(fmt.Errorf("can't create a TRIGGER from within another stored routine")) } - //b.handleErr(fmt.Errorf("CREATE statements in CREATE PROCEDURE not yet supported")) - default: - b.handleErr(fmt.Errorf("DDL in CREATE PROCEDURE not yet supported")) } case *ast.DBDDL: b.handleErr(fmt.Errorf("DBDDL in CREATE PROCEDURE not yet supported")) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 7268c39c71..78fd8fd217 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -48,14 +48,18 @@ func Call(ctx *sql.Context, iNode InterpreterNode, runner sql.StatementRunner, v for i := range vals { stack.NewVariableWithValue(parameterNames[i], parameterTypes[i], vals[i]) } + // TODO: eventually return multiple sql.RowIters + var resultRowIter sql.RowIter + // Run the statements statements := iNode.GetStatements() for { counter++ + if counter < 0 { + panic("negative function counter") + } if counter >= len(statements) { break - } else if counter < 0 { - panic("negative function counter") } operation := statements[counter] @@ -65,7 +69,21 @@ func Call(ctx *sql.Context, iNode InterpreterNode, runner sql.StatementRunner, v if err != nil { return nil, err } - return rowIter, nil + var rows []sql.Row + for { + row, rErr := rowIter.Next(ctx) + if rErr != nil { + if rErr == io.EOF { + break + } + return nil, rErr + } + rows = append(rows, row) + } + if err = rowIter.Close(ctx); err != nil { + return nil, err + } + resultRowIter = sql.RowsToRowIter(rows...) case OpCode_Declare: resolvedType := types.Uint32 // TODO: figure out actual type from operation stack.NewVariable(operation.Target, resolvedType) @@ -76,17 +94,21 @@ func Call(ctx *sql.Context, iNode InterpreterNode, runner sql.StatementRunner, v if err != nil { return nil, err } + var rows []sql.Row for { - if _, rErr := rowIter.Next(ctx); rErr != nil { + row, rErr := rowIter.Next(ctx) + if rErr != nil { if rErr == io.EOF { break } - return nil, err + return nil, rErr } + rows = append(rows, row) } if err = rowIter.Close(ctx); err != nil { return nil, err } + resultRowIter = sql.RowsToRowIter(rows...) case OpCode_Goto: // We must compare to the index - 1, so that the increment hits our target if counter <= operation.Index { @@ -122,7 +144,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, runner sql.StatementRunner, v if rErr == io.EOF { break } - return nil, err + return nil, rErr } } if err = rowIter.Close(ctx); err != nil { @@ -143,5 +165,5 @@ func Call(ctx *sql.Context, iNode InterpreterNode, runner sql.StatementRunner, v panic("unimplemented opcode") } } - return nil, nil + return resultRowIter, nil } From eb6d741c19a1821e4df2b7a9debe2d15cfe3b52d Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 18 Feb 2025 13:47:07 -0800 Subject: [PATCH 008/111] progress --- sql/procedures/interpreter_logic.go | 126 ++++++++++++++++++---------- sql/procedures/interpreter_stack.go | 24 ++++++ sql/procedures/parse.go | 8 +- sql/rowexec/proc.go | 19 +++-- 4 files changed, 120 insertions(+), 57 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 78fd8fd217..53ad7b6c09 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -15,9 +15,10 @@ package procedures import ( - "fmt" "io" + ast "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -26,32 +27,81 @@ import ( // implemented as a set of operations that are interpreted during runtime). type InterpreterNode interface { GetRunner() sql.StatementRunner - GetParameters() []sql.Type - GetParameterNames() []string GetReturn() sql.Type GetStatements() []InterpreterOperation - SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Node } +type Parameter struct { + Name string + Type sql.Type + Value any +} + +func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLNode, error) { + switch e := expr.(type) { + case *ast.AliasedExpr: + newExpr, err := replaceVariablesInExpr(stack, e.Expr) + if err != nil { + return nil, err + } + e.Expr = newExpr.(ast.Expr) + case *ast.BinaryExpr: + newLeftExpr, err := replaceVariablesInExpr(stack, e.Left) + if err != nil { + return nil, err + } + newRightExpr, err := replaceVariablesInExpr(stack, e.Right) + if err != nil { + return nil, err + } + e.Left = newLeftExpr.(ast.Expr) + e.Right = newRightExpr.(ast.Expr) + case *ast.ColName: + iv := stack.GetVariable(e.Name.String()) + if iv == nil { + return expr, nil + } + return iv.ToAST(), nil + } + return expr, nil +} + +func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sql.RowIter, error) { + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", stmt, nil, nil) + if err != nil { + return nil, err + } + var rows []sql.Row + for { + row, rErr := rowIter.Next(ctx) + if rErr != nil { + if rErr == io.EOF { + break + } + return nil, rErr + } + rows = append(rows, row) + } + if err = rowIter.Close(ctx); err != nil { + return nil, err + } + return sql.RowsToRowIter(rows...), nil +} + // Call runs the contained operations on the given runner. -func Call(ctx *sql.Context, iNode InterpreterNode, runner sql.StatementRunner, vals []any) (any, error) { +func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, error) { // Set up the initial state of the function counter := -1 // We increment before accessing, so start at -1 stack := NewInterpreterStack() - // Add the parameters - parameterTypes := iNode.GetParameters() - parameterNames := iNode.GetParameterNames() - if len(vals) != len(parameterTypes) { - return nil, fmt.Errorf("parameter count mismatch: expected %d got %d", len(parameterTypes), len(vals)) - } - for i := range vals { - stack.NewVariableWithValue(parameterNames[i], parameterTypes[i], vals[i]) + for _, param := range params { + stack.NewVariableWithValue(param.Name, param.Type, param.Value) } - // TODO: eventually return multiple sql.RowIters - var resultRowIter sql.RowIter // Run the statements + // TODO: eventually return multiple sql.RowIters + var rowIters []sql.RowIter + runner := iNode.GetRunner() statements := iNode.GetStatements() for { counter++ @@ -65,50 +115,34 @@ func Call(ctx *sql.Context, iNode InterpreterNode, runner sql.StatementRunner, v operation := statements[counter] switch operation.OpCode { case OpCode_Select: - _, rowIter, _, err := runner.QueryWithBindings(ctx, "", operation.PrimaryData, nil, nil) - if err != nil { - return nil, err + selectStmt := operation.PrimaryData.(*ast.Select) + if selectStmt.SelectExprs == nil { + panic("select stmt with no select exprs") } - var rows []sql.Row - for { - row, rErr := rowIter.Next(ctx) - if rErr != nil { - if rErr == io.EOF { - break - } - return nil, rErr + for i := range selectStmt.SelectExprs { + newNode, err := replaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) + if err != nil { + return nil, err } - rows = append(rows, row) + selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) } - if err = rowIter.Close(ctx); err != nil { + rowIter, err := query(ctx, runner, selectStmt) + if err != nil { return nil, err } - resultRowIter = sql.RowsToRowIter(rows...) + rowIters = append(rowIters, rowIter) case OpCode_Declare: resolvedType := types.Uint32 // TODO: figure out actual type from operation stack.NewVariable(operation.Target, resolvedType) case OpCode_Exception: // TODO: implement case OpCode_Execute: - _, rowIter, _, err := runner.QueryWithBindings(ctx, "", operation.PrimaryData, nil, nil) + // TODO: replace variables + rowIter, err := query(ctx, runner, operation.PrimaryData) if err != nil { return nil, err } - var rows []sql.Row - for { - row, rErr := rowIter.Next(ctx) - if rErr != nil { - if rErr == io.EOF { - break - } - return nil, rErr - } - rows = append(rows, row) - } - if err = rowIter.Close(ctx); err != nil { - return nil, err - } - resultRowIter = sql.RowsToRowIter(rows...) + rowIters = append(rowIters, rowIter) case OpCode_Goto: // We must compare to the index - 1, so that the increment hits our target if counter <= operation.Index { @@ -165,5 +199,5 @@ func Call(ctx *sql.Context, iNode InterpreterNode, runner sql.StatementRunner, v panic("unimplemented opcode") } } - return resultRowIter, nil + return rowIters[0], nil } diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 5fb7ae9d6a..8ca4a75983 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -16,8 +16,12 @@ package procedures import ( "fmt" + "strconv" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" + + ast "github.com/dolthub/vitess/go/vt/sqlparser" ) // Stack is a generic stack. @@ -85,6 +89,26 @@ type InterpreterVariable struct { Value any } +func (iv *InterpreterVariable) ToAST() *ast.SQLVal { + var astType ast.ValType + var astVal []byte + if types.IsInteger(iv.Type) { + intStr := fmt.Sprintf("%d", iv.Value) + return ast.NewIntVal([]byte(intStr)) + } else if types.IsFloat(iv.Type) { + floatStr := strconv.FormatFloat(iv.Value.(float64), 'f', -1, 64) + return ast.NewFloatVal([]byte(floatStr)) + } else { + astType = ast.StrVal + astVal = []byte(fmt.Sprintf("%s", iv.Value)) + } + + return &ast.SQLVal{ + Type: astType, + Val: astVal, + } +} + // InterpreterScopeDetails contains all of the details that are relevant to a particular scope. type InterpreterScopeDetails struct { variables map[string]*InterpreterVariable diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index e8a3453497..f23793a3be 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -46,9 +46,11 @@ func ConvertStmt(ops *[]InterpreterOperation, stack *InterpreterStack, stmt ast. *ops = append(*ops, selectOp) case *ast.Declare: - // TODO: - //declareOp := InterpreterOperation{} - //stack.NewVariable + declareOp := InterpreterOperation{ + OpCode: OpCode_Declare, + PrimaryData: s, + } + *ops = append(*ops, declareOp) default: execOp := InterpreterOperation{ diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 4d4fcde4b4..de9da6c5b6 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -182,23 +182,26 @@ func (b *BaseBuilder) buildProcedureResolvedTable(ctx *sql.Context, n *plan.Proc } func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sql.RowIter, error) { + procParams := make([]*procedures.Parameter, len(n.Params)) for i, paramExpr := range n.Params { - val, err := paramExpr.Eval(ctx, row) + paramName := strings.ToLower(n.Procedure.Params[i].Name) + paramType := n.Procedure.Params[i].Type + paramVal, err := paramExpr.Eval(ctx, row) if err != nil { return nil, err } - paramName := n.Procedure.Params[i].Name - paramType := n.Procedure.Params[i].Type - err = n.Pref.InitializeVariable(paramName, paramType, val) + paramVal, _, err = paramType.Convert(paramVal) if err != nil { return nil, err } + procParams[i] = &procedures.Parameter{ + Name: paramName, + Value: paramVal, + Type: paramType, + } } - n.Pref.PushScope() - defer n.Pref.PopScope(ctx) - - rowIter, err := procedures.Call(ctx, n, n.Runner, nil) + rowIter, err := procedures.Call(ctx, n, procParams) if err != nil { return nil, err } From b6531549a1ccbc2e18dd05039221efe453850f85 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 18 Feb 2025 23:49:15 -0800 Subject: [PATCH 009/111] simple declare --- sql/procedures/interpreter_logic.go | 19 +++++++++++++++++-- sql/procedures/interpreter_stack.go | 4 ++++ sql/procedures/parse.go | 1 + 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 53ad7b6c09..75eed3b95c 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -16,6 +16,7 @@ package procedures import ( "io" + "strings" ast "github.com/dolthub/vitess/go/vt/sqlparser" @@ -132,8 +133,22 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er } rowIters = append(rowIters, rowIter) case OpCode_Declare: - resolvedType := types.Uint32 // TODO: figure out actual type from operation - stack.NewVariable(operation.Target, resolvedType) + declareStmt := operation.PrimaryData.(*ast.Declare) + for _, decl := range declareStmt.Variables.Names { + var varType sql.Type + switch declareStmt.Variables.VarType.Type { + case "int": + varType = types.Int32 + default: + panic("unimplemented type") + } + varName := strings.ToLower(decl.String()) + if declareStmt.Variables.VarType.Default != nil { + stack.NewVariableWithValue(varName, varType, declareStmt.Variables.VarType.Default) + } else { + stack.NewVariable(varName, varType) + } + } case OpCode_Exception: // TODO: implement case OpCode_Execute: diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 8ca4a75983..62b48b07a3 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -90,6 +90,10 @@ type InterpreterVariable struct { } func (iv *InterpreterVariable) ToAST() *ast.SQLVal { + if sqlVal, isSQLVal := iv.Value.(*ast.SQLVal); isSQLVal { + return sqlVal + } + var astType ast.ValType var astVal []byte if types.IsInteger(iv.Type) { diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index f23793a3be..881efaee49 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -38,6 +38,7 @@ func ConvertStmt(ops *[]InterpreterOperation, stack *InterpreterStack, stmt ast. } *ops = append(*ops, endOp) stack.PopScope() + case *ast.Select: selectOp := InterpreterOperation{ OpCode: OpCode_Select, From ceb8331085a5d3e36448189a55174471673114ba Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 19 Feb 2025 12:24:44 -0800 Subject: [PATCH 010/111] implement if else --- sql/plan/call.go | 2 +- sql/plan/procedure.go | 4 +-- sql/procedures/interpreter_logic.go | 56 ++++++++++++++++++++--------- sql/procedures/parse.go | 54 +++++++++++++++++++++++----- 4 files changed, 89 insertions(+), 27 deletions(-) diff --git a/sql/plan/call.go b/sql/plan/call.go index f04be8f06e..ee37d085d1 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -226,7 +226,7 @@ func (c *Call) GetParameterNames() []string { } // GetStatements implements the sql.InterpreterNode interface. -func (c *Call) GetStatements() []procedures.InterpreterOperation { +func (c *Call) GetStatements() []*procedures.InterpreterOperation { return c.Procedure.Ops } diff --git a/sql/plan/procedure.go b/sql/plan/procedure.go index c0fab99f5f..b7eb377800 100644 --- a/sql/plan/procedure.go +++ b/sql/plan/procedure.go @@ -81,7 +81,7 @@ type Procedure struct { Comment string Characteristics []Characteristic CreateProcedureString string - Ops []procedures.InterpreterOperation + Ops []*procedures.InterpreterOperation ExternalProc sql.Node CreatedAt time.Time ModifiedAt time.Time @@ -104,7 +104,7 @@ func NewProcedure( createProcedureString string, createdAt time.Time, modifiedAt time.Time, - ops []procedures.InterpreterOperation, + ops []*procedures.InterpreterOperation, ) *Procedure { lowercasedParams := make([]ProcedureParam, len(params)) for i, param := range params { diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 75eed3b95c..501940ac45 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -29,7 +29,7 @@ import ( type InterpreterNode interface { GetRunner() sql.StatementRunner GetReturn() sql.Type - GetStatements() []InterpreterOperation + GetStatements() []*InterpreterOperation SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Node } @@ -58,6 +58,17 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e.Left = newLeftExpr.(ast.Expr) e.Right = newRightExpr.(ast.Expr) + case *ast.ComparisonExpr: + newLeftExpr, err := replaceVariablesInExpr(stack, e.Left) + if err != nil { + return nil, err + } + newRightExpr, err := replaceVariablesInExpr(stack, e.Right) + if err != nil { + return nil, err + } + e.Left = newLeftExpr.(ast.Expr) + e.Right = newRightExpr.(ast.Expr) case *ast.ColName: iv := stack.GetVariable(e.Name.String()) if iv == nil { @@ -184,28 +195,38 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er } } case OpCode_If: - _, rowIter, _, err := runner.QueryWithBindings(ctx, "", operation.PrimaryData, nil, nil) + selectStmt := operation.PrimaryData.(*ast.Select) + if selectStmt.SelectExprs == nil { + panic("select stmt with no select exprs") + } + for i := range selectStmt.SelectExprs { + newNode, err := replaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) + if err != nil { + return nil, err + } + selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) + } + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) if err != nil { return nil, err } - for { - if _, rErr := rowIter.Next(ctx); rErr != nil { - if rErr == io.EOF { - break - } - return nil, rErr - } + // TODO: exactly one result that is a bool for now + row, err := rowIter.Next(ctx) + if err != nil { + return nil, err + } + if _, err = rowIter.Next(ctx); err != io.EOF { + return nil, err } if err = rowIter.Close(ctx); err != nil { return nil, err } - // TODO: ensure there is exactly one result that is a bool - //if retVal.(bool) { - // // We're never changing the scope, so we can just assign it directly. - // // Also, we must assign to index-1, so that the increment hits our target. - // counter = operation.Index - 1 - //} + // go to the appropriate block + cond := row[0].(bool) + if !cond { + counter = operation.Index - 1 // index of the else block, offset by 1 + } case OpCode_ScopeBegin: stack.PushScope() case OpCode_ScopeEnd: @@ -214,5 +235,8 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er panic("unimplemented opcode") } } - return rowIters[0], nil + if len(rowIters) == 0 { + panic("no rowIters") + } + return rowIters[len(rowIters)-1], nil } diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 881efaee49..3b3d6eb502 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -18,11 +18,11 @@ import ( ast "github.com/dolthub/vitess/go/vt/sqlparser" ) -func ConvertStmt(ops *[]InterpreterOperation, stack *InterpreterStack, stmt ast.Statement) error { +func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast.Statement) error { switch s := stmt.(type) { case *ast.BeginEndBlock: stack.PushScope() - startOP := InterpreterOperation{ + startOP := &InterpreterOperation{ OpCode: OpCode_ScopeBegin, } *ops = append(*ops, startOP) @@ -33,28 +33,66 @@ func ConvertStmt(ops *[]InterpreterOperation, stack *InterpreterStack, stmt ast. return err } } - endOp := InterpreterOperation{ + endOp := &InterpreterOperation{ OpCode: OpCode_ScopeEnd, } *ops = append(*ops, endOp) stack.PopScope() case *ast.Select: - selectOp := InterpreterOperation{ + selectOp := &InterpreterOperation{ OpCode: OpCode_Select, PrimaryData: s, } *ops = append(*ops, selectOp) case *ast.Declare: - declareOp := InterpreterOperation{ + declareOp := &InterpreterOperation{ OpCode: OpCode_Declare, PrimaryData: s, } *ops = append(*ops, declareOp) + case *ast.IfStatement: + // TODO: assume exactly one condition for now + ifCond := s.Conditions[0] + // TODO: convert condition into a select query + selectIfCond := &ast.Select{ + SelectExprs: ast.SelectExprs{ + &ast.AliasedExpr{ + Expr: ifCond.Expr, + }, + }, + } + ifOp := &InterpreterOperation{ + OpCode: OpCode_If, + PrimaryData: selectIfCond, + } + *ops = append(*ops, ifOp) + + for _, ifStmt := range ifCond.Statements { + if err := ConvertStmt(ops, stack, ifStmt); err != nil { + return err + } + } + gotoOp := &InterpreterOperation{ + OpCode: OpCode_Goto, + } + *ops = append(*ops, gotoOp) + + ifOp.Index = len(*ops) + for _, elseStmt := range s.Else { + if err := ConvertStmt(ops, stack, elseStmt); err != nil { + return err + } + } + + gotoOp.Index = len(*ops) + + // TODO: update the indexes, now that we know where the goto should go + default: - execOp := InterpreterOperation{ + execOp := &InterpreterOperation{ OpCode: OpCode_Execute, PrimaryData: s, } @@ -66,8 +104,8 @@ func ConvertStmt(ops *[]InterpreterOperation, stack *InterpreterStack, stmt ast. // Parse parses the given CREATE FUNCTION string (which must be the entire string, not just the body) into a Block // containing the contents of the body. -func Parse(stmt ast.Statement) ([]InterpreterOperation, error) { - ops := make([]InterpreterOperation, 0, 64) +func Parse(stmt ast.Statement) ([]*InterpreterOperation, error) { + ops := make([]*InterpreterOperation, 0, 64) stack := NewInterpreterStack() err := ConvertStmt(&ops, &stack, stmt) if err != nil { From 337227ff37de1a8a8e08b5816a9d31afaadbbd12 Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 19 Feb 2025 20:26:03 +0000 Subject: [PATCH 011/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/queries/procedure_queries.go | 28 ++++++++++++------------- sql/plan/procedure.go | 2 +- sql/procedures.go | 2 +- sql/procedures/interpreter_stack.go | 2 +- sql/rowexec/proc.go | 1 - 5 files changed, 17 insertions(+), 18 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 5df50e0eb5..deb46bbd77 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2833,13 +2833,13 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ Name: "table ddl statements in stored procedures", Assertions: []ScriptTestAssertion{ { - Query: "create procedure create_proc() create table t (i int primary key, j int);", + Query: "create procedure create_proc() create table t (i int primary key, j int);", Expected: []sql.Row{ {types.NewOkResult(0)}, }, }, { - Query: "call create_proc()", + Query: "call create_proc()", Expected: []sql.Row{ {types.NewOkResult(0)}, }, @@ -2855,24 +2855,24 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { - Query: "call create_proc()", + Query: "call create_proc()", ExpectedErrStr: "table with name t already exists", }, { - Query: "create procedure insert_proc() insert into t values (1, 1), (2, 2), (3, 3);", + Query: "create procedure insert_proc() insert into t values (1, 1), (2, 2), (3, 3);", Expected: []sql.Row{ {types.NewOkResult(0)}, }, }, { - Query: "call insert_proc()", + Query: "call insert_proc()", Expected: []sql.Row{ {types.NewOkResult(3)}, }, }, { - Query: "select * from t", + Query: "select * from t", Expected: []sql.Row{ {1, 1}, {2, 2}, @@ -2880,24 +2880,24 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { - Query: "call insert_proc()", + Query: "call insert_proc()", ExpectedErrStr: "duplicate primary key given: [1]", }, { - Query: "create procedure update_proc() update t set j = 999 where i > 1;", + Query: "create procedure update_proc() update t set j = 999 where i > 1;", Expected: []sql.Row{ {types.NewOkResult(0)}, }, }, { - Query: "call update_proc()", + Query: "call update_proc()", Expected: []sql.Row{ {types.OkResult{RowsAffected: 2, Info: plan.UpdateInfo{Matched: 2, Updated: 2}}}, }, }, { - Query: "select * from t", + Query: "select * from t", Expected: []sql.Row{ {1, 1}, {2, 999}, @@ -2905,20 +2905,20 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { - Query: "call update_proc()", + Query: "call update_proc()", Expected: []sql.Row{ {types.OkResult{RowsAffected: 0, Info: plan.UpdateInfo{Matched: 2}}}, }, }, { - Query: "create procedure drop_proc() drop table t;", + Query: "create procedure drop_proc() drop table t;", Expected: []sql.Row{ {types.NewOkResult(0)}, }, }, { - Query: "call drop_proc()", + Query: "call drop_proc()", Expected: []sql.Row{ {types.NewOkResult(0)}, }, @@ -2928,7 +2928,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ Expected: []sql.Row{}, }, { - Query: "call drop_proc()", + Query: "call drop_proc()", ExpectedErrStr: "Unknown table 't'", }, }, diff --git a/sql/plan/procedure.go b/sql/plan/procedure.go index b7eb377800..7b1f1af772 100644 --- a/sql/plan/procedure.go +++ b/sql/plan/procedure.go @@ -126,7 +126,7 @@ func NewProcedure( CreatedAt: createdAt, ModifiedAt: modifiedAt, - Ops: ops, + Ops: ops, } } diff --git a/sql/procedures.go b/sql/procedures.go index 4b0c488ef1..893f832141 100644 --- a/sql/procedures.go +++ b/sql/procedures.go @@ -18,7 +18,7 @@ import ( "fmt" "time" - "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/vitess/go/vt/sqlparser" ) // Interpreter is an interface that implements an interpreter. These are typically used for functions (which may be diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 62b48b07a3..e4ba5f87c9 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -95,7 +95,7 @@ func (iv *InterpreterVariable) ToAST() *ast.SQLVal { } var astType ast.ValType - var astVal []byte + var astVal []byte if types.IsInteger(iv.Type) { intStr := fmt.Sprintf("%d", iv.Value) return ast.NewIntVal([]byte(intStr)) diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index de9da6c5b6..8dfdd046e1 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -211,7 +211,6 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq innerIter: rowIter.(sql.RowIter), }, nil - // TODO: mirror plpgsql interpreter_logic.go Call() // TODO: instead of building, run the actual operations // This means call the runner.QueryWithBindings From f204d8c07a1d42c8fbfbbeb556691605acc95527 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 19 Feb 2025 13:39:30 -0800 Subject: [PATCH 012/111] implement while loops --- sql/procedures/parse.go | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 3b3d6eb502..41dd339a62 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -57,7 +57,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast // TODO: assume exactly one condition for now ifCond := s.Conditions[0] // TODO: convert condition into a select query - selectIfCond := &ast.Select{ + selectCond := &ast.Select{ SelectExprs: ast.SelectExprs{ &ast.AliasedExpr{ Expr: ifCond.Expr, @@ -66,7 +66,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } ifOp := &InterpreterOperation{ OpCode: OpCode_If, - PrimaryData: selectIfCond, + PrimaryData: selectCond, } *ops = append(*ops, ifOp) @@ -89,8 +89,35 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast gotoOp.Index = len(*ops) - // TODO: update the indexes, now that we know where the goto should go + case *ast.While: + loopStart := len(*ops) + whileCond := s.Condition + selectCond := &ast.Select{ + SelectExprs: ast.SelectExprs{ + &ast.AliasedExpr{ + Expr: whileCond, + }, + }, + } + whileOp := &InterpreterOperation{ + OpCode: OpCode_If, + PrimaryData: selectCond, + } + *ops = append(*ops, whileOp) + + for _, whileStmt := range s.Statements { + if err := ConvertStmt(ops, stack, whileStmt); err != nil { + return err + } + } + gotoOp := &InterpreterOperation{ + OpCode: OpCode_Goto, + Index: loopStart, + } + *ops = append(*ops, gotoOp) + + whileOp.Index = len(*ops) default: execOp := &InterpreterOperation{ OpCode: OpCode_Execute, From 5b3f7202034ea901fe62663af7be85642cf1d0eb Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 19 Feb 2025 15:04:18 -0800 Subject: [PATCH 013/111] tmp --- enginetest/memory_engine_test.go | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index bd1f10eafa..ac527be63d 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -198,17 +198,32 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - t.Skip() + //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "test script", + Name: "Simple SELECT", SetUpScript: []string{ - "create table t (i int);", + `create procedure proc() +begin + set @x = 0; + while @x < 10 do + set @x = @x + 1; + end while; +end;`, + //"create procedure proc(x int) select x > 1;", }, Assertions: []queries.ScriptTestAssertion{ { - Query: "select 1 into @a", - Expected: []sql.Row{}, + Query: "call proc();", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, + }, + { + Query: "select @x;", + Expected: []sql.Row{ + {10}, + }, }, }, }, @@ -216,6 +231,8 @@ func TestSingleScript(t *testing.T) { for _, test := range scripts { harness := enginetest.NewMemoryHarness("", 1, testNumPartitions, true, nil) + // TODO: fix this + //harness.UseServer() engine, err := harness.NewEngine(t) if err != nil { panic(err) From 9b8cc897e0659fedb7c418e2a4c7889c5f725848 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 20 Feb 2025 12:24:24 -0800 Subject: [PATCH 014/111] repeat test --- enginetest/memory_engine_test.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index ac527be63d..bf27009ee0 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -203,18 +203,19 @@ func TestSingleScript(t *testing.T) { { Name: "Simple SELECT", SetUpScript: []string{ - `create procedure proc() + ` +create procedure proc(i int) begin set @x = 0; - while @x < 10 do - set @x = @x + 1; - end while; -end;`, + repeat set @x = @x + 1; + until @x > i + end repeat; +end`, //"create procedure proc(x int) select x > 1;", }, Assertions: []queries.ScriptTestAssertion{ { - Query: "call proc();", + Query: "call proc(10);", Expected: []sql.Row{ {types.NewOkResult(0)}, }, @@ -222,7 +223,7 @@ end;`, { Query: "select @x;", Expected: []sql.Row{ - {10}, + {11}, }, }, }, From b19b8b5a5c292e311fce9924c924d92ad2285528 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 20 Feb 2025 13:00:32 -0800 Subject: [PATCH 015/111] implemented repeat --- sql/procedures/interpreter_logic.go | 6 ++++++ sql/procedures/parse.go | 32 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 501940ac45..5df10221ad 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -69,6 +69,12 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e.Left = newLeftExpr.(ast.Expr) e.Right = newRightExpr.(ast.Expr) + case *ast.NotExpr: + newExpr, err := replaceVariablesInExpr(stack, e.Expr) + if err != nil { + return nil, err + } + e.Expr = newExpr.(ast.Expr) case *ast.ColName: iv := stack.GetVariable(e.Name.String()) if iv == nil { diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 41dd339a62..b3cf7a02bc 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -118,6 +118,38 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast *ops = append(*ops, gotoOp) whileOp.Index = len(*ops) + + case *ast.Repeat: + loopStart := len(*ops) + + repeatCond := &ast.NotExpr{Expr: s.Condition} + selectCond := &ast.Select{ + SelectExprs: ast.SelectExprs{ + &ast.AliasedExpr{ + Expr: repeatCond, + }, + }, + } + repeatOp := &InterpreterOperation{ + OpCode: OpCode_If, + PrimaryData: selectCond, + } + *ops = append(*ops, repeatOp) + + for _, repeatStmt := range s.Statements { + if err := ConvertStmt(ops, stack, repeatStmt); err != nil { + return err + } + } + + gotoOp := &InterpreterOperation{ + OpCode: OpCode_Goto, + Index: loopStart, + } + *ops = append(*ops, gotoOp) + + repeatOp.Index = len(*ops) + default: execOp := &InterpreterOperation{ OpCode: OpCode_Execute, From da6db05a019a3e6601a512b34dd557c6bf5c4ed8 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 21 Feb 2025 14:37:40 -0800 Subject: [PATCH 016/111] double replace issue --- sql/procedures/interpreter_logic.go | 130 ++++++++++++++++++++++-- sql/procedures/interpreter_operation.go | 1 + sql/procedures/interpreter_stack.go | 2 + sql/procedures/parse.go | 116 ++++++++++++++++++++- 4 files changed, 237 insertions(+), 12 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 5df10221ad..7a55a0c52a 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -39,6 +39,42 @@ type Parameter struct { Value any } +func unreplaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) ast.SQLNode { + switch e := expr.(type) { + case *ast.AliasedExpr: + newExpr := unreplaceVariablesInExpr(stack, e.Expr) + e.Expr = newExpr.(ast.Expr) + case *ast.BinaryExpr: + newLeftExpr := unreplaceVariablesInExpr(stack, e.Left) + newRightExpr := unreplaceVariablesInExpr(stack, e.Right) + e.Left = newLeftExpr.(ast.Expr) + e.Right = newRightExpr.(ast.Expr) + case *ast.ComparisonExpr: + newLeftExpr := unreplaceVariablesInExpr(stack, e.Left) + newRightExpr := unreplaceVariablesInExpr(stack, e.Right) + e.Left = newLeftExpr.(ast.Expr) + e.Right = newRightExpr.(ast.Expr) + case *ast.FuncExpr: + for i := range e.Exprs { + newExpr := unreplaceVariablesInExpr(stack, e.Exprs[i]) + e.Exprs[i] = newExpr.(ast.SelectExpr) + } + case *ast.NotExpr: + newExpr := unreplaceVariablesInExpr(stack, e.Expr) + e.Expr = newExpr.(ast.Expr) + case *ast.Set: + for _, setExpr := range e.Exprs { + newExpr := unreplaceVariablesInExpr(stack, setExpr.Expr) + setExpr.Expr = newExpr.(ast.Expr) + } + case *ast.SQLVal: + if oldVal, ok := stack.replaceMap[expr]; ok { + return oldVal + } + } + return expr +} + func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLNode, error) { switch e := expr.(type) { case *ast.AliasedExpr: @@ -69,18 +105,39 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e.Left = newLeftExpr.(ast.Expr) e.Right = newRightExpr.(ast.Expr) + case *ast.FuncExpr: + for i := range e.Exprs { + newExpr, err := replaceVariablesInExpr(stack, e.Exprs[i]) + if err != nil { + return nil, err + } + e.Exprs[i] = newExpr.(ast.SelectExpr) + } case *ast.NotExpr: newExpr, err := replaceVariablesInExpr(stack, e.Expr) if err != nil { return nil, err } e.Expr = newExpr.(ast.Expr) + case *ast.Set: + for _, setExpr := range e.Exprs { + newExpr, err := replaceVariablesInExpr(stack, setExpr.Expr) + if err != nil { + return nil, err + } + err = stack.SetVariable(nil, setExpr.Name.String(), newExpr) + if err != nil { + return nil, err + } + } case *ast.ColName: iv := stack.GetVariable(e.Name.String()) if iv == nil { return expr, nil } - return iv.ToAST(), nil + newExpr := iv.ToAST() + stack.replaceMap[newExpr] = e + return newExpr, nil } return expr, nil } @@ -149,15 +206,19 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er return nil, err } rowIters = append(rowIters, rowIter) + + for i := range selectStmt.SelectExprs { + newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) + selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) + } + stack.replaceMap = map[ast.SQLNode]ast.SQLNode{} + case OpCode_Declare: declareStmt := operation.PrimaryData.(*ast.Declare) for _, decl := range declareStmt.Variables.Names { - var varType sql.Type - switch declareStmt.Variables.VarType.Type { - case "int": - varType = types.Int32 - default: - panic("unimplemented type") + varType, err := types.ColumnTypeToType(&declareStmt.Variables.VarType) + if err != nil { + return nil, err } varName := strings.ToLower(decl.String()) if declareStmt.Variables.VarType.Default != nil { @@ -166,15 +227,61 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er stack.NewVariable(varName, varType) } } + case OpCode_Set: + selectStmt := operation.PrimaryData.(*ast.Select) + if selectStmt.SelectExprs == nil { + panic("select stmt with no select exprs") + } + for i := range selectStmt.SelectExprs { + newNode, err := replaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) + if err != nil { + return nil, err + } + selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) + } + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) + if err != nil { + return nil, err + } + row, err := rowIter.Next(ctx) + if err != nil { + return nil, err + } + if _, err = rowIter.Next(ctx); err != io.EOF { + return nil, err + } + if err = rowIter.Close(ctx); err != nil { + return nil, err + } + + err = stack.SetVariable(nil, strings.ToLower(operation.Target), row[0]) + if err != nil { + return nil, err + } + + for i := range selectStmt.SelectExprs { + newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) + selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) + } + stack.replaceMap = map[ast.SQLNode]ast.SQLNode{} + case OpCode_Exception: // TODO: implement case OpCode_Execute: // TODO: replace variables - rowIter, err := query(ctx, runner, operation.PrimaryData) + stmt, err := replaceVariablesInExpr(&stack, operation.PrimaryData) + if err != nil { + return nil, err + } + rowIter, err := query(ctx, runner, stmt.(ast.Statement)) if err != nil { return nil, err } rowIters = append(rowIters, rowIter) + + stmt = unreplaceVariablesInExpr(&stack, stmt) + stack.replaceMap = map[ast.SQLNode]ast.SQLNode{} + case OpCode_Goto: // We must compare to the index - 1, so that the increment hits our target if counter <= operation.Index { @@ -233,6 +340,13 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er if !cond { counter = operation.Index - 1 // index of the else block, offset by 1 } + + for i := range selectStmt.SelectExprs { + newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) + selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) + } + stack.replaceMap = map[ast.SQLNode]ast.SQLNode{} + case OpCode_ScopeBegin: stack.PushScope() case OpCode_ScopeEnd: diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index 20573188e7..4f37f51870 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -19,6 +19,7 @@ type OpCode uint16 const ( OpCode_Select OpCode = iota OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html + OpCode_Set OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING OpCode_Execute // Everything that's not a SELECT OpCode_Goto // All control-flow structures can be represented using Goto diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index e4ba5f87c9..f9573abd6c 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -123,6 +123,7 @@ type InterpreterScopeDetails struct { // general purpose. type InterpreterStack struct { stack *Stack[*InterpreterScopeDetails] + replaceMap map[ast.SQLNode]ast.SQLNode } // NewInterpreterStack creates a new InterpreterStack. @@ -134,6 +135,7 @@ func NewInterpreterStack() InterpreterStack { }) return InterpreterStack{ stack: stack, + replaceMap: map[ast.SQLNode]ast.SQLNode{}, } } diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index b3cf7a02bc..667f004651 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -53,6 +53,33 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } *ops = append(*ops, declareOp) + case *ast.Set: + if len(s.Exprs) != 1 { + panic("unexpected number of set expressions") + } + setExpr := s.Exprs[0] + var setOp *InterpreterOperation + if len(setExpr.Scope) != 0 { + setOp = &InterpreterOperation{ + OpCode: OpCode_Execute, + PrimaryData: s, + } + } else { + selectStmt := &ast.Select{ + SelectExprs: ast.SelectExprs{ + &ast.AliasedExpr{ + Expr: setExpr.Expr, + }, + }, + } + setOp = &InterpreterOperation{ + OpCode: OpCode_Set, + PrimaryData: selectStmt, + Target: setExpr.Name.String(), + } + } + *ops = append(*ops, setOp) + case *ast.IfStatement: // TODO: assume exactly one condition for now ifCond := s.Conditions[0] @@ -80,14 +107,63 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } *ops = append(*ops, gotoOp) - ifOp.Index = len(*ops) + ifOp.Index = len(*ops) // start of else block for _, elseStmt := range s.Else { if err := ConvertStmt(ops, stack, elseStmt); err != nil { return err } } - gotoOp.Index = len(*ops) + gotoOp.Index = len(*ops) // end of if statement + + case *ast.CaseStatement: + var caseGotoOps []*InterpreterOperation + for _, caseStmt := range s.Cases { + caseExpr := caseStmt.Case + if s.Expr != nil { + caseExpr = &ast.ComparisonExpr{ + Operator: ast.EqualStr, + Left: s.Expr, + Right: caseExpr, + } + } + caseCond := &ast.Select{ + SelectExprs: ast.SelectExprs{ + &ast.AliasedExpr{ + Expr: caseExpr, + }, + }, + } + caseOp := &InterpreterOperation{ + OpCode: OpCode_If, + PrimaryData: caseCond, + } + *ops = append(*ops, caseOp) + + for _, ifStmt := range caseStmt.Statements { + if err := ConvertStmt(ops, stack, ifStmt); err != nil { + return err + } + } + gotoOp := &InterpreterOperation{ + OpCode: OpCode_Goto, + } + caseGotoOps = append(caseGotoOps, gotoOp) + *ops = append(*ops, gotoOp) + + caseOp.Index = len(*ops) // start of next case + } + if s.Else != nil { + for _, elseStmt := range s.Else { + if err := ConvertStmt(ops, stack, elseStmt); err != nil { + return err + } + } + } + + for _, gotoOp := range caseGotoOps { + gotoOp.Index = len(*ops) // end of case block + } case *ast.While: loopStart := len(*ops) @@ -117,7 +193,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } *ops = append(*ops, gotoOp) - whileOp.Index = len(*ops) + whileOp.Index = len(*ops) // end of while block case *ast.Repeat: loopStart := len(*ops) @@ -148,8 +224,40 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } *ops = append(*ops, gotoOp) - repeatOp.Index = len(*ops) + repeatOp.Index = len(*ops) // end of repeat block + + case *ast.Loop: + loopStart := len(*ops) + for _, loopStmt := range s.Statements { + if err := ConvertStmt(ops, stack, loopStmt); err != nil { + return err + } + } + gotoOp := &InterpreterOperation{ + OpCode: OpCode_Goto, + Index: loopStart, + } + *ops = append(*ops, gotoOp) + + // perform second pass over loop statements to add labels + for idx := loopStart; idx < len(*ops); idx++ { + op := (*ops)[idx] + switch op.OpCode { + case OpCode_Goto: + if op.Target == s.Label { + (*ops)[idx].Index = len(*ops) + } + default: + continue + } + } + case *ast.Leave: + leaveOp := &InterpreterOperation{ + OpCode: OpCode_Goto, + Target: s.Label, // hacky? way to signal a leave + } + *ops = append(*ops, leaveOp) default: execOp := &InterpreterOperation{ OpCode: OpCode_Execute, From 347cdd3cc54ed48ee4f7587618ff537e26e64c6e Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 21 Feb 2025 22:38:57 +0000 Subject: [PATCH 017/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/procedures/interpreter_operation.go | 18 +++++++++--------- sql/procedures/interpreter_stack.go | 4 ++-- sql/procedures/parse.go | 8 ++++---- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index 4f37f51870..ac2adde6c6 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -17,16 +17,16 @@ import ast "github.com/dolthub/vitess/go/vt/sqlparser" type OpCode uint16 const ( - OpCode_Select OpCode = iota - OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html + OpCode_Select OpCode = iota + OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html OpCode_Set - OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING - OpCode_Execute // Everything that's not a SELECT - OpCode_Goto // All control-flow structures can be represented using Goto - OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS - OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING - OpCode_ScopeBegin // This is used for scope control, specific to Doltgres - OpCode_ScopeEnd // This is used for scope control, specific to Doltgres + OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING + OpCode_Execute // Everything that's not a SELECT + OpCode_Goto // All control-flow structures can be represented using Goto + OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS + OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING + OpCode_ScopeBegin // This is used for scope control, specific to Doltgres + OpCode_ScopeEnd // This is used for scope control, specific to Doltgres ) // InterpreterOperation is an operation that will be performed by the interpreter. diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index f9573abd6c..712e4add68 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -122,7 +122,7 @@ type InterpreterScopeDetails struct { // the same as a stack in the traditional programming sense, but rather is a loose abstraction that serves the same // general purpose. type InterpreterStack struct { - stack *Stack[*InterpreterScopeDetails] + stack *Stack[*InterpreterScopeDetails] replaceMap map[ast.SQLNode]ast.SQLNode } @@ -134,7 +134,7 @@ func NewInterpreterStack() InterpreterStack { variables: make(map[string]*InterpreterVariable), }) return InterpreterStack{ - stack: stack, + stack: stack, replaceMap: map[ast.SQLNode]ast.SQLNode{}, } } diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 667f004651..f3c4e5074b 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -61,7 +61,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast var setOp *InterpreterOperation if len(setExpr.Scope) != 0 { setOp = &InterpreterOperation{ - OpCode: OpCode_Execute, + OpCode: OpCode_Execute, PrimaryData: s, } } else { @@ -189,7 +189,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } gotoOp := &InterpreterOperation{ OpCode: OpCode_Goto, - Index: loopStart, + Index: loopStart, } *ops = append(*ops, gotoOp) @@ -220,7 +220,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast gotoOp := &InterpreterOperation{ OpCode: OpCode_Goto, - Index: loopStart, + Index: loopStart, } *ops = append(*ops, gotoOp) @@ -235,7 +235,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } gotoOp := &InterpreterOperation{ OpCode: OpCode_Goto, - Index: loopStart, + Index: loopStart, } *ops = append(*ops, gotoOp) From 75caa0797552fd066e804fdb6b18626519c9ad1a Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 24 Feb 2025 10:43:40 -0800 Subject: [PATCH 018/111] introduce new ast type --- enginetest/memory_engine_test.go | 185 ++++++++++++++++++++++++++++--- sql/planbuilder/scalar.go | 3 + sql/planbuilder/show.go | 3 + 3 files changed, 175 insertions(+), 16 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index bf27009ee0..e8c5ddddb3 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -201,31 +201,184 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "Simple SELECT", + Name: "CASE statements", SetUpScript: []string{ ` -create procedure proc(i int) +create procedure proc() begin - set @x = 0; - repeat set @x = @x + 1; - until @x > i - end repeat; -end`, - //"create procedure proc(x int) select x > 1;", + declare x int default 0; + tloop: loop + case + when x = 0 then + set x = 1; + else + leave tloop; + end case; + end loop; + select x; +end;`, + + +// `CREATE PROCEDURE p1(IN a BIGINT) +//BEGIN +// DECLARE b VARCHAR(200) DEFAULT ""; +// tloop: LOOP +// CASE +// WHEN a < 4 THEN +// SET b = CONCAT(b, "a"); +// SET a = a + 1; +// WHEN a < 8 THEN +// SET b = CONCAT(b, "b"); +// SET a = a + 1; +// ELSE +// LEAVE tloop; +// END CASE; +// END LOOP; +// SELECT b; +//END;`, +// `CREATE PROCEDURE p2(IN a BIGINT) +//BEGIN +// DECLARE b VARCHAR(200) DEFAULT ""; +// tloop: LOOP +// CASE a +// WHEN 1 THEN +// SET b = CONCAT(b, "a"); +// SET a = a + 1; +// WHEN 2 THEN +// SET b = CONCAT(b, "b"); +// SET a = a + 1; +// WHEN 3 THEN +// SET b = CONCAT(b, "c"); +// SET a = a + 1; +// ELSE +// LEAVE tloop; +// END CASE; +// END LOOP; +// SELECT b; +//END;`, +// `CREATE PROCEDURE p3(IN a BIGINT) +//BEGIN +// DECLARE b VARCHAR(200) DEFAULT ""; +// tloop: LOOP +// CASE a +// WHEN 1 THEN +// SET b = CONCAT(b, "a"); +// SET a = a + 1; +// END CASE; +// END LOOP; +// SELECT b; +//END;`, +// `CREATE PROCEDURE p4(IN a BIGINT) +//BEGIN +// DECLARE b VARCHAR(200) DEFAULT ""; +// tloop: LOOP +// CASE +// WHEN a = 1 THEN +// SET b = CONCAT(b, "a"); +// SET a = a + 1; +// END CASE; +// END LOOP; +// SELECT b; +//END;`, +// `CREATE PROCEDURE p5(IN a BIGINT) +//BEGIN +// DECLARE b VARCHAR(200) DEFAULT ""; +// REPEAT +// CASE +// WHEN a <= 1 THEN +// SET b = CONCAT(b, "a"); +// SET a = a + 1; +// END CASE; +// UNTIL a > 1 +// END REPEAT; +// SELECT b; +//END;`, }, Assertions: []queries.ScriptTestAssertion{ { - Query: "call proc(10);", + Query: "CALL proc", Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, - }, - { - Query: "select @x;", - Expected: []sql.Row{ - {11}, + {1}, }, }, + + //{ + // Query: "CALL p1(0)", + // Expected: []sql.Row{ + // {"aaaabbbb"}, + // }, + //}, + + //{ + // Query: "CALL p1(3)", + // Expected: []sql.Row{ + // {"abbbb"}, + // }, + //}, + //{ + // Query: "CALL p1(6)", + // Expected: []sql.Row{ + // {"bb"}, + // }, + //}, + //{ + // Query: "CALL p1(9)", + // Expected: []sql.Row{ + // {""}, + // }, + //}, + //{ + // Query: "CALL p2(1)", + // Expected: []sql.Row{ + // {"abc"}, + // }, + //}, + //{ + // Query: "CALL p2(2)", + // Expected: []sql.Row{ + // {"bc"}, + // }, + //}, + //{ + // Query: "CALL p2(3)", + // Expected: []sql.Row{ + // {"c"}, + // }, + //}, + //{ + // Query: "CALL p2(4)", + // Expected: []sql.Row{ + // {""}, + // }, + //}, + //{ + // Query: "CALL p3(1)", + // ExpectedErrStr: "Case not found for CASE statement (errno 1339) (sqlstate 20000)", + //}, + //{ + // Query: "CALL p3(2)", + // ExpectedErrStr: "Case not found for CASE statement (errno 1339) (sqlstate 20000)", + //}, + //{ + // Query: "CALL p4(1)", + // ExpectedErrStr: "Case not found for CASE statement (errno 1339) (sqlstate 20000)", + //}, + //{ + // Query: "CALL p4(-1)", + // ExpectedErrStr: "Case not found for CASE statement (errno 1339) (sqlstate 20000)", + //}, + //{ + // Query: "CALL p5(0)", + // Expected: []sql.Row{ + // {"aa"}, + // }, + //}, + //{ + // Query: "CALL p5(1)", + // Expected: []sql.Row{ + // {"a"}, + // }, + //}, }, }, } diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 1939f6eab2..11088542a2 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -106,6 +106,9 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { case *ast.NullVal: return expression.NewLiteral(nil, types.Null) case *ast.ColName: + if v.Metadata != nil { + return b.ConvertVal(v.Metadata) + } dbName := strings.ToLower(v.Qualifier.DbQualifier.String()) tblName := strings.ToLower(v.Qualifier.Name.String()) colName := strings.ToLower(v.Name.String()) diff --git a/sql/planbuilder/show.go b/sql/planbuilder/show.go index 2a04a0cf49..ae5200d6cf 100644 --- a/sql/planbuilder/show.go +++ b/sql/planbuilder/show.go @@ -604,6 +604,9 @@ func (b *Builder) buildAsOfExpr(inScope *scope, time ast.Expr) sql.Expression { } return expression.NewLiteral(ret.(string), types.LongText) case *ast.ColName: + if v.Metadata != nil { + return b.buildAsOfExpr(inScope, v.Metadata) + } sysVar, _, ok := b.buildSysVar(v, ast.SetScope_None) if ok { return sysVar From c41fed47ae93d4b3f6e79e62bd5423c424946bca Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 11 Mar 2025 19:24:49 +0000 Subject: [PATCH 019/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/memory_engine_test.go | 149 +++++++++++++++---------------- 1 file changed, 74 insertions(+), 75 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index e8c5ddddb3..a6f0832533 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -218,81 +218,80 @@ begin select x; end;`, - -// `CREATE PROCEDURE p1(IN a BIGINT) -//BEGIN -// DECLARE b VARCHAR(200) DEFAULT ""; -// tloop: LOOP -// CASE -// WHEN a < 4 THEN -// SET b = CONCAT(b, "a"); -// SET a = a + 1; -// WHEN a < 8 THEN -// SET b = CONCAT(b, "b"); -// SET a = a + 1; -// ELSE -// LEAVE tloop; -// END CASE; -// END LOOP; -// SELECT b; -//END;`, -// `CREATE PROCEDURE p2(IN a BIGINT) -//BEGIN -// DECLARE b VARCHAR(200) DEFAULT ""; -// tloop: LOOP -// CASE a -// WHEN 1 THEN -// SET b = CONCAT(b, "a"); -// SET a = a + 1; -// WHEN 2 THEN -// SET b = CONCAT(b, "b"); -// SET a = a + 1; -// WHEN 3 THEN -// SET b = CONCAT(b, "c"); -// SET a = a + 1; -// ELSE -// LEAVE tloop; -// END CASE; -// END LOOP; -// SELECT b; -//END;`, -// `CREATE PROCEDURE p3(IN a BIGINT) -//BEGIN -// DECLARE b VARCHAR(200) DEFAULT ""; -// tloop: LOOP -// CASE a -// WHEN 1 THEN -// SET b = CONCAT(b, "a"); -// SET a = a + 1; -// END CASE; -// END LOOP; -// SELECT b; -//END;`, -// `CREATE PROCEDURE p4(IN a BIGINT) -//BEGIN -// DECLARE b VARCHAR(200) DEFAULT ""; -// tloop: LOOP -// CASE -// WHEN a = 1 THEN -// SET b = CONCAT(b, "a"); -// SET a = a + 1; -// END CASE; -// END LOOP; -// SELECT b; -//END;`, -// `CREATE PROCEDURE p5(IN a BIGINT) -//BEGIN -// DECLARE b VARCHAR(200) DEFAULT ""; -// REPEAT -// CASE -// WHEN a <= 1 THEN -// SET b = CONCAT(b, "a"); -// SET a = a + 1; -// END CASE; -// UNTIL a > 1 -// END REPEAT; -// SELECT b; -//END;`, + // `CREATE PROCEDURE p1(IN a BIGINT) + //BEGIN + // DECLARE b VARCHAR(200) DEFAULT ""; + // tloop: LOOP + // CASE + // WHEN a < 4 THEN + // SET b = CONCAT(b, "a"); + // SET a = a + 1; + // WHEN a < 8 THEN + // SET b = CONCAT(b, "b"); + // SET a = a + 1; + // ELSE + // LEAVE tloop; + // END CASE; + // END LOOP; + // SELECT b; + //END;`, + // `CREATE PROCEDURE p2(IN a BIGINT) + //BEGIN + // DECLARE b VARCHAR(200) DEFAULT ""; + // tloop: LOOP + // CASE a + // WHEN 1 THEN + // SET b = CONCAT(b, "a"); + // SET a = a + 1; + // WHEN 2 THEN + // SET b = CONCAT(b, "b"); + // SET a = a + 1; + // WHEN 3 THEN + // SET b = CONCAT(b, "c"); + // SET a = a + 1; + // ELSE + // LEAVE tloop; + // END CASE; + // END LOOP; + // SELECT b; + //END;`, + // `CREATE PROCEDURE p3(IN a BIGINT) + //BEGIN + // DECLARE b VARCHAR(200) DEFAULT ""; + // tloop: LOOP + // CASE a + // WHEN 1 THEN + // SET b = CONCAT(b, "a"); + // SET a = a + 1; + // END CASE; + // END LOOP; + // SELECT b; + //END;`, + // `CREATE PROCEDURE p4(IN a BIGINT) + //BEGIN + // DECLARE b VARCHAR(200) DEFAULT ""; + // tloop: LOOP + // CASE + // WHEN a = 1 THEN + // SET b = CONCAT(b, "a"); + // SET a = a + 1; + // END CASE; + // END LOOP; + // SELECT b; + //END;`, + // `CREATE PROCEDURE p5(IN a BIGINT) + //BEGIN + // DECLARE b VARCHAR(200) DEFAULT ""; + // REPEAT + // CASE + // WHEN a <= 1 THEN + // SET b = CONCAT(b, "a"); + // SET a = a + 1; + // END CASE; + // UNTIL a > 1 + // END REPEAT; + // SELECT b; + //END;`, }, Assertions: []queries.ScriptTestAssertion{ { From bf5a416b87ca9a262c768423cd67c987abb506b3 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 11 Mar 2025 16:41:05 -0700 Subject: [PATCH 020/111] implement case errors --- sql/planbuilder/scalar.go | 4 +- sql/planbuilder/show.go | 4 +- sql/procedures/interpreter_logic.go | 67 +++---------------------- sql/procedures/interpreter_operation.go | 1 + sql/procedures/interpreter_stack.go | 2 - sql/procedures/parse.go | 13 ++++- 6 files changed, 24 insertions(+), 67 deletions(-) diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 11088542a2..00ef68a336 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -106,8 +106,8 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { case *ast.NullVal: return expression.NewLiteral(nil, types.Null) case *ast.ColName: - if v.Metadata != nil { - return b.ConvertVal(v.Metadata) + if v.StoredProcVal != nil { + return b.ConvertVal(v.StoredProcVal) } dbName := strings.ToLower(v.Qualifier.DbQualifier.String()) tblName := strings.ToLower(v.Qualifier.Name.String()) diff --git a/sql/planbuilder/show.go b/sql/planbuilder/show.go index ae5200d6cf..bdd6729b6c 100644 --- a/sql/planbuilder/show.go +++ b/sql/planbuilder/show.go @@ -604,8 +604,8 @@ func (b *Builder) buildAsOfExpr(inScope *scope, time ast.Expr) sql.Expression { } return expression.NewLiteral(ret.(string), types.LongText) case *ast.ColName: - if v.Metadata != nil { - return b.buildAsOfExpr(inScope, v.Metadata) + if v.StoredProcVal != nil { + return b.buildAsOfExpr(inScope, v.StoredProcVal) } sysVar, _, ok := b.buildSysVar(v, ast.SetScope_None) if ok { diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 7a55a0c52a..a79ae39773 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -39,42 +39,6 @@ type Parameter struct { Value any } -func unreplaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) ast.SQLNode { - switch e := expr.(type) { - case *ast.AliasedExpr: - newExpr := unreplaceVariablesInExpr(stack, e.Expr) - e.Expr = newExpr.(ast.Expr) - case *ast.BinaryExpr: - newLeftExpr := unreplaceVariablesInExpr(stack, e.Left) - newRightExpr := unreplaceVariablesInExpr(stack, e.Right) - e.Left = newLeftExpr.(ast.Expr) - e.Right = newRightExpr.(ast.Expr) - case *ast.ComparisonExpr: - newLeftExpr := unreplaceVariablesInExpr(stack, e.Left) - newRightExpr := unreplaceVariablesInExpr(stack, e.Right) - e.Left = newLeftExpr.(ast.Expr) - e.Right = newRightExpr.(ast.Expr) - case *ast.FuncExpr: - for i := range e.Exprs { - newExpr := unreplaceVariablesInExpr(stack, e.Exprs[i]) - e.Exprs[i] = newExpr.(ast.SelectExpr) - } - case *ast.NotExpr: - newExpr := unreplaceVariablesInExpr(stack, e.Expr) - e.Expr = newExpr.(ast.Expr) - case *ast.Set: - for _, setExpr := range e.Exprs { - newExpr := unreplaceVariablesInExpr(stack, setExpr.Expr) - setExpr.Expr = newExpr.(ast.Expr) - } - case *ast.SQLVal: - if oldVal, ok := stack.replaceMap[expr]; ok { - return oldVal - } - } - return expr -} - func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLNode, error) { switch e := expr.(type) { case *ast.AliasedExpr: @@ -136,8 +100,11 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN return expr, nil } newExpr := iv.ToAST() - stack.replaceMap[newExpr] = e - return newExpr, nil + return &ast.ColName{ + Name: e.Name, + Qualifier: e.Qualifier, + StoredProcVal: newExpr, + }, nil } return expr, nil } @@ -207,12 +174,6 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er } rowIters = append(rowIters, rowIter) - for i := range selectStmt.SelectExprs { - newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) - selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) - } - stack.replaceMap = map[ast.SQLNode]ast.SQLNode{} - case OpCode_Declare: declareStmt := operation.PrimaryData.(*ast.Declare) for _, decl := range declareStmt.Variables.Names { @@ -259,14 +220,9 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er return nil, err } - for i := range selectStmt.SelectExprs { - newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) - selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) - } - stack.replaceMap = map[ast.SQLNode]ast.SQLNode{} - case OpCode_Exception: - // TODO: implement + return nil, operation.Error + case OpCode_Execute: // TODO: replace variables stmt, err := replaceVariablesInExpr(&stack, operation.PrimaryData) @@ -279,9 +235,6 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er } rowIters = append(rowIters, rowIter) - stmt = unreplaceVariablesInExpr(&stack, stmt) - stack.replaceMap = map[ast.SQLNode]ast.SQLNode{} - case OpCode_Goto: // We must compare to the index - 1, so that the increment hits our target if counter <= operation.Index { @@ -341,12 +294,6 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er counter = operation.Index - 1 // index of the else block, offset by 1 } - for i := range selectStmt.SelectExprs { - newNode := unreplaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) - selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) - } - stack.replaceMap = map[ast.SQLNode]ast.SQLNode{} - case OpCode_ScopeBegin: stack.PushScope() case OpCode_ScopeEnd: diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index ac2adde6c6..9ef4d12e25 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -36,4 +36,5 @@ type InterpreterOperation struct { SecondaryData []string // This represents auxiliary data, such as bindings, strictness, etc. Target string // This is the variable that will store the results (if applicable) Index int // This is the index that should be set for operations that move the function counter + Error error // This is the error that should be returned for OpCode_Exception } diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 712e4add68..83747d71f8 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -123,7 +123,6 @@ type InterpreterScopeDetails struct { // general purpose. type InterpreterStack struct { stack *Stack[*InterpreterScopeDetails] - replaceMap map[ast.SQLNode]ast.SQLNode } // NewInterpreterStack creates a new InterpreterStack. @@ -135,7 +134,6 @@ func NewInterpreterStack() InterpreterStack { }) return InterpreterStack{ stack: stack, - replaceMap: map[ast.SQLNode]ast.SQLNode{}, } } diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index f3c4e5074b..6182a1e013 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -15,6 +15,8 @@ package procedures import ( + "github.com/dolthub/vitess/go/mysql" + ast "github.com/dolthub/vitess/go/vt/sqlparser" ) @@ -153,7 +155,16 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast caseOp.Index = len(*ops) // start of next case } - if s.Else != nil { + if s.Else == nil { + // throw an error if when there is no else block + // this is just an empty case statement that will always hit the else + // todo: alternatively, use an error opcode + errOp := &InterpreterOperation{ + OpCode: OpCode_Exception, + Error: mysql.NewSQLError(1339, "20000", "Case not found for CASE statement"), + } + *ops = append(*ops, errOp) + } else { for _, elseStmt := range s.Else { if err := ConvertStmt(ops, stack, elseStmt); err != nil { return err From 176bd604685b5af260548364b839a8bb4dd8212d Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 11 Mar 2025 23:42:23 +0000 Subject: [PATCH 021/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/procedures/interpreter_stack.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 83747d71f8..e4ba5f87c9 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -122,7 +122,7 @@ type InterpreterScopeDetails struct { // the same as a stack in the traditional programming sense, but rather is a loose abstraction that serves the same // general purpose. type InterpreterStack struct { - stack *Stack[*InterpreterScopeDetails] + stack *Stack[*InterpreterScopeDetails] } // NewInterpreterStack creates a new InterpreterStack. @@ -133,7 +133,7 @@ func NewInterpreterStack() InterpreterStack { variables: make(map[string]*InterpreterVariable), }) return InterpreterStack{ - stack: stack, + stack: stack, } } From dd870f77b64abd58b4e297c313b6b5daaf8d4cdc Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 12 Mar 2025 12:44:35 -0700 Subject: [PATCH 022/111] fix --- sql/analyzer/optimization_rules.go | 7 +++++-- sql/procedures/interpreter_logic.go | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index 9d53294a9b..aeaa9aa735 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -15,7 +15,8 @@ package analyzer import ( - "strings" + "github.com/dolthub/go-mysql-server/memory" +"strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" @@ -43,7 +44,9 @@ func eraseProjection(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.S return transform.Node(node, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) { project, ok := node.(*plan.Project) if ok { - if project.Schema().CaseSensitiveEquals(project.Child.Schema()) { + projSch := project.Schema() + childSch := project.Child.Schema() + if projSch.CaseSensitiveEquals(childSch) && !childSch.CaseSensitiveEquals(memory.DualTableSchema.Schema) { a.Log("project erased") return project.Child, transform.NewTree, nil } diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index a79ae39773..d5fa841005 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -85,6 +85,10 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN e.Expr = newExpr.(ast.Expr) case *ast.Set: for _, setExpr := range e.Exprs { + // TODO: properly handle user scope variables + if setExpr.Scope == ast.SetScope_User { + continue + } newExpr, err := replaceVariablesInExpr(stack, setExpr.Expr) if err != nil { return nil, err @@ -188,6 +192,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er stack.NewVariable(varName, varType) } } + case OpCode_Set: selectStmt := operation.PrimaryData.(*ast.Select) if selectStmt.SelectExprs == nil { @@ -260,6 +265,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er } } } + case OpCode_If: selectStmt := operation.PrimaryData.(*ast.Select) if selectStmt.SelectExprs == nil { From 4b451f6426918fd293d391ec7f0766b32f2e659f Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 12 Mar 2025 12:45:02 -0700 Subject: [PATCH 023/111] asdf --- enginetest/memory_engine_test.go | 181 ++----------------------------- 1 file changed, 12 insertions(+), 169 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index a6f0832533..55895a7ffe 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -203,181 +203,21 @@ func TestSingleScript(t *testing.T) { { Name: "CASE statements", SetUpScript: []string{ - ` -create procedure proc() -begin - declare x int default 0; - tloop: loop - case - when x = 0 then - set x = 1; - else - leave tloop; - end case; - end loop; - select x; -end;`, - - // `CREATE PROCEDURE p1(IN a BIGINT) - //BEGIN - // DECLARE b VARCHAR(200) DEFAULT ""; - // tloop: LOOP - // CASE - // WHEN a < 4 THEN - // SET b = CONCAT(b, "a"); - // SET a = a + 1; - // WHEN a < 8 THEN - // SET b = CONCAT(b, "b"); - // SET a = a + 1; - // ELSE - // LEAVE tloop; - // END CASE; - // END LOOP; - // SELECT b; - //END;`, - // `CREATE PROCEDURE p2(IN a BIGINT) - //BEGIN - // DECLARE b VARCHAR(200) DEFAULT ""; - // tloop: LOOP - // CASE a - // WHEN 1 THEN - // SET b = CONCAT(b, "a"); - // SET a = a + 1; - // WHEN 2 THEN - // SET b = CONCAT(b, "b"); - // SET a = a + 1; - // WHEN 3 THEN - // SET b = CONCAT(b, "c"); - // SET a = a + 1; - // ELSE - // LEAVE tloop; - // END CASE; - // END LOOP; - // SELECT b; - //END;`, - // `CREATE PROCEDURE p3(IN a BIGINT) - //BEGIN - // DECLARE b VARCHAR(200) DEFAULT ""; - // tloop: LOOP - // CASE a - // WHEN 1 THEN - // SET b = CONCAT(b, "a"); - // SET a = a + 1; - // END CASE; - // END LOOP; - // SELECT b; - //END;`, - // `CREATE PROCEDURE p4(IN a BIGINT) - //BEGIN - // DECLARE b VARCHAR(200) DEFAULT ""; - // tloop: LOOP - // CASE - // WHEN a = 1 THEN - // SET b = CONCAT(b, "a"); - // SET a = a + 1; - // END CASE; - // END LOOP; - // SELECT b; - //END;`, - // `CREATE PROCEDURE p5(IN a BIGINT) - //BEGIN - // DECLARE b VARCHAR(200) DEFAULT ""; - // REPEAT - // CASE - // WHEN a <= 1 THEN - // SET b = CONCAT(b, "a"); - // SET a = a + 1; - // END CASE; - // UNTIL a > 1 - // END REPEAT; - // SELECT b; - //END;`, + ` +CREATE PROCEDURE p1() +BEGIN + DECLARE b INT DEFAULT ""; + SELECT b; +END; +`, }, Assertions: []queries.ScriptTestAssertion{ { - Query: "CALL proc", + Query: "call p1()", Expected: []sql.Row{ - {1}, + {""}, }, }, - - //{ - // Query: "CALL p1(0)", - // Expected: []sql.Row{ - // {"aaaabbbb"}, - // }, - //}, - - //{ - // Query: "CALL p1(3)", - // Expected: []sql.Row{ - // {"abbbb"}, - // }, - //}, - //{ - // Query: "CALL p1(6)", - // Expected: []sql.Row{ - // {"bb"}, - // }, - //}, - //{ - // Query: "CALL p1(9)", - // Expected: []sql.Row{ - // {""}, - // }, - //}, - //{ - // Query: "CALL p2(1)", - // Expected: []sql.Row{ - // {"abc"}, - // }, - //}, - //{ - // Query: "CALL p2(2)", - // Expected: []sql.Row{ - // {"bc"}, - // }, - //}, - //{ - // Query: "CALL p2(3)", - // Expected: []sql.Row{ - // {"c"}, - // }, - //}, - //{ - // Query: "CALL p2(4)", - // Expected: []sql.Row{ - // {""}, - // }, - //}, - //{ - // Query: "CALL p3(1)", - // ExpectedErrStr: "Case not found for CASE statement (errno 1339) (sqlstate 20000)", - //}, - //{ - // Query: "CALL p3(2)", - // ExpectedErrStr: "Case not found for CASE statement (errno 1339) (sqlstate 20000)", - //}, - //{ - // Query: "CALL p4(1)", - // ExpectedErrStr: "Case not found for CASE statement (errno 1339) (sqlstate 20000)", - //}, - //{ - // Query: "CALL p4(-1)", - // ExpectedErrStr: "Case not found for CASE statement (errno 1339) (sqlstate 20000)", - //}, - //{ - // Query: "CALL p5(0)", - // Expected: []sql.Row{ - // {"aa"}, - // }, - //}, - //{ - // Query: "CALL p5(1)", - // Expected: []sql.Row{ - // {"a"}, - // }, - //}, }, }, } @@ -391,6 +231,9 @@ end;`, panic(err) } + engine.EngineAnalyzer().Debug = true + engine.EngineAnalyzer().Verbose = true + enginetest.TestScriptWithEngine(t, engine, harness, test) } } From 0ae1290baa370e1c6cc8ecf4a64a4124d664e640 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 12 Mar 2025 14:27:18 -0700 Subject: [PATCH 024/111] out tests --- enginetest/memory_engine_test.go | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 55895a7ffe..299da0e5d3 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -203,21 +203,28 @@ func TestSingleScript(t *testing.T) { { Name: "CASE statements", SetUpScript: []string{ + "SET @x = 0", ` -CREATE PROCEDURE p1() +CREATE PROCEDURE p1(OUT x INT) BEGIN - DECLARE b INT DEFAULT ""; - SELECT b; -END; -`, + SET x = 123; + SELECT x; +END;`, }, Assertions: []queries.ScriptTestAssertion{ { - Query: "call p1()", + Query: "CALL p1(@x)", Expected: []sql.Row{ - {""}, + {123}, }, }, + { + Query: "SELECT @x", + Expected: []sql.Row{ + {123}, + }, + }, + }, }, } @@ -231,8 +238,8 @@ END; panic(err) } - engine.EngineAnalyzer().Debug = true - engine.EngineAnalyzer().Verbose = true + //engine.EngineAnalyzer().Debug = true + //engine.EngineAnalyzer().Verbose = true enginetest.TestScriptWithEngine(t, engine, harness, test) } From 37981e09f0b6d11806939b97172c525e7e2909c5 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 12 Mar 2025 16:06:11 -0700 Subject: [PATCH 025/111] fix out params --- sql/procedures/interpreter_logic.go | 57 ++++++++++++++++------------- sql/procedures/interpreter_stack.go | 30 +++++++-------- sql/procedures/parse.go | 2 +- sql/rowexec/proc.go | 44 +++++++++++++--------- sql/rowexec/proc_iters.go | 47 ------------------------ 5 files changed, 74 insertions(+), 106 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index d5fa841005..c7095f98ca 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -136,7 +136,7 @@ func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sq } // Call runs the contained operations on the given runner. -func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, error) { +func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *InterpreterStack, error) { // Set up the initial state of the function counter := -1 // We increment before accessing, so start at -1 stack := NewInterpreterStack() @@ -166,15 +166,15 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er panic("select stmt with no select exprs") } for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) + newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) if err != nil { - return nil, err + return nil, nil, err } selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) } rowIter, err := query(ctx, runner, selectStmt) if err != nil { - return nil, err + return nil, nil, err } rowIters = append(rowIters, rowIter) @@ -183,7 +183,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er for _, decl := range declareStmt.Variables.Names { varType, err := types.ColumnTypeToType(&declareStmt.Variables.VarType) if err != nil { - return nil, err + return nil, nil, err } varName := strings.ToLower(decl.String()) if declareStmt.Variables.VarType.Default != nil { @@ -199,44 +199,41 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er panic("select stmt with no select exprs") } for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) + newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) if err != nil { - return nil, err + return nil, nil, err } selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) } _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) if err != nil { - return nil, err + return nil, nil, err } row, err := rowIter.Next(ctx) if err != nil { - return nil, err + return nil, nil, err } if _, err = rowIter.Next(ctx); err != io.EOF { - return nil, err + return nil, nil, err } if err = rowIter.Close(ctx); err != nil { - return nil, err + return nil, nil, err } err = stack.SetVariable(nil, strings.ToLower(operation.Target), row[0]) if err != nil { - return nil, err + return nil, nil, err } - case OpCode_Exception: - return nil, operation.Error - case OpCode_Execute: // TODO: replace variables - stmt, err := replaceVariablesInExpr(&stack, operation.PrimaryData) + stmt, err := replaceVariablesInExpr(stack, operation.PrimaryData) if err != nil { - return nil, err + return nil, nil, err } rowIter, err := query(ctx, runner, stmt.(ast.Statement)) if err != nil { - return nil, err + return nil, nil, err } rowIters = append(rowIters, rowIter) @@ -272,26 +269,26 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er panic("select stmt with no select exprs") } for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(&stack, selectStmt.SelectExprs[i]) + newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) if err != nil { - return nil, err + return nil, nil, err } selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) } _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) if err != nil { - return nil, err + return nil, nil, err } // TODO: exactly one result that is a bool for now row, err := rowIter.Next(ctx) if err != nil { - return nil, err + return nil, nil, err } if _, err = rowIter.Next(ctx); err != io.EOF { - return nil, err + return nil, nil, err } if err = rowIter.Close(ctx); err != nil { - return nil, err + return nil, nil, err } // go to the appropriate block @@ -300,6 +297,8 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er counter = operation.Index - 1 // index of the else block, offset by 1 } + case OpCode_Exception: + return nil, nil, operation.Error case OpCode_ScopeBegin: stack.PushScope() case OpCode_ScopeEnd: @@ -308,8 +307,14 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, er panic("unimplemented opcode") } } + + + // TODO: Set all user and system variables from INOUT and OUT params. + // Copy logic from proc_iters.go: callIter.Close() + if len(rowIters) == 0 { - panic("no rowIters") + rowIters = append(rowIters, sql.RowsToRowIter(sql.Row{types.NewOkResult(0)})) } - return rowIters[len(rowIters)-1], nil + + return rowIters[len(rowIters)-1], stack, nil } diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 83747d71f8..641911145e 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -94,23 +94,23 @@ func (iv *InterpreterVariable) ToAST() *ast.SQLVal { return sqlVal } - var astType ast.ValType var astVal []byte if types.IsInteger(iv.Type) { - intStr := fmt.Sprintf("%d", iv.Value) - return ast.NewIntVal([]byte(intStr)) - } else if types.IsFloat(iv.Type) { - floatStr := strconv.FormatFloat(iv.Value.(float64), 'f', -1, 64) - return ast.NewFloatVal([]byte(floatStr)) - } else { - astType = ast.StrVal - astVal = []byte(fmt.Sprintf("%s", iv.Value)) + if iv.Value != nil { + astVal = []byte(fmt.Sprintf("%d", iv.Value)) + } + return ast.NewIntVal(astVal) } - - return &ast.SQLVal{ - Type: astType, - Val: astVal, + if types.IsFloat(iv.Type) { + if iv.Value != nil { + astVal = []byte(strconv.FormatFloat(iv.Value.(float64), 'f', -1, 64)) + } + return ast.NewFloatVal(astVal) + } + if iv.Value != nil { + astVal = []byte(fmt.Sprintf("%s", iv.Value)) } + return ast.NewStrVal(astVal) } // InterpreterScopeDetails contains all of the details that are relevant to a particular scope. @@ -126,13 +126,13 @@ type InterpreterStack struct { } // NewInterpreterStack creates a new InterpreterStack. -func NewInterpreterStack() InterpreterStack { +func NewInterpreterStack() *InterpreterStack { stack := NewStack[*InterpreterScopeDetails]() // This first push represents the function base, including parameters stack.Push(&InterpreterScopeDetails{ variables: make(map[string]*InterpreterVariable), }) - return InterpreterStack{ + return &InterpreterStack{ stack: stack, } } diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 6182a1e013..94e41478be 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -285,7 +285,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast func Parse(stmt ast.Statement) ([]*InterpreterOperation, error) { ops := make([]*InterpreterOperation, 0, 64) stack := NewInterpreterStack() - err := ConvertStmt(&ops, &stack, stmt) + err := ConvertStmt(&ops, stack, stmt) if err != nil { return nil, err } diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 8dfdd046e1..0323c57cdf 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -184,8 +184,9 @@ func (b *BaseBuilder) buildProcedureResolvedTable(ctx *sql.Context, n *plan.Proc func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sql.RowIter, error) { procParams := make([]*procedures.Parameter, len(n.Params)) for i, paramExpr := range n.Params { - paramName := strings.ToLower(n.Procedure.Params[i].Name) - paramType := n.Procedure.Params[i].Type + param := n.Procedure.Params[i] + paramName := strings.ToLower(param.Name) + paramType := param.Type paramVal, err := paramExpr.Eval(ctx, row) if err != nil { return nil, err @@ -201,29 +202,38 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq } } - rowIter, err := procedures.Call(ctx, n, procParams) + rowIter, stack, err := procedures.Call(ctx, n, procParams) if err != nil { return nil, err } - return &callIter{ - call: n, - innerIter: rowIter.(sql.RowIter), - }, nil - - // TODO: mirror plpgsql interpreter_logic.go Call() - // TODO: instead of building, run the actual operations - // This means call the runner.QueryWithBindings - innerIter, err := b.buildNodeExec(ctx, n.Procedure, row) - if err != nil { - return nil, err + for i, param := range n.Params { + procParam := n.Procedure.Params[i] + if procParam.Direction == plan.ProcedureParamDirection_In { + continue + } + // Set all user and system variables from INOUT and OUT params + stackVar := stack.GetVariable(procParam.Name) // TODO: ToLower? + switch p := param.(type) { + case *expression.ProcedureParam: + err = p.Set(stackVar.Value, stackVar.Type) + if err != nil { + return nil, err + } + case *expression.UserVar: + err = ctx.SetUserVariable(ctx, p.Name, stackVar.Value, stackVar.Type) + if err != nil { + return nil, err + } + case *expression.SystemVar: + // This should have been caught by the analyzer, so a major bug exists somewhere + return nil, fmt.Errorf("unable to set `%s` as it is a system variable", p.Name) + } } - // TODO: save any select ast rowIters to be returned later - return &callIter{ call: n, - innerIter: innerIter, + innerIter: rowIter.(sql.RowIter), }, nil } diff --git a/sql/rowexec/proc_iters.go b/sql/rowexec/proc_iters.go index 07a6cb853d..5ec79d1463 100644 --- a/sql/rowexec/proc_iters.go +++ b/sql/rowexec/proc_iters.go @@ -134,53 +134,6 @@ func (ci *callIter) Close(ctx *sql.Context) error { if err != nil { return err } - - // Set all user and system variables from INOUT and OUT params - for i, param := range ci.call.Procedure.Params { - if param.Direction == plan.ProcedureParamDirection_Inout || - (param.Direction == plan.ProcedureParamDirection_Out && ci.call.Pref.VariableHasBeenSet(param.Name)) { - val, err := ci.call.Pref.GetVariableValue(param.Name) - if err != nil { - return err - } - - typ := ci.call.Pref.GetVariableType(param.Name) - - switch callParam := ci.call.Params[i].(type) { - case *expression.UserVar: - err = ctx.SetUserVariable(ctx, callParam.Name, val, typ) - if err != nil { - return err - } - case *expression.SystemVar: - // This should have been caught by the analyzer, so a major bug exists somewhere - return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name) - case *expression.ProcedureParam: - err = callParam.Set(val, param.Type) - if err != nil { - return err - } - } - } else if param.Direction == plan.ProcedureParamDirection_Out { // VariableHasBeenSet was false - // For OUT only, if a var was not set within the procedure body, then we set the vars to nil. - // If the var had a value before the call then it is basically removed. - switch callParam := ci.call.Params[i].(type) { - case *expression.UserVar: - err = ctx.SetUserVariable(ctx, callParam.Name, nil, ci.call.Pref.GetVariableType(param.Name)) - if err != nil { - return err - } - case *expression.SystemVar: - // This should have been caught by the analyzer, so a major bug exists somewhere - return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name) - case *expression.ProcedureParam: - err := callParam.Set(nil, param.Type) - if err != nil { - return err - } - } - } - } return nil } From ba3fc0589877cd66aa3b5743c2042f611dffb9cb Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 13 Mar 2025 11:25:13 -0700 Subject: [PATCH 026/111] convert booleans --- sql/planbuilder/scalar.go | 7 ++++++- sql/procedures/interpreter_logic.go | 7 +++++-- sql/procedures/interpreter_stack.go | 22 +++++++--------------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index 00ef68a336..f201cd4ec8 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -107,7 +107,12 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { return expression.NewLiteral(nil, types.Null) case *ast.ColName: if v.StoredProcVal != nil { - return b.ConvertVal(v.StoredProcVal) + switch val := v.StoredProcVal.(type) { + case *ast.SQLVal: + return b.ConvertVal(val) + case *ast.NullVal: + return expression.NewLiteral(nil, types.Null) + } } dbName := strings.ToLower(v.Qualifier.DbQualifier.String()) tblName := strings.ToLower(v.Qualifier.Name.String()) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index c7095f98ca..a0bb16159b 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -292,8 +292,11 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I } // go to the appropriate block - cond := row[0].(bool) - if !cond { + cond, _, err := types.Boolean.Convert(row[0]) + if err != nil { + return nil, nil, err + } + if cond == nil || cond.(int8) == 0 { counter = operation.Index - 1 // index of the else block, offset by 1 } diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 641911145e..3f9ac43c83 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -89,28 +89,20 @@ type InterpreterVariable struct { Value any } -func (iv *InterpreterVariable) ToAST() *ast.SQLVal { +func (iv *InterpreterVariable) ToAST() ast.Expr { if sqlVal, isSQLVal := iv.Value.(*ast.SQLVal); isSQLVal { return sqlVal } - - var astVal []byte + if iv.Value == nil { + return &ast.NullVal{} + } if types.IsInteger(iv.Type) { - if iv.Value != nil { - astVal = []byte(fmt.Sprintf("%d", iv.Value)) - } - return ast.NewIntVal(astVal) + return ast.NewIntVal([]byte(fmt.Sprintf("%d", iv.Value))) } if types.IsFloat(iv.Type) { - if iv.Value != nil { - astVal = []byte(strconv.FormatFloat(iv.Value.(float64), 'f', -1, 64)) - } - return ast.NewFloatVal(astVal) - } - if iv.Value != nil { - astVal = []byte(fmt.Sprintf("%s", iv.Value)) + return ast.NewFloatVal([]byte(strconv.FormatFloat(iv.Value.(float64), 'f', -1, 64))) } - return ast.NewStrVal(astVal) + return ast.NewStrVal([]byte(fmt.Sprintf("%s", iv.Value))) } // InterpreterScopeDetails contains all of the details that are relevant to a particular scope. From 913db9e7591446e6c3e68a7fb9863ddede0c4ef4 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 13 Mar 2025 11:29:39 -0700 Subject: [PATCH 027/111] vitess bump --- go.mod | 2 +- go.sum | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 485ba6d789..b983ed1720 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a + github.com/dolthub/vitess v0.0.0-20250313182551-a36df89b8084 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 772cefdc8e..0a89c67a25 100644 --- a/go.sum +++ b/go.sum @@ -58,12 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730 h1:GtlMVB7+Z7fZZj7BHRFd2rzxZ574dJ8cB/EHWdq1kbY= -github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4 h1:wtS9ZWEyEeYzLCcqdGUo+7i3hAV5MWuY9Z7tYbQa65A= -github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a h1:HIH9g4z+yXr4DIFyT6L5qOIEGJ1zVtlj6baPyHAG4Yw= -github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250313182551-a36df89b8084 h1:j0GMko/b2Uk8wQ57TTLZHZckwmMjkc5AYmz4vZIYxHU= +github.com/dolthub/vitess v0.0.0-20250313182551-a36df89b8084/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From 0d935c6b4680f91c80a3b133d5f731eff9f0f645 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 13 Mar 2025 12:10:22 -0700 Subject: [PATCH 028/111] test --- enginetest/memory_engine_test.go | 46 +++++++++++++++++--------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 299da0e5d3..4be83c4bf3 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,33 +200,37 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ - { - Name: "CASE statements", - SetUpScript: []string{ - "SET @x = 0", + { + Name: "IF/ELSE with 1 SELECT at end", + SetUpScript: []string{ + "SET @outparam = ''", ` -CREATE PROCEDURE p1(OUT x INT) +CREATE PROCEDURE p1(OUT s VARCHAR(200), N DOUBLE, m DOUBLE) BEGIN - SET x = 123; - SELECT x; + SET s = ''; + IF n = m THEN + SET s = 'equals'; + ELSE + IF n > m THEN + SET s = 'greater'; + ELSE + SET s = 'less'; + END IF; + SET s = CONCAT('is ', s, ' than'); + END IF; + SET s = CONCAT(n, ' ', s, ' ', m, '.'); + SELECT s; END;`, - }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "CALL p1(@x)", - Expected: []sql.Row{ - {123}, - }, - }, - { - Query: "SELECT @x", - Expected: []sql.Row{ - {123}, - }, + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL p1(@outparam, null, 2)", + Expected: []sql.Row{ + {nil}, }, - }, }, + }, } for _, test := range scripts { From 1c4f9d554cb8675b4a02807212a41d0171758927 Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 13 Mar 2025 19:13:15 +0000 Subject: [PATCH 029/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/memory_engine_test.go | 24 ++++++++++++------------ sql/analyzer/optimization_rules.go | 4 +--- sql/procedures/interpreter_logic.go | 1 - 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 4be83c4bf3..c3214ccd74 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,11 +200,11 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ - { - Name: "IF/ELSE with 1 SELECT at end", - SetUpScript: []string{ - "SET @outparam = ''", - ` + { + Name: "IF/ELSE with 1 SELECT at end", + SetUpScript: []string{ + "SET @outparam = ''", + ` CREATE PROCEDURE p1(OUT s VARCHAR(200), N DOUBLE, m DOUBLE) BEGIN SET s = ''; @@ -221,16 +221,16 @@ BEGIN SET s = CONCAT(n, ' ', s, ' ', m, '.'); SELECT s; END;`, - }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "CALL p1(@outparam, null, 2)", - Expected: []sql.Row{ - {nil}, + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "CALL p1(@outparam, null, 2)", + Expected: []sql.Row{ + {nil}, + }, }, }, }, - }, } for _, test := range scripts { diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index d17597ba79..98dc7e58b8 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -15,11 +15,9 @@ package analyzer import ( - "github.com/dolthub/go-mysql-server/memory" -"strings" + "strings" "github.com/dolthub/go-mysql-server/memory" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index a0bb16159b..534f4fd0fd 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -311,7 +311,6 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I } } - // TODO: Set all user and system variables from INOUT and OUT params. // Copy logic from proc_iters.go: callIter.Close() From 8c0243b5c501a8f1887eeabca97821bd1e13a176 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 13 Mar 2025 12:52:14 -0700 Subject: [PATCH 030/111] handle else if --- sql/analyzer/optimization_rules.go | 4 +-- sql/procedures/interpreter_logic.go | 12 ++++--- sql/procedures/parse.go | 52 ++++++++++++++++------------- 3 files changed, 36 insertions(+), 32 deletions(-) diff --git a/sql/analyzer/optimization_rules.go b/sql/analyzer/optimization_rules.go index d17597ba79..98dc7e58b8 100644 --- a/sql/analyzer/optimization_rules.go +++ b/sql/analyzer/optimization_rules.go @@ -15,11 +15,9 @@ package analyzer import ( - "github.com/dolthub/go-mysql-server/memory" -"strings" + "strings" "github.com/dolthub/go-mysql-server/memory" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index a0bb16159b..b5a6fc8a34 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -144,6 +144,9 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I stack.NewVariableWithValue(param.Name, param.Type, param.Value) } + // TODO: remove this; track last selectRowIter + var selIter sql.RowIter + // Run the statements // TODO: eventually return multiple sql.RowIters var rowIters []sql.RowIter @@ -177,6 +180,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I return nil, nil, err } rowIters = append(rowIters, rowIter) + selIter = rowIter case OpCode_Declare: declareStmt := operation.PrimaryData.(*ast.Declare) @@ -311,13 +315,11 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I } } - - // TODO: Set all user and system variables from INOUT and OUT params. - // Copy logic from proc_iters.go: callIter.Close() - + if selIter != nil { + return selIter, stack, nil + } if len(rowIters) == 0 { rowIters = append(rowIters, sql.RowsToRowIter(sql.Row{types.NewOkResult(0)})) } - return rowIters[len(rowIters)-1], stack, nil } diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 94e41478be..cf035896d5 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -83,40 +83,44 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast *ops = append(*ops, setOp) case *ast.IfStatement: - // TODO: assume exactly one condition for now - ifCond := s.Conditions[0] - // TODO: convert condition into a select query - selectCond := &ast.Select{ - SelectExprs: ast.SelectExprs{ - &ast.AliasedExpr{ - Expr: ifCond.Expr, + // TODO: each subsequent condition is an else if + var ifElseGotoOps []*InterpreterOperation + for _, ifCond := range s.Conditions { + selectCond := &ast.Select{ + SelectExprs: ast.SelectExprs{ + &ast.AliasedExpr{ + Expr: ifCond.Expr, + }, }, - }, - } - ifOp := &InterpreterOperation{ - OpCode: OpCode_If, - PrimaryData: selectCond, - } - *ops = append(*ops, ifOp) + } + ifOp := &InterpreterOperation{ + OpCode: OpCode_If, + PrimaryData: selectCond, + } + *ops = append(*ops, ifOp) - for _, ifStmt := range ifCond.Statements { - if err := ConvertStmt(ops, stack, ifStmt); err != nil { - return err + for _, ifStmt := range ifCond.Statements { + if err := ConvertStmt(ops, stack, ifStmt); err != nil { + return err + } } - } - gotoOp := &InterpreterOperation{ - OpCode: OpCode_Goto, - } - *ops = append(*ops, gotoOp) + gotoOp := &InterpreterOperation{ + OpCode: OpCode_Goto, + } + ifElseGotoOps = append(ifElseGotoOps, gotoOp) + *ops = append(*ops, gotoOp) - ifOp.Index = len(*ops) // start of else block + ifOp.Index = len(*ops) // start of next if statement + } for _, elseStmt := range s.Else { if err := ConvertStmt(ops, stack, elseStmt); err != nil { return err } } - gotoOp.Index = len(*ops) // end of if statement + for _, gotoOp := range ifElseGotoOps { + gotoOp.Index = len(*ops) // end of if statement + } case *ast.CaseStatement: var caseGotoOps []*InterpreterOperation From d125dec1d325bce9709292d262c54412fd2859cd Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 13 Mar 2025 13:42:52 -0700 Subject: [PATCH 031/111] handle insert --- sql/procedures/interpreter_logic.go | 49 ++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index b5a6fc8a34..dcd22bca57 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -41,6 +41,17 @@ type Parameter struct { func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLNode, error) { switch e := expr.(type) { + case *ast.ColName: + iv := stack.GetVariable(strings.ToLower(e.Name.String())) + if iv == nil { + return expr, nil + } + newExpr := iv.ToAST() + return &ast.ColName{ + Name: e.Name, + Qualifier: e.Qualifier, + StoredProcVal: newExpr, + }, nil case *ast.AliasedExpr: newExpr, err := replaceVariablesInExpr(stack, e.Expr) if err != nil { @@ -98,17 +109,34 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN return nil, err } } - case *ast.ColName: - iv := stack.GetVariable(e.Name.String()) - if iv == nil { - return expr, nil + case *ast.Call: + for i := range e.Params { + newExpr, err := replaceVariablesInExpr(stack, e.Params[i]) + if err != nil { + return nil, err + } + e.Params[i] = newExpr.(ast.Expr) + } + case ast.ValTuple: + for i := range e { + newExpr, err := replaceVariablesInExpr(stack, e[i]) + if err != nil { + return nil, err + } + e[i] = newExpr.(ast.Expr) + } + case *ast.Insert: + switch insRows := e.Rows.(type) { + case *ast.AliasedValues: + for i := range insRows.Values { + newExpr, err := replaceVariablesInExpr(stack, insRows.Values[i]) + if err != nil { + return nil, err + } + insRows.Values[i] = newExpr.(ast.ValTuple) + } + e.Rows = insRows } - newExpr := iv.ToAST() - return &ast.ColName{ - Name: e.Name, - Qualifier: e.Qualifier, - StoredProcVal: newExpr, - }, nil } return expr, nil } @@ -230,7 +258,6 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I } case OpCode_Execute: - // TODO: replace variables stmt, err := replaceVariablesInExpr(stack, operation.PrimaryData) if err != nil { return nil, nil, err From 7752521608156fc76187c40cf54d4a76cd27e08f Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 13 Mar 2025 15:08:11 -0700 Subject: [PATCH 032/111] error --- sql/planbuilder/dml.go | 4 ++ sql/procedures/interpreter_logic.go | 70 ++++++++++++++++++++--------- 2 files changed, 54 insertions(+), 20 deletions(-) diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index 71b8f9b56d..b6bf24ed56 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -734,6 +734,10 @@ func (b *Builder) buildInto(inScope *scope, into *ast.Into) { if strings.HasPrefix(val.String(), "@") { vars[i] = expression.NewUserVar(strings.TrimPrefix(val.String(), "@")) } else { + if inScope.proc == nil { + err := sql.ErrExternalProcedureMissingContextParam.New(val.String()) + b.handleErr(err) + } col, ok := inScope.proc.GetVar(val.String()) if !ok { err := sql.ErrExternalProcedureMissingContextParam.New(val.String()) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index dcd22bca57..958235656e 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -117,6 +117,40 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e.Params[i] = newExpr.(ast.Expr) } + case *ast.Into: + for i := range e.Variables { + newExpr, err := replaceVariablesInExpr(stack, e.Variables[i]) + if err != nil { + return nil, err + } + e.Variables[i] = newExpr.(ast.ColIdent) + } + case *ast.Select: + for i := range e.SelectExprs { + newExpr, err := replaceVariablesInExpr(stack, e.SelectExprs[i]) + if err != nil { + return nil, err + } + e.SelectExprs[i] = newExpr.(ast.SelectExpr) + } + if e.Into != nil { + newExpr, err := replaceVariablesInExpr(stack, e.Into) + if err != nil { + return nil, err + } + e.Into = newExpr.(*ast.Into) + } + case *ast.SetOp: + newLeftExpr, err := replaceVariablesInExpr(stack, e.Left) + if err != nil { + return nil, err + } + newRightExpr, err := replaceVariablesInExpr(stack, e.Right) + if err != nil { + return nil, err + } + e.Left = newLeftExpr.(ast.SelectStatement) + e.Right = newRightExpr.(ast.SelectStatement) case ast.ValTuple: for i := range e { newExpr, err := replaceVariablesInExpr(stack, e[i]) @@ -125,18 +159,20 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e[i] = newExpr.(ast.Expr) } - case *ast.Insert: - switch insRows := e.Rows.(type) { - case *ast.AliasedValues: - for i := range insRows.Values { - newExpr, err := replaceVariablesInExpr(stack, insRows.Values[i]) - if err != nil { - return nil, err - } - insRows.Values[i] = newExpr.(ast.ValTuple) + case *ast.AliasedValues: + for i := range e.Values { + newExpr, err := replaceVariablesInExpr(stack, e.Values[i]) + if err != nil { + return nil, err } - e.Rows = insRows + e.Values[i] = newExpr.(ast.ValTuple) } + case *ast.Insert: + newExpr, err := replaceVariablesInExpr(stack, e.Rows) + if err != nil { + return nil, err + } + e.Rows = newExpr.(ast.InsertRows) } return expr, nil } @@ -193,17 +229,11 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I switch operation.OpCode { case OpCode_Select: selectStmt := operation.PrimaryData.(*ast.Select) - if selectStmt.SelectExprs == nil { - panic("select stmt with no select exprs") - } - for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) - if err != nil { - return nil, nil, err - } - selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) + newSelectStmt, err := replaceVariablesInExpr(stack, selectStmt) + if err != nil { + return nil, nil, err } - rowIter, err := query(ctx, runner, selectStmt) + rowIter, err := query(ctx, runner, newSelectStmt.(*ast.Select)) if err != nil { return nil, nil, err } From 08f8e72335be027a066f8ebc261caffce81a7aae Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 13 Mar 2025 15:08:46 -0700 Subject: [PATCH 033/111] test --- enginetest/memory_engine_test.go | 45 ++++++++++++-------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 4be83c4bf3..de3d63eb35 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,37 +200,26 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ - { - Name: "IF/ELSE with 1 SELECT at end", - SetUpScript: []string{ - "SET @outparam = ''", - ` -CREATE PROCEDURE p1(OUT s VARCHAR(200), N DOUBLE, m DOUBLE) -BEGIN - SET s = ''; - IF n = m THEN - SET s = 'equals'; - ELSE - IF n > m THEN - SET s = 'greater'; - ELSE - SET s = 'less'; - END IF; - SET s = CONCAT('is ', s, ' than'); - END IF; - SET s = CONCAT(n, ' ', s, ' ', m, '.'); - SELECT s; -END;`, - }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "CALL p1(@outparam, null, 2)", - Expected: []sql.Row{ - {nil}, + { + Name: "Simple SELECT INTO", + SetUpScript: []string{ + "CREATE PROCEDURE testabc(IN x DOUBLE, IN y DOUBLE, OUT abc DOUBLE) SELECT x*y INTO abc", + }, + Assertions: []queries.ScriptTestAssertion{ + //{ + // Query: "select 1 into @x", + // Expected: []sql.Row{}, + //}, + { + Query: "CALL testabc(2, 3, @res1)", + Expected: []sql.Row{{float64(6)}}, + }, + { + Query: "CALL testabc(9, 9.5, @res2)", + Expected: []sql.Row{{float64(85.5)}}, }, }, }, - }, } for _, test := range scripts { From b96bbc3e177163d84b4942d5e5b1f6ce046e36cc Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 14 Mar 2025 11:29:41 -0700 Subject: [PATCH 034/111] fix user vras --- sql/procedures/interpreter_logic.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 958235656e..736bf1002e 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -97,13 +97,14 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN case *ast.Set: for _, setExpr := range e.Exprs { // TODO: properly handle user scope variables - if setExpr.Scope == ast.SetScope_User { - continue - } newExpr, err := replaceVariablesInExpr(stack, setExpr.Expr) if err != nil { return nil, err } + setExpr.Expr = newExpr.(ast.Expr) + if setExpr.Scope == ast.SetScope_User { + continue + } err = stack.SetVariable(nil, setExpr.Name.String(), newExpr) if err != nil { return nil, err @@ -118,6 +119,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN e.Params[i] = newExpr.(ast.Expr) } case *ast.Into: + // TODO: somehow support select into variables for i := range e.Variables { newExpr, err := replaceVariablesInExpr(stack, e.Variables[i]) if err != nil { @@ -140,6 +142,12 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e.Into = newExpr.(*ast.Into) } + case *ast.Subquery: + newExpr, err := replaceVariablesInExpr(stack, e.Select) + if err != nil { + return nil, err + } + e.Select = newExpr.(*ast.Select) case *ast.SetOp: newLeftExpr, err := replaceVariablesInExpr(stack, e.Left) if err != nil { From 586c990ea59e75eeeb75467aa57d58c601288198 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 14 Mar 2025 11:30:13 -0700 Subject: [PATCH 035/111] asdf --- enginetest/memory_engine_test.go | 26 +++++++++++++++---------- enginetest/queries/procedure_queries.go | 8 +++++--- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index de3d63eb35..b224134021 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -201,22 +201,28 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "Simple SELECT INTO", + Name: "Subquery on SET user variable captures parameter", SetUpScript: []string{ - "CREATE PROCEDURE testabc(IN x DOUBLE, IN y DOUBLE, OUT abc DOUBLE) SELECT x*y INTO abc", + ` +CREATE PROCEDURE p1(x VARCHAR(20)) +BEGIN + SET @randomvar = (SELECT LENGTH(x)); + SELECT @randomvar; +END;`, }, Assertions: []queries.ScriptTestAssertion{ - //{ - // Query: "select 1 into @x", - // Expected: []sql.Row{}, - //}, { - Query: "CALL testabc(2, 3, @res1)", - Expected: []sql.Row{{float64(6)}}, + SkipResultCheckOnServerEngine: true, // the user var has null type, which returns nil value over the wire. + Query: "CALL p1('hi')", + Expected: []sql.Row{ + {int64(2)}, + }, }, { - Query: "CALL testabc(9, 9.5, @res2)", - Expected: []sql.Row{{float64(85.5)}}, + Query: "CALL p1('hello')", + Expected: []sql.Row{ + {int64(5)}, + }, }, }, }, diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index deb46bbd77..c5c4b79e7f 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -903,9 +903,11 @@ END;`, SetUpScript: []string{ "CREATE TABLE inventory (item_id int primary key, shelf_id int, item varchar(10))", "INSERT INTO inventory VALUES (1, 1, 'a'), (2, 1, 'b'), (3, 2, 'c'), (4, 1, 'd'), (5, 4, 'e')", - `CREATE PROCEDURE count_and_print(IN p_shelf_id INT, OUT p_count INT) BEGIN -SELECT item FROM inventory WHERE shelf_id = p_shelf_id ORDER BY item ASC; -SELECT COUNT(*) INTO p_count FROM inventory WHERE shelf_id = p_shelf_id; + ` +CREATE PROCEDURE count_and_print(IN p_shelf_id INT, OUT p_count INT) +BEGIN + SELECT item FROM inventory WHERE shelf_id = p_shelf_id ORDER BY item ASC; + SELECT COUNT(*) INTO p_count FROM inventory WHERE shelf_id = p_shelf_id; END`, }, Assertions: []ScriptTestAssertion{ From fae1af91359c90560cfff845bdc70f2b80980ef7 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 14 Mar 2025 12:57:51 -0700 Subject: [PATCH 036/111] declare --- enginetest/memory_engine_test.go | 57 ++++++++++++++++++------- enginetest/queries/procedure_queries.go | 4 +- sql/procedures/interpreter_logic.go | 7 +++ 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index b224134021..51c8b9c087 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -201,28 +201,55 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "Subquery on SET user variable captures parameter", + Name: "DECLARE CONDITION", SetUpScript: []string{ - ` -CREATE PROCEDURE p1(x VARCHAR(20)) -BEGIN - SET @randomvar = (SELECT LENGTH(x)); - SELECT @randomvar; + `CREATE PROCEDURE p1(x INT) +BEGIN + DECLARE specialty CONDITION FOR SQLSTATE '45000'; + DECLARE specialty2 CONDITION FOR SQLSTATE '02000'; + IF x = 0 THEN + SIGNAL SQLSTATE '01000'; + ELSEIF x = 1 THEN + SIGNAL SQLSTATE '45000' + SET MESSAGE_TEXT = 'A custom error occurred 1'; + ELSEIF x = 2 THEN + SIGNAL specialty + SET MESSAGE_TEXT = 'A custom error occurred 2', MYSQL_ERRNO = 1002; + ELSEIF x = 3 THEN + SIGNAL specialty; + ELSEIF x = 4 THEN + SIGNAL specialty2; + ELSE + SIGNAL SQLSTATE '01000' + SET MESSAGE_TEXT = 'A warning occurred', MYSQL_ERRNO = 1000; + SIGNAL SQLSTATE '45000' + SET MESSAGE_TEXT = 'An error occurred', MYSQL_ERRNO = 1001; + END IF; + BEGIN + DECLARE specialty3 CONDITION FOR SQLSTATE '45000'; + END; END;`, }, Assertions: []queries.ScriptTestAssertion{ { - SkipResultCheckOnServerEngine: true, // the user var has null type, which returns nil value over the wire. - Query: "CALL p1('hi')", - Expected: []sql.Row{ - {int64(2)}, - }, + Query: "CALL p1(0)", + ExpectedErrStr: "warnings not yet implemented", }, { - Query: "CALL p1('hello')", - Expected: []sql.Row{ - {int64(5)}, - }, + Query: "CALL p1(1)", + ExpectedErrStr: "A custom error occurred 1 (errno 1644) (sqlstate 45000)", + }, + { + Query: "CALL p1(2)", + ExpectedErrStr: "A custom error occurred 2 (errno 1002) (sqlstate 45000)", + }, + { + Query: "CALL p1(3)", + ExpectedErrStr: "Unhandled user-defined exception condition (errno 1644) (sqlstate 45000)", + }, + { + Query: "CALL p1(4)", + ExpectedErrStr: "Unhandled user-defined not found condition (errno 1643) (sqlstate 02000)", }, }, }, diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index c5c4b79e7f..998384a029 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -989,7 +989,9 @@ END;`, }, { Query: "CALL p1(@x);", - Expected: []sql.Row{{}}, + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, }, { Query: "SELECT @x;", diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 736bf1002e..0d61888300 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -142,6 +142,13 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e.Into = newExpr.(*ast.Into) } + if e.Where != nil { + newExpr, err := replaceVariablesInExpr(stack, e.Where.Expr) + if err != nil { + return nil, err + } + e.Where.Expr = newExpr.(ast.Expr) + } case *ast.Subquery: newExpr, err := replaceVariablesInExpr(stack, e.Select) if err != nil { From ba4fedf2ba3ab1cb63a947096ffe4c299ab365a8 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 17 Mar 2025 11:13:51 -0700 Subject: [PATCH 037/111] fix some typing issues --- sql/planbuilder/scalar.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/planbuilder/scalar.go b/sql/planbuilder/scalar.go index f201cd4ec8..d757d69b57 100644 --- a/sql/planbuilder/scalar.go +++ b/sql/planbuilder/scalar.go @@ -109,7 +109,11 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) { if v.StoredProcVal != nil { switch val := v.StoredProcVal.(type) { case *ast.SQLVal: - return b.ConvertVal(val) + resVal := b.ConvertVal(val) + if lit, isLit := resVal.(*expression.Literal); isLit && val.Type == ast.FloatVal { + return expression.NewLiteral(lit.Value(), types.Float64) + } + return resVal case *ast.NullVal: return expression.NewLiteral(nil, types.Null) } From 27f7b4605daccccd0b8eb6e526e20764ec531ac8 Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 17 Mar 2025 18:15:20 +0000 Subject: [PATCH 038/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/queries/procedure_queries.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 998384a029..44c6303de9 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -988,7 +988,7 @@ END;`, Expected: []sql.Row{{}}, }, { - Query: "CALL p1(@x);", + Query: "CALL p1(@x);", Expected: []sql.Row{ {types.NewOkResult(0)}, }, From ea7551d7fa79c8270ba79ab29c316a11246a9d56 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 18 Mar 2025 14:36:31 -0700 Subject: [PATCH 039/111] select into and some declare --- sql/planbuilder/proc.go | 4 ++ sql/procedures/interpreter_logic.go | 75 +++++++++++++++++++++++------ sql/procedures/interpreter_stack.go | 19 +++++++- 3 files changed, 80 insertions(+), 18 deletions(-) diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 0bdf9b80cb..3c15e4f5af 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -665,6 +665,10 @@ func (b *Builder) buildSignal(inScope *scope, s *ast.Signal) (outScope *scope) { sqlStateValue := s.SqlStateValue if s.ConditionName != "" { signalName := strings.ToLower(s.ConditionName) + if inScope.proc == nil { + err := sql.ErrDeclareConditionNotFound.New(signalName) + b.handleErr(err) + } condition := inScope.proc.GetCondition(signalName) if condition == nil { err := sql.ErrDeclareConditionNotFound.New(signalName) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 0d61888300..0ea3eee24e 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -105,7 +105,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN if setExpr.Scope == ast.SetScope_User { continue } - err = stack.SetVariable(nil, setExpr.Name.String(), newExpr) + err = stack.SetVariable(setExpr.Name.String(), newExpr) if err != nil { return nil, err } @@ -244,32 +244,75 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I switch operation.OpCode { case OpCode_Select: selectStmt := operation.PrimaryData.(*ast.Select) - newSelectStmt, err := replaceVariablesInExpr(stack, selectStmt) + if newSelectStmt, err := replaceVariablesInExpr(stack, selectStmt); err == nil { + selectStmt = newSelectStmt.(*ast.Select) + } else { + return nil, nil, err + } + + if selectStmt.Into == nil { + rowIter, err := query(ctx, runner, selectStmt) + if err != nil { + return nil, nil, err + } + rowIters = append(rowIters, rowIter) + selIter = rowIter + continue + } + + selectInto := selectStmt.Into + selectStmt.Into = nil + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) if err != nil { return nil, nil, err } - rowIter, err := query(ctx, runner, newSelectStmt.(*ast.Select)) + row, err := rowIter.Next(ctx) if err != nil { return nil, nil, err } - rowIters = append(rowIters, rowIter) - selIter = rowIter - - case OpCode_Declare: - declareStmt := operation.PrimaryData.(*ast.Declare) - for _, decl := range declareStmt.Variables.Names { - varType, err := types.ColumnTypeToType(&declareStmt.Variables.VarType) + if _, err = rowIter.Next(ctx); err != io.EOF { + return nil, nil, err + } + if err = rowIter.Close(ctx); err != nil { + return nil, nil, err + } + if len(row) != len(selectInto.Variables) { + return nil, nil, sql.ErrColumnNumberDoesNotMatch.New() + } + for i := range selectInto.Variables { + intoVar := strings.ToLower(selectInto.Variables[i].String()) + if strings.HasPrefix(intoVar, "@") { + // TODO + continue + } + err = stack.SetVariable(intoVar, row[i]) if err != nil { return nil, nil, err } - varName := strings.ToLower(decl.String()) - if declareStmt.Variables.VarType.Default != nil { - stack.NewVariableWithValue(varName, varType, declareStmt.Variables.VarType.Default) - } else { - stack.NewVariable(varName, varType) + } + + case OpCode_Declare: + declareStmt := operation.PrimaryData.(*ast.Declare) + if declareStmt.Condition != nil { + // TODO: copy error checks from buildDeclareCondition + stack.NewCondition(strings.ToLower(declareStmt.Condition.Name), declareStmt.Condition.SqlStateValue, 0) + } + if declareStmt.Variables != nil { + for _, decl := range declareStmt.Variables.Names { + varType, err := types.ColumnTypeToType(&declareStmt.Variables.VarType) + if err != nil { + return nil, nil, err + } + varName := strings.ToLower(decl.String()) + if declareStmt.Variables.VarType.Default != nil { + stack.NewVariableWithValue(varName, varType, declareStmt.Variables.VarType.Default) + } else { + stack.NewVariable(varName, varType) + } } } + case OpCode_Set: selectStmt := operation.PrimaryData.(*ast.Select) if selectStmt.SelectExprs == nil { @@ -297,7 +340,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I return nil, nil, err } - err = stack.SetVariable(nil, strings.ToLower(operation.Target), row[0]) + err = stack.SetVariable(strings.ToLower(operation.Target), row[0]) if err != nil { return nil, nil, err } diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 3166572d12..f740cf73aa 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -105,9 +105,16 @@ func (iv *InterpreterVariable) ToAST() ast.Expr { return ast.NewStrVal([]byte(fmt.Sprintf("%s", iv.Value))) } +// InterpreterCondition is a declare condition with custom SQLState and ErrorCode. +type InterpreterCondition struct { + SQLState string + MySQLErrCode int64 +} + // InterpreterScopeDetails contains all of the details that are relevant to a particular scope. type InterpreterScopeDetails struct { - variables map[string]*InterpreterVariable + variables map[string]*InterpreterVariable + conditions map[string]*InterpreterCondition } // InterpreterStack represents the working information that an interpreter will use during execution. It is not exactly @@ -176,6 +183,14 @@ func (is *InterpreterStack) NewVariableAlias(alias string, variable *Interpreter is.stack.Peek().variables[alias] = variable } +// NewCondition creates a new variable in the current scope. +func (is *InterpreterStack) NewCondition(name string, sqlState string, mysqlErrCode int64) { + is.stack.Peek().conditions[name] = &InterpreterCondition{ + SQLState: sqlState, + MySQLErrCode: mysqlErrCode, + } +} + // PushScope creates a new scope. func (is *InterpreterStack) PushScope() { is.stack.Push(&InterpreterScopeDetails{ @@ -191,7 +206,7 @@ func (is *InterpreterStack) PopScope() { // SetVariable sets the first variable found, with a matching name, to the value given. This does not ensure that the // value matches the expectations of the type, so it should be validated before this is called. Returns an error if the // variable cannot be found. -func (is *InterpreterStack) SetVariable(ctx *sql.Context, name string, val any) error { +func (is *InterpreterStack) SetVariable(name string, val any) error { iv := is.GetVariable(name) if iv == nil { return fmt.Errorf("variable `%s` could not be found", name) From 784cbd2e191453eb0d98c66685ff6779e8facbd0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 20 Mar 2025 02:35:26 -0700 Subject: [PATCH 040/111] condition and signal logic --- sql/procedures/interpreter_logic.go | 121 +++++++++++++++++++++++- sql/procedures/interpreter_operation.go | 17 ++-- sql/procedures/interpreter_stack.go | 17 +++- sql/procedures/parse.go | 9 ++ 4 files changed, 150 insertions(+), 14 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 0ea3eee24e..83686d1845 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -15,13 +15,16 @@ package procedures import ( + "fmt" "io" + "strconv" "strings" - ast "github.com/dolthub/vitess/go/vt/sqlparser" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" + + "github.com/dolthub/vitess/go/mysql" + ast "github.com/dolthub/vitess/go/vt/sqlparser" ) // InterpreterNode is an interface that implements an interpreter. These are typically used for functions (which may be @@ -293,10 +296,31 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I case OpCode_Declare: declareStmt := operation.PrimaryData.(*ast.Declare) + if declareStmt.Condition != nil { - // TODO: copy error checks from buildDeclareCondition - stack.NewCondition(strings.ToLower(declareStmt.Condition.Name), declareStmt.Condition.SqlStateValue, 0) + cond := declareStmt.Condition + condName := strings.ToLower(cond.Name) + stateVal := cond.SqlStateValue + var num int64 + var err error + if stateVal != "" { + if len(stateVal) != 5 { + return nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") + } + if stateVal[0:2] == "00" { + return nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", stateVal) + } + } else { + // use our own error + num, err = strconv.ParseInt(string(cond.MysqlErrorCode.Val), 10, 64) + if err != nil || num == 0 { + err = fmt.Errorf("invalid value '%s' for MySQL error code", string(cond.MysqlErrorCode.Val)) + return nil, nil, err + } + } + stack.NewCondition(condName, stateVal, num) } + if declareStmt.Variables != nil { for _, decl := range declareStmt.Variables.Names { varType, err := types.ColumnTypeToType(&declareStmt.Variables.VarType) @@ -312,6 +336,95 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I } } + case OpCode_Signal: + // TODO: copy logic from planbuilder/proc.go: buildSignal() + signalStmt := operation.PrimaryData.(*ast.Signal) + var msgTxt string + var sqlState string + var mysqlErrNo int + if signalStmt.ConditionName == "" { + sqlState = signalStmt.SqlStateValue + if sqlState[0:2] == "01" { + return nil, nil, fmt.Errorf("warnings not yet implemented") + } + } else { + cond := stack.GetCondition(strings.ToLower(signalStmt.ConditionName)) + if cond == nil { + return nil, nil, sql.ErrDeclareConditionNotFound.New(signalStmt.ConditionName) + } + mysqlErrNo = int(cond.MySQLErrCode) + sqlState = cond.SQLState + } + + if len(sqlState) != 5 { + return nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") + } + + for _, item := range signalStmt.Info { + switch item.ConditionItemName { + case ast.SignalConditionItemName_MysqlErrno: + switch val := item.Value.(type) { + case *ast.SQLVal: + num, err := strconv.ParseInt(string(val.Val), 10, 64) + if err != nil || num == 0 { + return nil, nil, fmt.Errorf("invalid value '%s' for MySQL error code", string(val.Val)) + } + mysqlErrNo = int(num) + case *ast.ColName: + return nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) + default: + return nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) + } + case ast.SignalConditionItemName_MessageText: + switch val := item.Value.(type) { + case *ast.SQLVal: + msgTxt = string(val.Val) + if len(msgTxt) > 128 { + return nil, nil, fmt.Errorf("signal condition information item MESSAGE_TEXT has max length of 128") + } + case *ast.ColName: + return nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) + default: + return nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) + } + default: + switch val := item.Value.(type) { + case *ast.SQLVal: + msgTxt = string(val.Val) + if len(msgTxt) > 64 { + return nil, nil, fmt.Errorf("signal condition information item %s has max length of 64", strings.ToUpper(string(item.ConditionItemName))) + } + default: + return nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item '%s''", item.Value, strings.ToUpper(string(item.ConditionItemName))) + } + } + } + + if mysqlErrNo == 0 { + switch sqlState[0:2] { + case "01": + mysqlErrNo = 1642 + case "02": + mysqlErrNo = 1643 + default: + mysqlErrNo = 1644 + } + } + + if msgTxt == "" { + switch sqlState[0:2] { + case "00": + return nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", sqlState) + case "01": + msgTxt = "Unhandled user-defined warning condition" + case "02": + msgTxt = "Unhandled user-defined not found condition" + default: + msgTxt = "Unhandled user-defined exception condition" + } + } + + return nil, nil, mysql.NewSQLError(mysqlErrNo, sqlState, msgTxt) case OpCode_Set: selectStmt := operation.PrimaryData.(*ast.Select) diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index 9ef4d12e25..ce0443d9c3 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -17,16 +17,17 @@ import ast "github.com/dolthub/vitess/go/vt/sqlparser" type OpCode uint16 const ( - OpCode_Select OpCode = iota - OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html + OpCode_Select OpCode = iota + OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html + OpCode_Signal OpCode_Set - OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING - OpCode_Execute // Everything that's not a SELECT - OpCode_Goto // All control-flow structures can be represented using Goto - OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS - OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING + OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING + OpCode_Execute // Everything that's not a SELECT + OpCode_Goto // All control-flow structures can be represented using Goto + OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS + OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING OpCode_ScopeBegin // This is used for scope control, specific to Doltgres - OpCode_ScopeEnd // This is used for scope control, specific to Doltgres + OpCode_ScopeEnd // This is used for scope control, specific to Doltgres ) // InterpreterOperation is an operation that will be performed by the interpreter. diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index f740cf73aa..f1ec8cb5ac 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -129,7 +129,8 @@ func NewInterpreterStack() *InterpreterStack { stack := NewStack[*InterpreterScopeDetails]() // This first push represents the function base, including parameters stack.Push(&InterpreterScopeDetails{ - variables: make(map[string]*InterpreterVariable), + variables: make(map[string]*InterpreterVariable), + conditions: make(map[string]*InterpreterCondition), }) return &InterpreterStack{ stack: stack, @@ -191,10 +192,22 @@ func (is *InterpreterStack) NewCondition(name string, sqlState string, mysqlErrC } } +// GetCondition traverses the stack (starting from the top) to find a condition with a matching name. Returns nil if no +// variable was found. +func (is *InterpreterStack) GetCondition(name string) *InterpreterCondition { + for i := 0; i < is.stack.Len(); i++ { + if ic, ok := is.stack.PeekDepth(i).conditions[name]; ok { + return ic + } + } + return nil +} + // PushScope creates a new scope. func (is *InterpreterStack) PushScope() { is.stack.Push(&InterpreterScopeDetails{ - variables: make(map[string]*InterpreterVariable), + variables: make(map[string]*InterpreterVariable), + conditions: make(map[string]*InterpreterCondition), }) } diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index cf035896d5..917a004c72 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -55,6 +55,13 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } *ops = append(*ops, declareOp) + case *ast.Signal: + signalOp := &InterpreterOperation{ + OpCode: OpCode_Signal, + PrimaryData: s, + } + *ops = append(*ops, signalOp) + case *ast.Set: if len(s.Exprs) != 1 { panic("unexpected number of set expressions") @@ -273,6 +280,8 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast Target: s.Label, // hacky? way to signal a leave } *ops = append(*ops, leaveOp) + + default: execOp := &InterpreterOperation{ OpCode: OpCode_Execute, From 421ac42e42509d8dc2ec3c875ec156bc12e3a4b7 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 20 Mar 2025 10:59:19 -0700 Subject: [PATCH 041/111] fix --- sql/procedures/interpreter_logic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 83686d1845..be18f6bba4 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -352,8 +352,8 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I if cond == nil { return nil, nil, sql.ErrDeclareConditionNotFound.New(signalStmt.ConditionName) } - mysqlErrNo = int(cond.MySQLErrCode) sqlState = cond.SQLState + mysqlErrNo = int(cond.MySQLErrCode) } if len(sqlState) != 5 { From 57ef575ffff66eff35ed87b61a00aa811a79f1d2 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 20 Mar 2025 10:59:31 -0700 Subject: [PATCH 042/111] updating test --- enginetest/queries/procedure_queries.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 998384a029..1a3769d016 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -1021,8 +1021,6 @@ BEGIN ELSEIF x = 4 THEN SIGNAL specialty2; ELSE - SIGNAL SQLSTATE '01000' - SET MESSAGE_TEXT = 'A warning occurred', MYSQL_ERRNO = 1000; SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'An error occurred', MYSQL_ERRNO = 1001; END IF; @@ -1052,6 +1050,10 @@ END;`, Query: "CALL p1(4)", ExpectedErrStr: "Unhandled user-defined not found condition (errno 1643) (sqlstate 02000)", }, + { + Query: "CALL p1(5)", + ExpectedErrStr: "An error occurred (errno 1001) (sqlstate 45000)", + }, }, }, { From a7256153bfc4e2be3cc4a7771b0181a3b1109305 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 20 Mar 2025 11:11:04 -0700 Subject: [PATCH 043/111] cursor implementation --- enginetest/memory_engine_test.go | 60 +++++++++----------------------- 1 file changed, 17 insertions(+), 43 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 51c8b9c087..9c09b82299 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -201,55 +201,29 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "DECLARE CONDITION", + Name: "FETCH multiple rows", SetUpScript: []string{ - `CREATE PROCEDURE p1(x INT) + `CREATE TABLE t1 (pk BIGINT PRIMARY KEY);`, + ` +CREATE PROCEDURE p1() BEGIN - DECLARE specialty CONDITION FOR SQLSTATE '45000'; - DECLARE specialty2 CONDITION FOR SQLSTATE '02000'; - IF x = 0 THEN - SIGNAL SQLSTATE '01000'; - ELSEIF x = 1 THEN - SIGNAL SQLSTATE '45000' - SET MESSAGE_TEXT = 'A custom error occurred 1'; - ELSEIF x = 2 THEN - SIGNAL specialty - SET MESSAGE_TEXT = 'A custom error occurred 2', MYSQL_ERRNO = 1002; - ELSEIF x = 3 THEN - SIGNAL specialty; - ELSEIF x = 4 THEN - SIGNAL specialty2; - ELSE - SIGNAL SQLSTATE '01000' - SET MESSAGE_TEXT = 'A warning occurred', MYSQL_ERRNO = 1000; - SIGNAL SQLSTATE '45000' - SET MESSAGE_TEXT = 'An error occurred', MYSQL_ERRNO = 1001; - END IF; - BEGIN - DECLARE specialty3 CONDITION FOR SQLSTATE '45000'; - END; + DECLARE a, b INT; + DECLARE cur1 CURSOR FOR SELECT pk FROM t1; + DELETE FROM t1; + INSERT INTO t1 VALUES (1), (2); + OPEN cur1; + FETCH cur1 INTO a; + FETCH cur1 INTO b; + CLOSE cur1; + SELECT a, b; END;`, }, Assertions: []queries.ScriptTestAssertion{ { - Query: "CALL p1(0)", - ExpectedErrStr: "warnings not yet implemented", - }, - { - Query: "CALL p1(1)", - ExpectedErrStr: "A custom error occurred 1 (errno 1644) (sqlstate 45000)", - }, - { - Query: "CALL p1(2)", - ExpectedErrStr: "A custom error occurred 2 (errno 1002) (sqlstate 45000)", - }, - { - Query: "CALL p1(3)", - ExpectedErrStr: "Unhandled user-defined exception condition (errno 1644) (sqlstate 45000)", - }, - { - Query: "CALL p1(4)", - ExpectedErrStr: "Unhandled user-defined not found condition (errno 1643) (sqlstate 02000)", + Query: "CALL p1();", + Expected: []sql.Row{ + {1, 2}, + }, }, }, }, From 250128078c750e7fd44c330e9000b51267215f4d Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 20 Mar 2025 14:27:23 -0700 Subject: [PATCH 044/111] replace limit vals --- sql/planbuilder/select.go | 3 + sql/procedures/interpreter_logic.go | 185 +++++++++++++++++++----- sql/procedures/interpreter_operation.go | 9 +- sql/procedures/interpreter_stack.go | 48 ++++-- sql/procedures/parse.go | 21 +++ 5 files changed, 213 insertions(+), 53 deletions(-) diff --git a/sql/planbuilder/select.go b/sql/planbuilder/select.go index 92bf0cafd6..a3dfb9c82d 100644 --- a/sql/planbuilder/select.go +++ b/sql/planbuilder/select.go @@ -155,6 +155,9 @@ func (b *Builder) buildOffset(inScope *scope, limit *ast.Limit) sql.Expression { func (b *Builder) buildLimitVal(inScope *scope, e ast.Expr) sql.Expression { switch e := e.(type) { case *ast.ColName: + if e.StoredProcVal != nil { + return b.buildLimitVal(inScope, e.StoredProcVal) + } if inScope.procActive() { if col, ok := inScope.proc.GetVar(e.String()); ok { // proc param is OK diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index be18f6bba4..ed373a1ab7 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -16,15 +16,15 @@ package procedures import ( "fmt" - "io" + "github.com/dolthub/vitess/go/mysql" +"io" "strconv" "strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" - "github.com/dolthub/vitess/go/mysql" - ast "github.com/dolthub/vitess/go/vt/sqlparser" + ast "github.com/dolthub/vitess/go/vt/sqlparser" ) // InterpreterNode is an interface that implements an interpreter. These are typically used for functions (which may be @@ -121,6 +121,21 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e.Params[i] = newExpr.(ast.Expr) } + case *ast.Limit: + newOffset, err := replaceVariablesInExpr(stack, e.Offset) + if err != nil { + return nil, err + } + newRowCount, err := replaceVariablesInExpr(stack, e.Rowcount) + if err != nil { + return nil, err + } + if newOffset != nil { + e.Offset = newOffset.(ast.Expr) + } + if newRowCount != nil { + e.Rowcount = newRowCount.(ast.Expr) + } case *ast.Into: // TODO: somehow support select into variables for i := range e.Variables { @@ -152,6 +167,13 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e.Where.Expr = newExpr.(ast.Expr) } + if e.Limit != nil { + newExpr, err := replaceVariablesInExpr(stack, e.Limit) + if err != nil { + return nil, err + } + e.Limit = newExpr.(*ast.Limit) + } case *ast.Subquery: newExpr, err := replaceVariablesInExpr(stack, e.Select) if err != nil { @@ -265,7 +287,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I selectInto := selectStmt.Into selectStmt.Into = nil - _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) + schema, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) if err != nil { return nil, nil, err } @@ -285,8 +307,10 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I for i := range selectInto.Variables { intoVar := strings.ToLower(selectInto.Variables[i].String()) if strings.HasPrefix(intoVar, "@") { - // TODO - continue + err = ctx.SetUserVariable(ctx, intoVar, row[i], schema[i].Type) + if err != nil { + return nil, nil, err + } } err = stack.SetVariable(intoVar, row[i]) if err != nil { @@ -297,6 +321,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I case OpCode_Declare: declareStmt := operation.PrimaryData.(*ast.Declare) + // TODO: duplicate conditions? if declareStmt.Condition != nil { cond := declareStmt.Condition condName := strings.ToLower(cond.Name) @@ -321,6 +346,18 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I stack.NewCondition(condName, stateVal, num) } + // TODO: duplicate cursors? + if declareStmt.Cursor != nil { + cursor := declareStmt.Cursor + cursorName := strings.ToLower(cursor.Name) + stack.NewCursor(cursorName, cursor.SelectStmt) + } + + if declareStmt.Handler != nil { + // TODO + } + + // TODO: duplicate variables? if declareStmt.Variables != nil { for _, decl := range declareStmt.Variables.Names { varType, err := types.ColumnTypeToType(&declareStmt.Variables.VarType) @@ -426,6 +463,71 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I return nil, nil, mysql.NewSQLError(mysqlErrNo, sqlState, msgTxt) + case OpCode_Open: + openCur := operation.PrimaryData.(*ast.OpenCursor) + cursor := stack.GetCursor(strings.ToLower(openCur.Name)) + if cursor == nil { + return nil, nil, sql.ErrCursorNotFound.New(openCur.Name) + } + if cursor.RowIter != nil { + return nil, nil, sql.ErrCursorAlreadyOpen.New(openCur.Name) + } + stmt, err := replaceVariablesInExpr(stack, cursor.SelectStmt) + if err != nil { + return nil, nil, err + } + schema, rowIter, _, err := runner.QueryWithBindings(ctx, "", stmt.(ast.Statement), nil, nil) + if err != nil { + return nil, nil, err + } + cursor.Schema = schema + cursor.RowIter = rowIter + + case OpCode_Fetch: + fetchCur := operation.PrimaryData.(*ast.FetchCursor) + cursor := stack.GetCursor(strings.ToLower(fetchCur.Name)) + if cursor == nil { + return nil, nil, sql.ErrCursorNotFound.New(fetchCur.Name) + } + if cursor.RowIter == nil { + return nil, nil, sql.ErrCursorNotOpen.New(fetchCur.Name) + } + row, err := cursor.RowIter.Next(ctx) + if err != nil { + return nil, nil, err + } + if len(row) != len(fetchCur.Variables) { + return nil, nil, sql.ErrFetchIncorrectCount.New() + } + for i := range fetchCur.Variables { + varName := strings.ToLower(fetchCur.Variables[i]) + if strings.HasPrefix(varName, "@") { + err = ctx.SetUserVariable(ctx, varName, row[i], cursor.Schema[i].Type) + if err != nil { + return nil, nil, err + } + continue + } + err = stack.SetVariable(varName, row[i]) + if err != nil { + return nil, nil, err + } + } + + case OpCode_Close: + closeCur := operation.PrimaryData.(*ast.CloseCursor) + cursor := stack.GetCursor(strings.ToLower(closeCur.Name)) + if cursor == nil { + return nil, nil, sql.ErrCursorNotFound.New(closeCur.Name) + } + if cursor.RowIter == nil { + return nil, nil, sql.ErrCursorNotOpen.New(closeCur.Name) + } + if err := cursor.RowIter.Close(ctx); err != nil { + return nil, nil, err + } + cursor.RowIter = nil + case OpCode_Set: selectStmt := operation.PrimaryData.(*ast.Select) if selectStmt.SelectExprs == nil { @@ -458,16 +560,42 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I return nil, nil, err } - case OpCode_Execute: - stmt, err := replaceVariablesInExpr(stack, operation.PrimaryData) + case OpCode_If: + selectStmt := operation.PrimaryData.(*ast.Select) + if selectStmt.SelectExprs == nil { + panic("select stmt with no select exprs") + } + for i := range selectStmt.SelectExprs { + newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) + if err != nil { + return nil, nil, err + } + selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) + } + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) if err != nil { return nil, nil, err } - rowIter, err := query(ctx, runner, stmt.(ast.Statement)) + // TODO: exactly one result that is a bool for now + row, err := rowIter.Next(ctx) if err != nil { return nil, nil, err } - rowIters = append(rowIters, rowIter) + if _, err = rowIter.Next(ctx); err != io.EOF { + return nil, nil, err + } + if err = rowIter.Close(ctx); err != nil { + return nil, nil, err + } + + // go to the appropriate block + cond, _, err := types.Boolean.Convert(row[0]) + if err != nil { + return nil, nil, err + } + if cond == nil || cond.(int8) == 0 { + counter = operation.Index - 1 // index of the else block, offset by 1 + } case OpCode_Goto: // We must compare to the index - 1, so that the increment hits our target @@ -495,49 +623,26 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I } } - case OpCode_If: - selectStmt := operation.PrimaryData.(*ast.Select) - if selectStmt.SelectExprs == nil { - panic("select stmt with no select exprs") - } - for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) - if err != nil { - return nil, nil, err - } - selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) - } - _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) - if err != nil { - return nil, nil, err - } - // TODO: exactly one result that is a bool for now - row, err := rowIter.Next(ctx) + case OpCode_Execute: + stmt, err := replaceVariablesInExpr(stack, operation.PrimaryData) if err != nil { return nil, nil, err } - if _, err = rowIter.Next(ctx); err != io.EOF { - return nil, nil, err - } - if err = rowIter.Close(ctx); err != nil { - return nil, nil, err - } - - // go to the appropriate block - cond, _, err := types.Boolean.Convert(row[0]) + rowIter, err := query(ctx, runner, stmt.(ast.Statement)) if err != nil { return nil, nil, err } - if cond == nil || cond.(int8) == 0 { - counter = operation.Index - 1 // index of the else block, offset by 1 - } + rowIters = append(rowIters, rowIter) case OpCode_Exception: return nil, nil, operation.Error + case OpCode_ScopeBegin: stack.PushScope() + case OpCode_ScopeEnd: stack.PopScope() + default: panic("unimplemented opcode") } diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index ce0443d9c3..6d2b9925ea 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -20,11 +20,14 @@ const ( OpCode_Select OpCode = iota OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html OpCode_Signal + OpCode_Open + OpCode_Fetch + OpCode_Close OpCode_Set - OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING - OpCode_Execute // Everything that's not a SELECT - OpCode_Goto // All control-flow structures can be represented using Goto OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS + OpCode_Goto // All control-flow structures can be represented using Goto + OpCode_Execute // Everything that's not a SELECT + OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING OpCode_ScopeBegin // This is used for scope control, specific to Doltgres OpCode_ScopeEnd // This is used for scope control, specific to Doltgres diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index f1ec8cb5ac..d9980eb8a9 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -83,6 +83,19 @@ func (s *Stack[T]) Empty() bool { return len(s.values) == 0 } +// InterpreterCondition is a declare condition with custom SQLState and ErrorCode. +type InterpreterCondition struct { + SQLState string + MySQLErrCode int64 +} + +// InterpreterCursor is a declare condition with custom SQLState and ErrorCode. +type InterpreterCursor struct { + SelectStmt ast.SelectStatement + RowIter sql.RowIter + Schema sql.Schema +} + // InterpreterVariable is a variable that lives on the stack. type InterpreterVariable struct { Type sql.Type @@ -105,16 +118,11 @@ func (iv *InterpreterVariable) ToAST() ast.Expr { return ast.NewStrVal([]byte(fmt.Sprintf("%s", iv.Value))) } -// InterpreterCondition is a declare condition with custom SQLState and ErrorCode. -type InterpreterCondition struct { - SQLState string - MySQLErrCode int64 -} - // InterpreterScopeDetails contains all of the details that are relevant to a particular scope. type InterpreterScopeDetails struct { - variables map[string]*InterpreterVariable conditions map[string]*InterpreterCondition + cursors map[string]*InterpreterCursor + variables map[string]*InterpreterVariable } // InterpreterStack represents the working information that an interpreter will use during execution. It is not exactly @@ -129,8 +137,9 @@ func NewInterpreterStack() *InterpreterStack { stack := NewStack[*InterpreterScopeDetails]() // This first push represents the function base, including parameters stack.Push(&InterpreterScopeDetails{ - variables: make(map[string]*InterpreterVariable), conditions: make(map[string]*InterpreterCondition), + cursors: make(map[string]*InterpreterCursor), + variables: make(map[string]*InterpreterVariable), }) return &InterpreterStack{ stack: stack, @@ -184,7 +193,7 @@ func (is *InterpreterStack) NewVariableAlias(alias string, variable *Interpreter is.stack.Peek().variables[alias] = variable } -// NewCondition creates a new variable in the current scope. +// NewCondition creates a new condition in the current scope. func (is *InterpreterStack) NewCondition(name string, sqlState string, mysqlErrCode int64) { is.stack.Peek().conditions[name] = &InterpreterCondition{ SQLState: sqlState, @@ -203,11 +212,30 @@ func (is *InterpreterStack) GetCondition(name string) *InterpreterCondition { return nil } +// NewCursor creates a new cursor in the current scope. +func (is *InterpreterStack) NewCursor(name string, selStmt ast.SelectStatement) { + is.stack.Peek().cursors[name] = &InterpreterCursor{ + SelectStmt: selStmt, + } +} + +// GetCursor traverses the stack (starting from the top) to find a condition with a matching name. Returns nil if no +// variable was found. +func (is *InterpreterStack) GetCursor(name string) *InterpreterCursor { + for i := 0; i < is.stack.Len(); i++ { + if ic, ok := is.stack.PeekDepth(i).cursors[name]; ok { + return ic + } + } + return nil +} + // PushScope creates a new scope. func (is *InterpreterStack) PushScope() { is.stack.Push(&InterpreterScopeDetails{ - variables: make(map[string]*InterpreterVariable), conditions: make(map[string]*InterpreterCondition), + cursors: make(map[string]*InterpreterCursor), + variables: make(map[string]*InterpreterVariable), }) } diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 917a004c72..ddd39ba029 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -55,6 +55,27 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } *ops = append(*ops, declareOp) + case *ast.OpenCursor: + openOp := &InterpreterOperation{ + OpCode: OpCode_Open, + PrimaryData: s, + } + *ops = append(*ops, openOp) + + case *ast.FetchCursor: + fetchOp := &InterpreterOperation{ + OpCode: OpCode_Fetch, + PrimaryData: s, + } + *ops = append(*ops, fetchOp) + + case *ast.CloseCursor: + closeOp := &InterpreterOperation{ + OpCode: OpCode_Close, + PrimaryData: s, + } + *ops = append(*ops, closeOp) + case *ast.Signal: signalOp := &InterpreterOperation{ OpCode: OpCode_Signal, From 227e6c3e5e534444da93779476a09edd69f1708f Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 20 Mar 2025 16:05:06 -0700 Subject: [PATCH 045/111] fix declare logic --- sql/procedures/parse.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index ddd39ba029..bf0ac93843 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -239,8 +239,14 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast whileOp.Index = len(*ops) // end of while block case *ast.Repeat: - loopStart := len(*ops) + // repeat statements always run at least once + for _, repeatStmt := range s.Statements { + if err := ConvertStmt(ops, stack, repeatStmt); err != nil { + return err + } + } + loopStart := len(*ops) repeatCond := &ast.NotExpr{Expr: s.Condition} selectCond := &ast.Select{ SelectExprs: ast.SelectExprs{ From 7e618d8e2c5c33e30761f512f9d2c41f47ad7880 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 20 Mar 2025 17:50:05 -0700 Subject: [PATCH 046/111] handlers --- sql/planbuilder/proc.go | 6 ++--- sql/procedures/interpreter_logic.go | 42 ++++++++++++++++++++--------- sql/procedures/interpreter_stack.go | 33 ++++++++++++++++++++++- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 3c15e4f5af..c8fb549cb7 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -71,6 +71,7 @@ func (p *procCtx) NewState(state declareState) { err := sql.ErrDeclareCursorOrderInvalid.New() p.s.b.handleErr(err) } + default: } p.lastState = state } @@ -464,14 +465,11 @@ func (b *Builder) buildDeclareHandler(inScope *scope, d *ast.Declare, query stri action = expression.DeclareHandlerAction_Exit case ast.DeclareHandlerAction_Undo: action = expression.DeclareHandlerAction_Undo + b.handleErr(sql.ErrDeclareHandlerUndo.New()) default: err := fmt.Errorf("unknown DECLARE ... HANDLER action: %v", dHandler.Action) b.handleErr(err) } - if action == expression.DeclareHandlerAction_Undo { - err := sql.ErrDeclareHandlerUndo.New() - b.handleErr(err) - } handler := &plan.DeclareHandler{ Action: action, diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index ed373a1ab7..7c5d24c675 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -322,8 +322,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I declareStmt := operation.PrimaryData.(*ast.Declare) // TODO: duplicate conditions? - if declareStmt.Condition != nil { - cond := declareStmt.Condition + if cond := declareStmt.Condition; cond != nil { condName := strings.ToLower(cond.Name) stateVal := cond.SqlStateValue var num int64 @@ -347,29 +346,48 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *I } // TODO: duplicate cursors? - if declareStmt.Cursor != nil { - cursor := declareStmt.Cursor + if cursor := declareStmt.Cursor; cursor != nil { cursorName := strings.ToLower(cursor.Name) stack.NewCursor(cursorName, cursor.SelectStmt) } - if declareStmt.Handler != nil { - // TODO + // TODO: duplicate handlers? + if handler := declareStmt.Handler; handler != nil { + if len(handler.ConditionValues) != 1 { + return nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) + } + + hCond := handler.ConditionValues[0] + switch hCond.ValueType { + case ast.DeclareHandlerCondition_NotFound: + case ast.DeclareHandlerCondition_SqlState: + default: + return nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) + } + + switch handler.Action { + case ast.DeclareHandlerAction_Continue: + case ast.DeclareHandlerAction_Exit: + case ast.DeclareHandlerAction_Undo: + return nil, nil, fmt.Errorf("unsupported handler action: %s", handler.Action) + } + + stack.NewHandler(string(hCond.ValueType), string(handler.Action), handler.Statement) } // TODO: duplicate variables? - if declareStmt.Variables != nil { - for _, decl := range declareStmt.Variables.Names { - varType, err := types.ColumnTypeToType(&declareStmt.Variables.VarType) + if vars := declareStmt.Variables; vars != nil { + for _, decl := range vars.Names { + varType, err := types.ColumnTypeToType(&vars.VarType) if err != nil { return nil, nil, err } varName := strings.ToLower(decl.String()) - if declareStmt.Variables.VarType.Default != nil { - stack.NewVariableWithValue(varName, varType, declareStmt.Variables.VarType.Default) - } else { + if vars.VarType.Default == nil { stack.NewVariable(varName, varType) + continue } + stack.NewVariableWithValue(varName, varType, vars.VarType.Default) } } diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index d9980eb8a9..afc2bfbcc7 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -89,13 +89,20 @@ type InterpreterCondition struct { MySQLErrCode int64 } -// InterpreterCursor is a declare condition with custom SQLState and ErrorCode. +// InterpreterCursor is a declare cursor. type InterpreterCursor struct { SelectStmt ast.SelectStatement RowIter sql.RowIter Schema sql.Schema } +// InterpreterHandler is a declare handler that specifies an Action during an error Condition. +type InterpreterHandler struct { + Condition string + Action string + Statement ast.Statement +} + // InterpreterVariable is a variable that lives on the stack. type InterpreterVariable struct { Type sql.Type @@ -122,6 +129,7 @@ func (iv *InterpreterVariable) ToAST() ast.Expr { type InterpreterScopeDetails struct { conditions map[string]*InterpreterCondition cursors map[string]*InterpreterCursor + handlers []*InterpreterHandler variables map[string]*InterpreterVariable } @@ -139,6 +147,7 @@ func NewInterpreterStack() *InterpreterStack { stack.Push(&InterpreterScopeDetails{ conditions: make(map[string]*InterpreterCondition), cursors: make(map[string]*InterpreterCursor), + handlers: make([]*InterpreterHandler), variables: make(map[string]*InterpreterVariable), }) return &InterpreterStack{ @@ -212,6 +221,8 @@ func (is *InterpreterStack) GetCondition(name string) *InterpreterCondition { return nil } + + // NewCursor creates a new cursor in the current scope. func (is *InterpreterStack) NewCursor(name string, selStmt ast.SelectStatement) { is.stack.Peek().cursors[name] = &InterpreterCursor{ @@ -230,6 +241,26 @@ func (is *InterpreterStack) GetCursor(name string) *InterpreterCursor { return nil } +// NewHandler creates a new handler in the current scope. +func (is *InterpreterStack) NewHandler(cond string, action string, stmt ast.Statement) { + is.stack.Peek().handlers = append(is.stack.Peek().handlers, &InterpreterHandler{ + Condition: cond, + Action: action, + Statement: stmt, + }) +} + +// ListHandlers returns a map with the names of all handlers. +func (is *InterpreterStack) ListHandlers() []*InterpreterHandler { + handlers := make([]*InterpreterHandler, 0) + for i := 0; i < is.stack.Len(); i++ { + for _, handler := range is.stack.PeekDepth(i).handlers { + handlers = append(handlers, handler) + } + } + return handlers +} + // PushScope creates a new scope. func (is *InterpreterStack) PushScope() { is.stack.Push(&InterpreterScopeDetails{ From 909d7dfc4369bf03d069393b220289d79a47d190 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 21 Mar 2025 02:26:52 -0700 Subject: [PATCH 047/111] fixing some handlers --- enginetest/queries/procedure_queries.go | 6 +- sql/procedures/interpreter_logic.go | 810 +++++++++++++----------- sql/procedures/interpreter_stack.go | 8 +- sql/rowexec/proc.go | 2 +- 4 files changed, 454 insertions(+), 372 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 1a3769d016..b84570bce4 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -1269,11 +1269,11 @@ END;`, `CREATE PROCEDURE duplicate_key() BEGIN DECLARE a, b INT DEFAULT 1; - BEGIN + BEGIN DECLARE EXIT HANDLER FOR SQLEXCEPTION SET a = 7; INSERT INTO t1 values (0); - END; - SELECT a; + END; + SELECT a; END;`, }, Assertions: []ScriptTestAssertion{ diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 7c5d24c675..e890f56acb 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -15,16 +15,18 @@ package procedures import ( + "errors" "fmt" - "github.com/dolthub/vitess/go/mysql" -"io" + "io" "strconv" "strings" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" - ast "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/vitess/go/mysql" + ast "github.com/dolthub/vitess/go/vt/sqlparser" ) // InterpreterNode is an interface that implements an interpreter. These are typically used for functions (which may be @@ -239,431 +241,511 @@ func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sq return sql.RowsToRowIter(rows...), nil } -// Call runs the contained operations on the given runner. -func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (any, *InterpreterStack, error) { - // Set up the initial state of the function - counter := -1 // We increment before accessing, so start at -1 - stack := NewInterpreterStack() - for _, param := range params { - stack.NewVariableWithValue(param.Name, param.Type, param.Value) +// handleError handles errors that occur during the execution of a procedure according to the defined handlers. +func handleError(ctx *sql.Context, stack *InterpreterStack, runner sql.StatementRunner, err error) error { + // TODO: just copy logic from expression/procedurereference.go + if err == nil { + return nil } - // TODO: remove this; track last selectRowIter - var selIter sql.RowIter - - // Run the statements - // TODO: eventually return multiple sql.RowIters - var rowIters []sql.RowIter - runner := iNode.GetRunner() - statements := iNode.GetStatements() - for { - counter++ - if counter < 0 { - panic("negative function counter") + var matchingHandler *InterpreterHandler + for _, handler := range stack.ListHandlers() { + if errors.Is(err, expression.FetchEOF) && handler.Condition == ast.DeclareHandlerCondition_NotFound { + matchingHandler = handler + break } - if counter >= len(statements) { + switch handler.Condition { + case ast.DeclareHandlerCondition_MysqlErrorCode: + case ast.DeclareHandlerCondition_SqlState: + case ast.DeclareHandlerCondition_ConditionName: + case ast.DeclareHandlerCondition_SqlWarning: + case ast.DeclareHandlerCondition_NotFound: + case ast.DeclareHandlerCondition_SqlException: + matchingHandler = handler break } + } - operation := statements[counter] - switch operation.OpCode { - case OpCode_Select: - selectStmt := operation.PrimaryData.(*ast.Select) - if newSelectStmt, err := replaceVariablesInExpr(stack, selectStmt); err == nil { - selectStmt = newSelectStmt.(*ast.Select) - } else { - return nil, nil, err - } + if matchingHandler == nil { + return err + } - if selectStmt.Into == nil { - rowIter, err := query(ctx, runner, selectStmt) - if err != nil { - return nil, nil, err - } - rowIters = append(rowIters, rowIter) - selIter = rowIter - continue - } + handlerOps := make([]*InterpreterOperation, 0, 1) + err = ConvertStmt(&handlerOps, stack, matchingHandler.Statement) + if err != nil { + return err + } - selectInto := selectStmt.Into - selectStmt.Into = nil - schema, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) + _, _, rowIter, err := execOp(ctx, runner, stack, handlerOps[0], handlerOps, -1) + if err != nil { + return err + } + if rowIter != nil { + for { + _, err = rowIter.Next(ctx) if err != nil { - return nil, nil, err + return err } - row, err := rowIter.Next(ctx) + } + } + + switch matchingHandler.Action { + case ast.DeclareHandlerAction_Continue: + return nil + case ast.DeclareHandlerAction_Exit: + return io.EOF + case ast.DeclareHandlerAction_Undo: + return fmt.Errorf("DECLARE UNDO HANDLER is not supported") + } + + return nil +} + +func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStack, operation *InterpreterOperation, statements []*InterpreterOperation, counter int) (int, sql.RowIter, sql.RowIter, error) { + switch operation.OpCode { + case OpCode_Select: + selectStmt := operation.PrimaryData.(*ast.Select) + if newSelectStmt, err := replaceVariablesInExpr(stack, selectStmt); err == nil { + selectStmt = newSelectStmt.(*ast.Select) + } else { + return 0, nil, nil, err + } + + if selectStmt.Into == nil { + rowIter, err := query(ctx, runner, selectStmt) if err != nil { - return nil, nil, err + return 0, nil, nil, err } - if _, err = rowIter.Next(ctx); err != io.EOF { - return nil, nil, err - } - if err = rowIter.Close(ctx); err != nil { - return nil, nil, err + return counter, rowIter, rowIter, nil + } + + selectInto := selectStmt.Into + selectStmt.Into = nil + schema, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) + if err != nil { + return 0, nil, nil, err + } + row, err := rowIter.Next(ctx) + if err != nil { + return 0, nil, nil, err + } + if _, err = rowIter.Next(ctx); err != io.EOF { + return 0, nil, nil, err + } + if err = rowIter.Close(ctx); err != nil { + return 0, nil, nil, err + } + if len(row) != len(selectInto.Variables) { + return 0, nil, nil, sql.ErrColumnNumberDoesNotMatch.New() + } + for i := range selectInto.Variables { + intoVar := strings.ToLower(selectInto.Variables[i].String()) + if strings.HasPrefix(intoVar, "@") { + err = ctx.SetUserVariable(ctx, intoVar, row[i], schema[i].Type) + if err != nil { + return 0, nil, nil, err + } } - if len(row) != len(selectInto.Variables) { - return nil, nil, sql.ErrColumnNumberDoesNotMatch.New() + err = stack.SetVariable(intoVar, row[i]) + if err != nil { + return 0, nil, nil, err } - for i := range selectInto.Variables { - intoVar := strings.ToLower(selectInto.Variables[i].String()) - if strings.HasPrefix(intoVar, "@") { - err = ctx.SetUserVariable(ctx, intoVar, row[i], schema[i].Type) - if err != nil { - return nil, nil, err - } + } + + case OpCode_Declare: + declareStmt := operation.PrimaryData.(*ast.Declare) + + // TODO: duplicate conditions? + if cond := declareStmt.Condition; cond != nil { + condName := strings.ToLower(cond.Name) + stateVal := cond.SqlStateValue + var num int64 + var err error + if stateVal != "" { + if len(stateVal) != 5 { + return 0, nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") } - err = stack.SetVariable(intoVar, row[i]) - if err != nil { - return nil, nil, err + if stateVal[0:2] == "00" { + return 0, nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", stateVal) + } + } else { + // use our own error + num, err = strconv.ParseInt(string(cond.MysqlErrorCode.Val), 10, 64) + if err != nil || num == 0 { + err = fmt.Errorf("invalid value '%s' for MySQL error code", string(cond.MysqlErrorCode.Val)) + return 0, nil, nil, err } } + stack.NewCondition(condName, stateVal, num) + } - case OpCode_Declare: - declareStmt := operation.PrimaryData.(*ast.Declare) + // TODO: duplicate cursors? + if cursor := declareStmt.Cursor; cursor != nil { + cursorName := strings.ToLower(cursor.Name) + stack.NewCursor(cursorName, cursor.SelectStmt) + } - // TODO: duplicate conditions? - if cond := declareStmt.Condition; cond != nil { - condName := strings.ToLower(cond.Name) - stateVal := cond.SqlStateValue - var num int64 - var err error - if stateVal != "" { - if len(stateVal) != 5 { - return nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") - } - if stateVal[0:2] == "00" { - return nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", stateVal) - } - } else { - // use our own error - num, err = strconv.ParseInt(string(cond.MysqlErrorCode.Val), 10, 64) - if err != nil || num == 0 { - err = fmt.Errorf("invalid value '%s' for MySQL error code", string(cond.MysqlErrorCode.Val)) - return nil, nil, err - } - } - stack.NewCondition(condName, stateVal, num) + // TODO: duplicate handlers? + if handler := declareStmt.Handler; handler != nil { + if len(handler.ConditionValues) != 1 { + return 0, nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) } - // TODO: duplicate cursors? - if cursor := declareStmt.Cursor; cursor != nil { - cursorName := strings.ToLower(cursor.Name) - stack.NewCursor(cursorName, cursor.SelectStmt) + hCond := handler.ConditionValues[0] + switch hCond.ValueType { + case ast.DeclareHandlerCondition_NotFound: + case ast.DeclareHandlerCondition_SqlException: + default: + return 0, nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) } - // TODO: duplicate handlers? - if handler := declareStmt.Handler; handler != nil { - if len(handler.ConditionValues) != 1 { - return nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) - } + switch handler.Action { + case ast.DeclareHandlerAction_Continue: + case ast.DeclareHandlerAction_Exit: + case ast.DeclareHandlerAction_Undo: + return 0, nil, nil, fmt.Errorf("unsupported handler action: %s", handler.Action) + } - hCond := handler.ConditionValues[0] - switch hCond.ValueType { - case ast.DeclareHandlerCondition_NotFound: - case ast.DeclareHandlerCondition_SqlState: - default: - return nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) - } + stack.NewHandler(hCond.ValueType, handler.Action, handler.Statement) + } - switch handler.Action { - case ast.DeclareHandlerAction_Continue: - case ast.DeclareHandlerAction_Exit: - case ast.DeclareHandlerAction_Undo: - return nil, nil, fmt.Errorf("unsupported handler action: %s", handler.Action) + // TODO: duplicate variables? + if vars := declareStmt.Variables; vars != nil { + for _, decl := range vars.Names { + varType, err := types.ColumnTypeToType(&vars.VarType) + if err != nil { + return 0, nil, nil, err } - - stack.NewHandler(string(hCond.ValueType), string(handler.Action), handler.Statement) - } - - // TODO: duplicate variables? - if vars := declareStmt.Variables; vars != nil { - for _, decl := range vars.Names { - varType, err := types.ColumnTypeToType(&vars.VarType) - if err != nil { - return nil, nil, err - } - varName := strings.ToLower(decl.String()) - if vars.VarType.Default == nil { - stack.NewVariable(varName, varType) - continue - } - stack.NewVariableWithValue(varName, varType, vars.VarType.Default) + varName := strings.ToLower(decl.String()) + if vars.VarType.Default == nil { + stack.NewVariable(varName, varType) + continue } + stack.NewVariableWithValue(varName, varType, vars.VarType.Default) } + } - case OpCode_Signal: - // TODO: copy logic from planbuilder/proc.go: buildSignal() - signalStmt := operation.PrimaryData.(*ast.Signal) - var msgTxt string - var sqlState string - var mysqlErrNo int - if signalStmt.ConditionName == "" { - sqlState = signalStmt.SqlStateValue - if sqlState[0:2] == "01" { - return nil, nil, fmt.Errorf("warnings not yet implemented") - } - } else { - cond := stack.GetCondition(strings.ToLower(signalStmt.ConditionName)) - if cond == nil { - return nil, nil, sql.ErrDeclareConditionNotFound.New(signalStmt.ConditionName) - } - sqlState = cond.SQLState - mysqlErrNo = int(cond.MySQLErrCode) - } - - if len(sqlState) != 5 { - return nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") - } - - for _, item := range signalStmt.Info { - switch item.ConditionItemName { - case ast.SignalConditionItemName_MysqlErrno: - switch val := item.Value.(type) { - case *ast.SQLVal: - num, err := strconv.ParseInt(string(val.Val), 10, 64) - if err != nil || num == 0 { - return nil, nil, fmt.Errorf("invalid value '%s' for MySQL error code", string(val.Val)) - } - mysqlErrNo = int(num) - case *ast.ColName: - return nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) - default: - return nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) - } - case ast.SignalConditionItemName_MessageText: - switch val := item.Value.(type) { - case *ast.SQLVal: - msgTxt = string(val.Val) - if len(msgTxt) > 128 { - return nil, nil, fmt.Errorf("signal condition information item MESSAGE_TEXT has max length of 128") - } - case *ast.ColName: - return nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) - default: - return nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) + case OpCode_Signal: + // TODO: copy logic from planbuilder/proc.go: buildSignal() + signalStmt := operation.PrimaryData.(*ast.Signal) + var msgTxt string + var sqlState string + var mysqlErrNo int + if signalStmt.ConditionName == "" { + sqlState = signalStmt.SqlStateValue + if sqlState[0:2] == "01" { + return 0, nil, nil, fmt.Errorf("warnings not yet implemented") + } + } else { + cond := stack.GetCondition(strings.ToLower(signalStmt.ConditionName)) + if cond == nil { + return 0, nil, nil, sql.ErrDeclareConditionNotFound.New(signalStmt.ConditionName) + } + sqlState = cond.SQLState + mysqlErrNo = int(cond.MySQLErrCode) + } + + if len(sqlState) != 5 { + return 0, nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") + } + + for _, item := range signalStmt.Info { + switch item.ConditionItemName { + case ast.SignalConditionItemName_MysqlErrno: + switch val := item.Value.(type) { + case *ast.SQLVal: + num, err := strconv.ParseInt(string(val.Val), 10, 64) + if err != nil || num == 0 { + return 0, nil, nil, fmt.Errorf("invalid value '%s' for MySQL error code", string(val.Val)) } + mysqlErrNo = int(num) + case *ast.ColName: + return 0, nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) default: - switch val := item.Value.(type) { - case *ast.SQLVal: - msgTxt = string(val.Val) - if len(msgTxt) > 64 { - return nil, nil, fmt.Errorf("signal condition information item %s has max length of 64", strings.ToUpper(string(item.ConditionItemName))) - } - default: - return nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item '%s''", item.Value, strings.ToUpper(string(item.ConditionItemName))) - } + return 0, nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) } - } - - if mysqlErrNo == 0 { - switch sqlState[0:2] { - case "01": - mysqlErrNo = 1642 - case "02": - mysqlErrNo = 1643 + case ast.SignalConditionItemName_MessageText: + switch val := item.Value.(type) { + case *ast.SQLVal: + msgTxt = string(val.Val) + if len(msgTxt) > 128 { + return 0, nil, nil, fmt.Errorf("signal condition information item MESSAGE_TEXT has max length of 128") + } + case *ast.ColName: + return 0, nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) default: - mysqlErrNo = 1644 + return 0, nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) } - } - - if msgTxt == "" { - switch sqlState[0:2] { - case "00": - return nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", sqlState) - case "01": - msgTxt = "Unhandled user-defined warning condition" - case "02": - msgTxt = "Unhandled user-defined not found condition" + default: + switch val := item.Value.(type) { + case *ast.SQLVal: + msgTxt = string(val.Val) + if len(msgTxt) > 64 { + return 0, nil, nil, fmt.Errorf("signal condition information item %s has max length of 64", strings.ToUpper(string(item.ConditionItemName))) + } default: - msgTxt = "Unhandled user-defined exception condition" + return 0, nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item '%s''", item.Value, strings.ToUpper(string(item.ConditionItemName))) } } + } - return nil, nil, mysql.NewSQLError(mysqlErrNo, sqlState, msgTxt) - - case OpCode_Open: - openCur := operation.PrimaryData.(*ast.OpenCursor) - cursor := stack.GetCursor(strings.ToLower(openCur.Name)) - if cursor == nil { - return nil, nil, sql.ErrCursorNotFound.New(openCur.Name) - } - if cursor.RowIter != nil { - return nil, nil, sql.ErrCursorAlreadyOpen.New(openCur.Name) - } - stmt, err := replaceVariablesInExpr(stack, cursor.SelectStmt) - if err != nil { - return nil, nil, err + if mysqlErrNo == 0 { + switch sqlState[0:2] { + case "01": + mysqlErrNo = 1642 + case "02": + mysqlErrNo = 1643 + default: + mysqlErrNo = 1644 } - schema, rowIter, _, err := runner.QueryWithBindings(ctx, "", stmt.(ast.Statement), nil, nil) - if err != nil { - return nil, nil, err - } - cursor.Schema = schema - cursor.RowIter = rowIter + } - case OpCode_Fetch: - fetchCur := operation.PrimaryData.(*ast.FetchCursor) - cursor := stack.GetCursor(strings.ToLower(fetchCur.Name)) - if cursor == nil { - return nil, nil, sql.ErrCursorNotFound.New(fetchCur.Name) - } - if cursor.RowIter == nil { - return nil, nil, sql.ErrCursorNotOpen.New(fetchCur.Name) - } - row, err := cursor.RowIter.Next(ctx) - if err != nil { - return nil, nil, err - } - if len(row) != len(fetchCur.Variables) { - return nil, nil, sql.ErrFetchIncorrectCount.New() - } - for i := range fetchCur.Variables { - varName := strings.ToLower(fetchCur.Variables[i]) - if strings.HasPrefix(varName, "@") { - err = ctx.SetUserVariable(ctx, varName, row[i], cursor.Schema[i].Type) - if err != nil { - return nil, nil, err - } - continue - } - err = stack.SetVariable(varName, row[i]) - if err != nil { - return nil, nil, err - } + if msgTxt == "" { + switch sqlState[0:2] { + case "00": + return 0, nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", sqlState) + case "01": + msgTxt = "Unhandled user-defined warning condition" + case "02": + msgTxt = "Unhandled user-defined not found condition" + default: + msgTxt = "Unhandled user-defined exception condition" } + } - case OpCode_Close: - closeCur := operation.PrimaryData.(*ast.CloseCursor) - cursor := stack.GetCursor(strings.ToLower(closeCur.Name)) - if cursor == nil { - return nil, nil, sql.ErrCursorNotFound.New(closeCur.Name) - } - if cursor.RowIter == nil { - return nil, nil, sql.ErrCursorNotOpen.New(closeCur.Name) - } - if err := cursor.RowIter.Close(ctx); err != nil { - return nil, nil, err - } - cursor.RowIter = nil + return 0, nil, nil, mysql.NewSQLError(mysqlErrNo, sqlState, msgTxt) - case OpCode_Set: - selectStmt := operation.PrimaryData.(*ast.Select) - if selectStmt.SelectExprs == nil { - panic("select stmt with no select exprs") - } - for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) + case OpCode_Open: + openCur := operation.PrimaryData.(*ast.OpenCursor) + cursor := stack.GetCursor(strings.ToLower(openCur.Name)) + if cursor == nil { + return 0, nil, nil, sql.ErrCursorNotFound.New(openCur.Name) + } + if cursor.RowIter != nil { + return 0, nil, nil, sql.ErrCursorAlreadyOpen.New(openCur.Name) + } + stmt, err := replaceVariablesInExpr(stack, cursor.SelectStmt) + if err != nil { + return 0, nil, nil, err + } + schema, rowIter, _, err := runner.QueryWithBindings(ctx, "", stmt.(ast.Statement), nil, nil) + if err != nil { + return 0, nil, nil, err + } + cursor.Schema = schema + cursor.RowIter = rowIter + + case OpCode_Fetch: + fetchCur := operation.PrimaryData.(*ast.FetchCursor) + cursor := stack.GetCursor(strings.ToLower(fetchCur.Name)) + if cursor == nil { + return 0, nil, nil, sql.ErrCursorNotFound.New(fetchCur.Name) + } + if cursor.RowIter == nil { + return 0, nil, nil, sql.ErrCursorNotOpen.New(fetchCur.Name) + } + row, err := cursor.RowIter.Next(ctx) + if err != nil { + return 0, nil, nil, err + } + if len(row) != len(fetchCur.Variables) { + return 0, nil, nil, sql.ErrFetchIncorrectCount.New() + } + for i := range fetchCur.Variables { + varName := strings.ToLower(fetchCur.Variables[i]) + if strings.HasPrefix(varName, "@") { + err = ctx.SetUserVariable(ctx, varName, row[i], cursor.Schema[i].Type) if err != nil { - return nil, nil, err + return 0, nil, nil, err } - selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) - } - _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) - if err != nil { - return nil, nil, err + continue } - row, err := rowIter.Next(ctx) + err = stack.SetVariable(varName, row[i]) if err != nil { - return nil, nil, err - } - if _, err = rowIter.Next(ctx); err != io.EOF { - return nil, nil, err - } - if err = rowIter.Close(ctx); err != nil { - return nil, nil, err + return 0, nil, nil, err } + } - err = stack.SetVariable(strings.ToLower(operation.Target), row[0]) - if err != nil { - return nil, nil, err - } + case OpCode_Close: + closeCur := operation.PrimaryData.(*ast.CloseCursor) + cursor := stack.GetCursor(strings.ToLower(closeCur.Name)) + if cursor == nil { + return 0, nil, nil, sql.ErrCursorNotFound.New(closeCur.Name) + } + if cursor.RowIter == nil { + return 0, nil, nil, sql.ErrCursorNotOpen.New(closeCur.Name) + } + if err := cursor.RowIter.Close(ctx); err != nil { + return 0, nil, nil, err + } + cursor.RowIter = nil - case OpCode_If: - selectStmt := operation.PrimaryData.(*ast.Select) - if selectStmt.SelectExprs == nil { - panic("select stmt with no select exprs") - } - for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) - if err != nil { - return nil, nil, err - } - selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) - } - _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) - if err != nil { - return nil, nil, err - } - // TODO: exactly one result that is a bool for now - row, err := rowIter.Next(ctx) + case OpCode_Set: + selectStmt := operation.PrimaryData.(*ast.Select) + if selectStmt.SelectExprs == nil { + panic("select stmt with no select exprs") + } + for i := range selectStmt.SelectExprs { + newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) if err != nil { - return nil, nil, err - } - if _, err = rowIter.Next(ctx); err != io.EOF { - return nil, nil, err - } - if err = rowIter.Close(ctx); err != nil { - return nil, nil, err + return 0, nil, nil, err } + selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) + } + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) + if err != nil { + return 0, nil, nil, err + } + row, err := rowIter.Next(ctx) + if err != nil { + return 0, nil, nil, err + } + if _, err = rowIter.Next(ctx); err != io.EOF { + return 0, nil, nil, err + } + if err = rowIter.Close(ctx); err != nil { + return 0, nil, nil, err + } - // go to the appropriate block - cond, _, err := types.Boolean.Convert(row[0]) + err = stack.SetVariable(strings.ToLower(operation.Target), row[0]) + if err != nil { + return 0, nil, nil, err + } + + case OpCode_If: + selectStmt := operation.PrimaryData.(*ast.Select) + if selectStmt.SelectExprs == nil { + panic("select stmt with no select exprs") + } + for i := range selectStmt.SelectExprs { + newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) if err != nil { - return nil, nil, err + return 0, nil, nil, err } - if cond == nil || cond.(int8) == 0 { - counter = operation.Index - 1 // index of the else block, offset by 1 - } - - case OpCode_Goto: - // We must compare to the index - 1, so that the increment hits our target - if counter <= operation.Index { - for ; counter < operation.Index-1; counter++ { - switch statements[counter].OpCode { - case OpCode_ScopeBegin: - stack.PushScope() - case OpCode_ScopeEnd: - stack.PopScope() - default: - // No-op - } + selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) + } + _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) + if err != nil { + return 0, nil, nil, err + } + // TODO: exactly one result that is a bool for now + row, err := rowIter.Next(ctx) + if err != nil { + return 0, nil, nil, err + } + if _, err = rowIter.Next(ctx); err != io.EOF { + return 0, nil, nil, err + } + if err = rowIter.Close(ctx); err != nil { + return 0, nil, nil, err + } + + // go to the appropriate block + cond, _, err := types.Boolean.Convert(row[0]) + if err != nil { + return 0, nil, nil, err + } + if cond == nil || cond.(int8) == 0 { + counter = operation.Index - 1 // index of the else block, offset by 1 + } + + case OpCode_Goto: + // We must compare to the index - 1, so that the increment hits our target + if counter <= operation.Index { + for ; counter < operation.Index-1; counter++ { + switch statements[counter].OpCode { + case OpCode_ScopeBegin: + stack.PushScope() + case OpCode_ScopeEnd: + stack.PopScope() + default: + // No-op } - } else { - for ; counter > operation.Index-1; counter-- { - switch statements[counter].OpCode { - case OpCode_ScopeBegin: - stack.PopScope() - case OpCode_ScopeEnd: - stack.PushScope() - default: - // No-op - } + } + } else { + for ; counter > operation.Index-1; counter-- { + switch statements[counter].OpCode { + case OpCode_ScopeBegin: + stack.PopScope() + case OpCode_ScopeEnd: + stack.PushScope() + default: + // No-op } } + } - case OpCode_Execute: - stmt, err := replaceVariablesInExpr(stack, operation.PrimaryData) - if err != nil { - return nil, nil, err - } - rowIter, err := query(ctx, runner, stmt.(ast.Statement)) - if err != nil { - return nil, nil, err - } - rowIters = append(rowIters, rowIter) + case OpCode_Execute: + stmt, err := replaceVariablesInExpr(stack, operation.PrimaryData) + if err != nil { + return 0, nil, nil, err + } + rowIter, err := query(ctx, runner, stmt.(ast.Statement)) + if err != nil { + return 0, nil, nil, err + } + return counter, nil, rowIter, err + + case OpCode_Exception: + return 0, nil, nil, operation.Error + + case OpCode_ScopeBegin: + stack.PushScope() + + case OpCode_ScopeEnd: + stack.PopScope() + + default: + panic("unimplemented opcode") + } - case OpCode_Exception: - return nil, nil, operation.Error + return counter, nil, nil, nil +} - case OpCode_ScopeBegin: - stack.PushScope() +// Call runs the contained operations on the given runner. +func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.RowIter, *InterpreterStack, error) { + // Set up the initial state of the function + counter := -1 // We increment before accessing, so start at -1 + stack := NewInterpreterStack() + for _, param := range params { + stack.NewVariableWithValue(param.Name, param.Type, param.Value) + } - case OpCode_ScopeEnd: - stack.PopScope() + // TODO: remove this; track last selectRowIter + var selIter sql.RowIter - default: - panic("unimplemented opcode") + // Run the statements + // TODO: eventually return multiple sql.RowIters + var rowIters []sql.RowIter + runner := iNode.GetRunner() + statements := iNode.GetStatements() + for { + counter++ + if counter < 0 { + panic("negative function counter") + } + if counter >= len(statements) { + break + } + + operation := statements[counter] + newCounter, newSelIter, rowIter, err := execOp(ctx, runner, stack, operation, statements, counter) + if err = handleError(ctx, stack, runner, err); err != nil { + if err != io.EOF { + return nil, nil, err + } + for counter < len(statements) && statements[counter].OpCode != OpCode_ScopeEnd { + counter++ + } + newCounter = counter + } + if rowIter != nil { + rowIters = append(rowIters, rowIter) + } + if newSelIter != nil { + selIter = newSelIter } + counter = newCounter } if selIter != nil { diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index afc2bfbcc7..1021729abc 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -98,8 +98,8 @@ type InterpreterCursor struct { // InterpreterHandler is a declare handler that specifies an Action during an error Condition. type InterpreterHandler struct { - Condition string - Action string + Condition ast.DeclareHandlerConditionValue + Action ast.DeclareHandlerAction Statement ast.Statement } @@ -147,7 +147,7 @@ func NewInterpreterStack() *InterpreterStack { stack.Push(&InterpreterScopeDetails{ conditions: make(map[string]*InterpreterCondition), cursors: make(map[string]*InterpreterCursor), - handlers: make([]*InterpreterHandler), + handlers: make([]*InterpreterHandler, 0), variables: make(map[string]*InterpreterVariable), }) return &InterpreterStack{ @@ -242,7 +242,7 @@ func (is *InterpreterStack) GetCursor(name string) *InterpreterCursor { } // NewHandler creates a new handler in the current scope. -func (is *InterpreterStack) NewHandler(cond string, action string, stmt ast.Statement) { +func (is *InterpreterStack) NewHandler(cond ast.DeclareHandlerConditionValue, action ast.DeclareHandlerAction, stmt ast.Statement) { is.stack.Peek().handlers = append(is.stack.Peek().handlers, &InterpreterHandler{ Condition: cond, Action: action, diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 0323c57cdf..9b6dde7e30 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -233,7 +233,7 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq return &callIter{ call: n, - innerIter: rowIter.(sql.RowIter), + innerIter: rowIter, }, nil } From 967a4f545db3e59653ecc4acf20ede9329a0868c Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 21 Mar 2025 02:27:04 -0700 Subject: [PATCH 048/111] asdf --- enginetest/memory_engine_test.go | 34 ++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 9c09b82299..1ce80b9e42 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -201,29 +201,33 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "FETCH multiple rows", + Name: "SQLEXCEPTION declare handler", SetUpScript: []string{ + `DROP TABLE IF EXISTS t1;`, `CREATE TABLE t1 (pk BIGINT PRIMARY KEY);`, ` -CREATE PROCEDURE p1() +CREATE PROCEDURE eof() BEGIN - DECLARE a, b INT; - DECLARE cur1 CURSOR FOR SELECT pk FROM t1; - DELETE FROM t1; - INSERT INTO t1 VALUES (1), (2); - OPEN cur1; - FETCH cur1 INTO a; - FETCH cur1 INTO b; - CLOSE cur1; - SELECT a, b; + DECLARE a, b INT DEFAULT 1; + DECLARE cur1 CURSOR FOR SELECT * FROM t1; + OPEN cur1; + BEGIN + DECLARE EXIT HANDLER FOR SQLEXCEPTION SET a = 7; + tloop: LOOP + FETCH cur1 INTO b; + IF a > 1000 THEN + LEAVE tloop; + END IF; + END LOOP; + END; + CLOSE cur1; + SELECT a; END;`, }, Assertions: []queries.ScriptTestAssertion{ { - Query: "CALL p1();", - Expected: []sql.Row{ - {1, 2}, - }, + Query: "CALL eof();", + Expected: []sql.Row{}, }, }, }, From 4f915d277f1e462bc60f7c7453f1319fedb016ce Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 21 Mar 2025 16:46:01 -0700 Subject: [PATCH 049/111] various fixes --- sql/procedures/interpreter_logic.go | 48 ++++++++++--- sql/procedures/interpreter_stack.go | 103 +++++++++++++++++++--------- sql/procedures/parse.go | 79 +++++++++++++++------ 3 files changed, 167 insertions(+), 63 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index e890f56acb..b54bf01718 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -57,6 +57,18 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN Qualifier: e.Qualifier, StoredProcVal: newExpr, }, nil + case *ast.ParenExpr: + newExpr, err := replaceVariablesInExpr(stack, e.Expr) + if err != nil { + return nil, err + } + e.Expr = newExpr.(ast.Expr) + case *ast.AliasedTableExpr: + newExpr, err := replaceVariablesInExpr(stack, e.Expr) + if err != nil { + return nil, err + } + e.Expr = newExpr.(ast.SimpleTableExpr) case *ast.AliasedExpr: newExpr, err := replaceVariablesInExpr(stack, e.Expr) if err != nil { @@ -155,6 +167,15 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } e.SelectExprs[i] = newExpr.(ast.SelectExpr) } + if e.With != nil { + for i := range e.With.Ctes { + newExpr, err := replaceVariablesInExpr(stack, e.With.Ctes[i].AliasedTableExpr) + if err != nil { + return nil, err + } + e.With.Ctes[i].AliasedTableExpr = newExpr.(*ast.AliasedTableExpr) + } + } if e.Into != nil { newExpr, err := replaceVariablesInExpr(stack, e.Into) if err != nil { @@ -250,8 +271,10 @@ func handleError(ctx *sql.Context, stack *InterpreterStack, runner sql.Statement var matchingHandler *InterpreterHandler for _, handler := range stack.ListHandlers() { - if errors.Is(err, expression.FetchEOF) && handler.Condition == ast.DeclareHandlerCondition_NotFound { - matchingHandler = handler + if errors.Is(err, expression.FetchEOF) { + if handler.Condition == ast.DeclareHandlerCondition_NotFound { + matchingHandler = handler + } break } switch handler.Condition { @@ -546,6 +569,9 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } row, err := cursor.RowIter.Next(ctx) if err != nil { + if err == io.EOF { + return 0, nil, nil, expression.FetchEOF + } return 0, nil, nil, err } if len(row) != len(fetchCur.Variables) { @@ -664,6 +690,9 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } } else { for ; counter > operation.Index-1; counter-- { + if counter == -1 { + print() + } switch statements[counter].OpCode { case OpCode_ScopeBegin: stack.PopScope() @@ -730,12 +759,15 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.Row operation := statements[counter] newCounter, newSelIter, rowIter, err := execOp(ctx, runner, stack, operation, statements, counter) - if err = handleError(ctx, stack, runner, err); err != nil { - if err != io.EOF { - return nil, nil, err - } - for counter < len(statements) && statements[counter].OpCode != OpCode_ScopeEnd { - counter++ + if err != nil { + hErr := handleError(ctx, stack, runner, err) + if hErr != nil { + if hErr != io.EOF { + return nil, nil, hErr + } + for counter < len(statements) && statements[counter].OpCode != OpCode_ScopeEnd { + counter++ + } } newCounter = counter } diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 1021729abc..91ca860a47 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -117,7 +117,22 @@ func (iv *InterpreterVariable) ToAST() ast.Expr { return &ast.NullVal{} } if types.IsInteger(iv.Type) { - return ast.NewIntVal([]byte(fmt.Sprintf("%d", iv.Value))) + switch val := iv.Value.(type) { + case bool: + if val { + return ast.NewIntVal([]byte("1")) + } else { + return ast.NewIntVal([]byte("0")) + } + case ast.BoolVal: + if val { + return ast.NewIntVal([]byte("1")) + } else { + return ast.NewIntVal([]byte("0")) + } + default: + return ast.NewIntVal([]byte(fmt.Sprintf("%d", val))) + } } if types.IsFloat(iv.Type) { return ast.NewFloatVal([]byte(strconv.FormatFloat(iv.Value.(float64), 'f', -1, 64))) @@ -125,12 +140,15 @@ func (iv *InterpreterVariable) ToAST() ast.Expr { return ast.NewStrVal([]byte(fmt.Sprintf("%s", iv.Value))) } -// InterpreterScopeDetails contains all of the details that are relevant to a particular scope. +// InterpreterScopeDetails contains all the details that are relevant to a particular scope. type InterpreterScopeDetails struct { conditions map[string]*InterpreterCondition cursors map[string]*InterpreterCursor handlers []*InterpreterHandler variables map[string]*InterpreterVariable + + // labels mark the counter of the start of a loop or block. + labels map[string]int } // InterpreterStack represents the working information that an interpreter will use during execution. It is not exactly @@ -149,6 +167,8 @@ func NewInterpreterStack() *InterpreterStack { cursors: make(map[string]*InterpreterCursor), handlers: make([]*InterpreterHandler, 0), variables: make(map[string]*InterpreterVariable), + + labels: make(map[string]int), }) return &InterpreterStack{ stack: stack, @@ -160,6 +180,26 @@ func (is *InterpreterStack) Details() *InterpreterScopeDetails { return is.stack.Peek() } +// NewVariable creates a new variable in the current scope. If a variable with the same name exists in a previous scope, +// then that variable will be shadowed until the current scope exits. +func (is *InterpreterStack) NewVariable(name string, typ sql.Type) { + is.NewVariableWithValue(name, typ, typ.Zero()) +} + +// NewVariableWithValue creates a new variable in the current scope, setting its initial value to the one given. +func (is *InterpreterStack) NewVariableWithValue(name string, typ sql.Type, val any) { + is.stack.Peek().variables[name] = &InterpreterVariable{ + Type: typ, + Value: val, + } +} + +// NewVariableAlias creates a new variable alias, named |alias|, in the current frame of this stack, +// pointing to the specified |variable|. +func (is *InterpreterStack) NewVariableAlias(alias string, variable *InterpreterVariable) { + is.stack.Peek().variables[alias] = variable +} + // GetVariable traverses the stack (starting from the top) to find a variable with a matching name. Returns nil if no // variable was found. func (is *InterpreterStack) GetVariable(name string) *InterpreterVariable { @@ -182,24 +222,16 @@ func (is *InterpreterStack) ListVariables() map[string]struct{} { return seen } -// NewVariable creates a new variable in the current scope. If a variable with the same name exists in a previous scope, -// then that variable will be shadowed until the current scope exits. -func (is *InterpreterStack) NewVariable(name string, typ sql.Type) { - is.NewVariableWithValue(name, typ, typ.Zero()) -} - -// NewVariableWithValue creates a new variable in the current scope, setting its initial value to the one given. -func (is *InterpreterStack) NewVariableWithValue(name string, typ sql.Type, val any) { - is.stack.Peek().variables[name] = &InterpreterVariable{ - Type: typ, - Value: val, +// SetVariable sets the first variable found, with a matching name, to the value given. This does not ensure that the +// value matches the expectations of the type, so it should be validated before this is called. Returns an error if the +// variable cannot be found. +func (is *InterpreterStack) SetVariable(name string, val any) error { + iv := is.GetVariable(name) + if iv == nil { + return fmt.Errorf("variable `%s` could not be found", name) } -} - -// NewVariableAlias creates a new variable alias, named |alias|, in the current frame of this stack, -// pointing to the specified |variable|. -func (is *InterpreterStack) NewVariableAlias(alias string, variable *InterpreterVariable) { - is.stack.Peek().variables[alias] = variable + iv.Value = val + return nil } // NewCondition creates a new condition in the current scope. @@ -221,8 +253,6 @@ func (is *InterpreterStack) GetCondition(name string) *InterpreterCondition { return nil } - - // NewCursor creates a new cursor in the current scope. func (is *InterpreterStack) NewCursor(name string, selStmt ast.SelectStatement) { is.stack.Peek().cursors[name] = &InterpreterCursor{ @@ -261,12 +291,31 @@ func (is *InterpreterStack) ListHandlers() []*InterpreterHandler { return handlers } +// NewLabel creates a new label in the current scope. +func (is *InterpreterStack) NewLabel(name string, index int) { + is.stack.Peek().labels[name] = index +} + +// GetLabel traverses the stack (starting from the top) to find a label with a matching name. Returns -1 if no +// variable was found. +func (is *InterpreterStack) GetLabel(name string) int { + for i := 0; i < is.stack.Len(); i++ { + if index, ok := is.stack.PeekDepth(i).labels[name]; ok { + return index + } + } + return -1 +} + // PushScope creates a new scope. func (is *InterpreterStack) PushScope() { is.stack.Push(&InterpreterScopeDetails{ conditions: make(map[string]*InterpreterCondition), cursors: make(map[string]*InterpreterCursor), + handlers: make([]*InterpreterHandler, 0), variables: make(map[string]*InterpreterVariable), + + labels: make(map[string]int), }) } @@ -274,15 +323,3 @@ func (is *InterpreterStack) PushScope() { func (is *InterpreterStack) PopScope() { is.stack.Pop() } - -// SetVariable sets the first variable found, with a matching name, to the value given. This does not ensure that the -// value matches the expectations of the type, so it should be validated before this is called. Returns an error if the -// variable cannot be found. -func (is *InterpreterStack) SetVariable(name string, val any) error { - iv := is.GetVariable(name) - if iv == nil { - return fmt.Errorf("variable `%s` could not be found", name) - } - iv.Value = val - return nil -} diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index bf0ac93843..4abcf65b8c 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -20,26 +20,57 @@ import ( ast "github.com/dolthub/vitess/go/vt/sqlparser" ) +// resolveGoToIndexes will iterate over operations from start to end, and resolve the indexes of any OpCode_Goto +// operations, assigning either loopStart or loopEnd. +func resolveGoToIndexes(ops *[]*InterpreterOperation, label string, start, end, loopStart, loopEnd int) { + if label == "" { + return + } + for idx := start; idx < end; idx++ { + op := (*ops)[idx] + switch op.OpCode { + case OpCode_Goto: + if op.Target != label { + continue + } + switch op.Index { + case -1: // iterate + (*ops)[idx].Index = loopStart + case -2: // leave + (*ops)[idx].Index = loopEnd + default: + continue + } + default: + continue + } + } +} + func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast.Statement) error { switch s := stmt.(type) { case *ast.BeginEndBlock: stack.PushScope() - startOP := &InterpreterOperation{ + startOp := &InterpreterOperation{ OpCode: OpCode_ScopeBegin, + Target: s.Label, } - *ops = append(*ops, startOP) + *ops = append(*ops, startOp) + startOp.Index = len(*ops) - // TODO: add declares for _, ss := range s.Statements { if err := ConvertStmt(ops, stack, ss); err != nil { return err } } + endOp := &InterpreterOperation{ OpCode: OpCode_ScopeEnd, + Target: s.Label, } *ops = append(*ops, endOp) - stack.PopScope() + endOp.Index = len(*ops) + resolveGoToIndexes(ops, s.Label, startOp.Index, endOp.Index, startOp.Index, endOp.Index) case *ast.Select: selectOp := &InterpreterOperation{ @@ -111,7 +142,6 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast *ops = append(*ops, setOp) case *ast.IfStatement: - // TODO: each subsequent condition is an else if var ifElseGotoOps []*InterpreterOperation for _, ifCond := range s.Conditions { selectCond := &ast.Select{ @@ -235,11 +265,12 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast Index: loopStart, } *ops = append(*ops, gotoOp) - whileOp.Index = len(*ops) // end of while block + resolveGoToIndexes(ops, s.Label, loopStart, whileOp.Index, loopStart, whileOp.Index) case *ast.Repeat: // repeat statements always run at least once + onceStart := len(*ops) for _, repeatStmt := range s.Statements { if err := ConvertStmt(ops, stack, repeatStmt); err != nil { return err @@ -247,6 +278,9 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } loopStart := len(*ops) + if s.Label != "" { + stack.NewLabel(s.Label, loopStart) + } repeatCond := &ast.NotExpr{Expr: s.Condition} selectCond := &ast.Select{ SelectExprs: ast.SelectExprs{ @@ -272,11 +306,14 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast Index: loopStart, } *ops = append(*ops, gotoOp) - repeatOp.Index = len(*ops) // end of repeat block + resolveGoToIndexes(ops, s.Label, onceStart, repeatOp.Index, loopStart, repeatOp.Index) case *ast.Loop: loopStart := len(*ops) + if s.Label != "" { + stack.NewLabel(s.Label, loopStart) + } for _, loopStmt := range s.Statements { if err := ConvertStmt(ops, stack, loopStmt); err != nil { return err @@ -284,37 +321,35 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } gotoOp := &InterpreterOperation{ OpCode: OpCode_Goto, + Target: s.Label, Index: loopStart, } *ops = append(*ops, gotoOp) + loopEnd := len(*ops) + resolveGoToIndexes(ops, s.Label, loopStart, loopEnd, loopStart, loopEnd) - // perform second pass over loop statements to add labels - for idx := loopStart; idx < len(*ops); idx++ { - op := (*ops)[idx] - switch op.OpCode { - case OpCode_Goto: - if op.Target == s.Label { - (*ops)[idx].Index = len(*ops) - } - default: - continue - } + case *ast.Iterate: + iterateOp := &InterpreterOperation{ + OpCode: OpCode_Goto, + Target: s.Label, + Index: stack.GetLabel(s.Label), // possible this is -1, which will get resolved later } + *ops = append(*ops, iterateOp) case *ast.Leave: leaveOp := &InterpreterOperation{ OpCode: OpCode_Goto, - Target: s.Label, // hacky? way to signal a leave + Target: s.Label, + Index: -2, // -2 indicates that this is a leave statement with unknown target index } *ops = append(*ops, leaveOp) - default: - execOp := &InterpreterOperation{ + executeOp := &InterpreterOperation{ OpCode: OpCode_Execute, PrimaryData: s, } - *ops = append(*ops, execOp) + *ops = append(*ops, executeOp) } return nil From 95b5bd96ff6fb6c7ee27717c3be056e704d5ecea Mon Sep 17 00:00:00 2001 From: James Cor Date: Sat, 22 Mar 2025 11:39:23 -0700 Subject: [PATCH 050/111] debugging out params being null --- enginetest/memory_engine_test.go | 41 ++++++++----------------- enginetest/queries/procedure_queries.go | 4 +-- sql/procedures/interpreter_logic.go | 1 - 3 files changed, 14 insertions(+), 32 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 1ce80b9e42..12dea1f6f6 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,37 +200,22 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ - { - Name: "SQLEXCEPTION declare handler", - SetUpScript: []string{ - `DROP TABLE IF EXISTS t1;`, - `CREATE TABLE t1 (pk BIGINT PRIMARY KEY);`, - ` -CREATE PROCEDURE eof() -BEGIN - DECLARE a, b INT DEFAULT 1; - DECLARE cur1 CURSOR FOR SELECT * FROM t1; - OPEN cur1; - BEGIN - DECLARE EXIT HANDLER FOR SQLEXCEPTION SET a = 7; - tloop: LOOP - FETCH cur1 INTO b; - IF a > 1000 THEN - LEAVE tloop; - END IF; - END LOOP; - END; - CLOSE cur1; - SELECT a; -END;`, - }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "CALL eof();", - Expected: []sql.Row{}, + { + Name: "OUT param without SET", + SetUpScript: []string{ + "SET @outparam = 5", + "CREATE PROCEDURE testabc(OUT x BIGINT) SELECT x", + "CALL testabc(@outparam)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT @outparam", + Expected: []sql.Row{ + {nil}, }, }, }, + }, } for _, test := range scripts { diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 8f27d9531e..c602d960d5 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2217,9 +2217,7 @@ var ProcedureCallTests = []ScriptTest{ { Query: "SELECT @outparam", Expected: []sql.Row{ - { - nil, - }, + {nil}, }, }, }, diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index b54bf01718..78e4180f03 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -264,7 +264,6 @@ func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sq // handleError handles errors that occur during the execution of a procedure according to the defined handlers. func handleError(ctx *sql.Context, stack *InterpreterStack, runner sql.StatementRunner, err error) error { - // TODO: just copy logic from expression/procedurereference.go if err == nil { return nil } From 9edbf921ace7f12484c5b004b518cad67bb43cb4 Mon Sep 17 00:00:00 2001 From: James Cor Date: Sat, 22 Mar 2025 12:10:24 -0700 Subject: [PATCH 051/111] fix assigning user vars --- sql/procedures/interpreter_stack.go | 6 ++++-- sql/rowexec/proc.go | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 91ca860a47..9361d286c1 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -105,8 +105,9 @@ type InterpreterHandler struct { // InterpreterVariable is a variable that lives on the stack. type InterpreterVariable struct { - Type sql.Type - Value any + Type sql.Type + Value any + HasBeenSet bool } func (iv *InterpreterVariable) ToAST() ast.Expr { @@ -231,6 +232,7 @@ func (is *InterpreterStack) SetVariable(name string, val any) error { return fmt.Errorf("variable `%s` could not be found", name) } iv.Value = val + iv.HasBeenSet = true return nil } diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 9b6dde7e30..9c4a08bfeb 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -221,7 +221,11 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq return nil, err } case *expression.UserVar: - err = ctx.SetUserVariable(ctx, p.Name, stackVar.Value, stackVar.Type) + val := stackVar.Value + if procParam.Direction == plan.ProcedureParamDirection_Out && !stackVar.HasBeenSet { + val = nil + } + err = ctx.SetUserVariable(ctx, p.Name, val, stackVar.Type) if err != nil { return nil, err } From ca8e957da04afb9e5d14accb68b257ec23522e11 Mon Sep 17 00:00:00 2001 From: James Cor Date: Sat, 22 Mar 2025 12:10:47 -0700 Subject: [PATCH 052/111] fix views in procs --- enginetest/queries/procedure_queries.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index c602d960d5..d2926570c5 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2969,16 +2969,22 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { - Name: "procedure must not contain CREATE VIEW", - SetUpScript: []string{}, + Name: "procedure can CREATE VIEW", + SetUpScript: []string{ + + }, Assertions: []ScriptTestAssertion{ { - Query: "create procedure p() create view v as select 1;", - ExpectedErrStr: "CREATE statements in CREATE PROCEDURE not yet supported", + Query: "create procedure p1() create view v as select 1;", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, }, { - Query: "create procedure p() begin create view v as select 1; end;", - ExpectedErrStr: "CREATE statements in CREATE PROCEDURE not yet supported", + Query: "create procedure p() begin create view v as select 1; end;", + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, }, }, }, From fc036a02b9f06c9d027e8573dd4b5bd1060eac06 Mon Sep 17 00:00:00 2001 From: James Cor Date: Sat, 22 Mar 2025 16:01:01 -0700 Subject: [PATCH 053/111] fix table errors --- enginetest/queries/procedure_queries.go | 2 +- sql/planbuilder/from.go | 5 ----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index d2926570c5..7c0be4903a 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2933,7 +2933,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, { Query: "call drop_proc()", - ExpectedErrStr: "Unknown table 't'", + ExpectedErrStr: "table not found: t", }, }, }, diff --git a/sql/planbuilder/from.go b/sql/planbuilder/from.go index 8912d8cfac..c7070fe69f 100644 --- a/sql/planbuilder/from.go +++ b/sql/planbuilder/from.go @@ -705,11 +705,6 @@ func (b *Builder) buildResolvedTable(inScope *scope, db, schema, name string, as b.TriggerCtx().UnresolvedTables = append(b.TriggerCtx().UnresolvedTables, name) return outScope, true } - // TODO: do the same for stored procedures - if b.procCtx != nil { - outScope.node = plan.NewUnresolvedTable(name, db) - return outScope, true - } return outScope, false } else { b.handleErr(tableResolveErr) From 4247bb70da436a5c67fef6223e3613ed9418b6ee Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 24 Mar 2025 11:33:14 -0700 Subject: [PATCH 054/111] adsf --- enginetest/memory_engine_test.go | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 12dea1f6f6..f7f8e189ea 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -201,18 +201,23 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "OUT param without SET", - SetUpScript: []string{ - "SET @outparam = 5", - "CREATE PROCEDURE testabc(OUT x BIGINT) SELECT x", - "CALL testabc(@outparam)", - }, + Name: "creating invalid procedure doesn't error until it is called", Assertions: []queries.ScriptTestAssertion{ { - Query: "SELECT @outparam", - Expected: []sql.Row{ - {nil}, - }, + Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "CALL proc1(@out_count);", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "CREATE TABLE mytable (i int, s varchar(128));", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "CALL proc1(@out_count);", + ExpectedErr: sql.ErrFunctionNotFound, }, }, }, From db9536044c2277febccf097121384187c41467dd Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 24 Mar 2025 13:42:33 -0700 Subject: [PATCH 055/111] bump --- go.mod | 2 +- go.sum | 12 ++---------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/go.mod b/go.mod index 94ad4ed084..c44e7d32cd 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250319212010-451ea8d003fa github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250320231804-0e77d549294c + github.com/dolthub/vitess v0.0.0-20250324203551-408d22cead28 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 0a66e2c968..1374225e5f 100644 --- a/go.sum +++ b/go.sum @@ -52,22 +52,14 @@ github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27 github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2/go.mod h1:mIEZOHnFx4ZMQeawhw9rhsj+0zwQj7adVsnBX7t+eKY= -github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 h1:rh2ij2yTYKJWlX+c8XRg4H5OzqPewbU1lPK8pcfVmx8= -github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= github.com/dolthub/go-icu-regex v0.0.0-20250319212010-451ea8d003fa h1:NFbzJ4wjWRz32nz2EimbrHpRx1Xt6k+IaR8N+j4x62k= github.com/dolthub/go-icu-regex v0.0.0-20250319212010-451ea8d003fa/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTEtT5tOBsCuCrlYnLRKpbJVJkDbrTRhwQ= github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730 h1:GtlMVB7+Z7fZZj7BHRFd2rzxZ574dJ8cB/EHWdq1kbY= -github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4 h1:wtS9ZWEyEeYzLCcqdGUo+7i3hAV5MWuY9Z7tYbQa65A= -github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a h1:HIH9g4z+yXr4DIFyT6L5qOIEGJ1zVtlj6baPyHAG4Yw= -github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250320231804-0e77d549294c h1:Dv2DfEGb8WRBi8I5KF5Sy39TuZi/FI692mpobKWcv4g= -github.com/dolthub/vitess v0.0.0-20250320231804-0e77d549294c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250324203551-408d22cead28 h1:RUXtVMLAx6Dk9vbvw6AOdOygOOqwWyJ6Y5KQSNkSNEw= +github.com/dolthub/vitess v0.0.0-20250324203551-408d22cead28/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From c0a76f25af34171c81c31c61a77e473580bf8d56 Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 24 Mar 2025 20:44:06 +0000 Subject: [PATCH 056/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/memory_engine_test.go | 38 ++++++++++++------------- enginetest/queries/procedure_queries.go | 8 ++---- sql/procedures/interpreter_logic.go | 8 +++--- sql/procedures/interpreter_operation.go | 16 +++++------ 4 files changed, 34 insertions(+), 36 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 11abec39ce..a182f32c9a 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -204,27 +204,27 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ - { - Name: "creating invalid procedure doesn't error until it is called", - Assertions: []queries.ScriptTestAssertion{ - { - Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, - Expected: []sql.Row{{types.NewOkResult(0)}}, - }, - { - Query: "CALL proc1(@out_count);", - ExpectedErr: sql.ErrTableNotFound, - }, - { - Query: "CREATE TABLE mytable (i int, s varchar(128));", - Expected: []sql.Row{{types.NewOkResult(0)}}, - }, - { - Query: "CALL proc1(@out_count);", - ExpectedErr: sql.ErrFunctionNotFound, + { + Name: "creating invalid procedure doesn't error until it is called", + Assertions: []queries.ScriptTestAssertion{ + { + Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "CALL proc1(@out_count);", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "CREATE TABLE mytable (i int, s varchar(128));", + Expected: []sql.Row{{types.NewOkResult(0)}}, + }, + { + Query: "CALL proc1(@out_count);", + ExpectedErr: sql.ErrFunctionNotFound, + }, }, }, - }, } for _, test := range scripts { diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 7c0be4903a..6626f7e63e 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2970,18 +2970,16 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, { Name: "procedure can CREATE VIEW", - SetUpScript: []string{ - - }, + SetUpScript: []string{}, Assertions: []ScriptTestAssertion{ { - Query: "create procedure p1() create view v as select 1;", + Query: "create procedure p1() create view v as select 1;", Expected: []sql.Row{ {types.NewOkResult(0)}, }, }, { - Query: "create procedure p() begin create view v as select 1; end;", + Query: "create procedure p() begin create view v as select 1; end;", Expected: []sql.Row{ {types.NewOkResult(0)}, }, diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 78e4180f03..ac65dac419 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -278,10 +278,10 @@ func handleError(ctx *sql.Context, stack *InterpreterStack, runner sql.Statement } switch handler.Condition { case ast.DeclareHandlerCondition_MysqlErrorCode: - case ast.DeclareHandlerCondition_SqlState: - case ast.DeclareHandlerCondition_ConditionName: - case ast.DeclareHandlerCondition_SqlWarning: - case ast.DeclareHandlerCondition_NotFound: + case ast.DeclareHandlerCondition_SqlState: + case ast.DeclareHandlerCondition_ConditionName: + case ast.DeclareHandlerCondition_SqlWarning: + case ast.DeclareHandlerCondition_NotFound: case ast.DeclareHandlerCondition_SqlException: matchingHandler = handler break diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index 6d2b9925ea..bc833bcff4 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -17,20 +17,20 @@ import ast "github.com/dolthub/vitess/go/vt/sqlparser" type OpCode uint16 const ( - OpCode_Select OpCode = iota - OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html + OpCode_Select OpCode = iota + OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html OpCode_Signal OpCode_Open OpCode_Fetch OpCode_Close OpCode_Set - OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS - OpCode_Goto // All control-flow structures can be represented using Goto - OpCode_Execute // Everything that's not a SELECT - OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING - OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING + OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS + OpCode_Goto // All control-flow structures can be represented using Goto + OpCode_Execute // Everything that's not a SELECT + OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING + OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING OpCode_ScopeBegin // This is used for scope control, specific to Doltgres - OpCode_ScopeEnd // This is used for scope control, specific to Doltgres + OpCode_ScopeEnd // This is used for scope control, specific to Doltgres ) // InterpreterOperation is an operation that will be performed by the interpreter. From 70ea89f3a7587eac888bd3f63713db6be1d962fe Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 24 Mar 2025 16:40:41 -0700 Subject: [PATCH 057/111] fix handler scopes --- enginetest/queries/procedure_queries.go | 4 +- sql/procedures/interpreter_logic.go | 54 +++++++++++++++---------- sql/procedures/interpreter_stack.go | 4 +- 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 7c0be4903a..8a4f0d666b 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -1396,7 +1396,9 @@ END;`, Assertions: []ScriptTestAssertion{ { Query: "CALL outer_declare();", - Expected: []sql.Row{}, + Expected: []sql.Row{ + {types.NewOkResult(0)}, + }, }, { Query: "CALL inner_declare();", diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 78e4180f03..b74037b166 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -263,9 +263,9 @@ func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sq } // handleError handles errors that occur during the execution of a procedure according to the defined handlers. -func handleError(ctx *sql.Context, stack *InterpreterStack, runner sql.StatementRunner, err error) error { +func handleError(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStack, statements []*InterpreterOperation, counter int, err error) (int, error) { if err == nil { - return nil + return counter, nil } var matchingHandler *InterpreterHandler @@ -289,38 +289,51 @@ func handleError(ctx *sql.Context, stack *InterpreterStack, runner sql.Statement } if matchingHandler == nil { - return err + return -1, err } handlerOps := make([]*InterpreterOperation, 0, 1) err = ConvertStmt(&handlerOps, stack, matchingHandler.Statement) if err != nil { - return err + return -1, err } _, _, rowIter, err := execOp(ctx, runner, stack, handlerOps[0], handlerOps, -1) if err != nil { - return err + return -1, err } if rowIter != nil { for { _, err = rowIter.Next(ctx) if err != nil { - return err + return -1, err } } } switch matchingHandler.Action { case ast.DeclareHandlerAction_Continue: - return nil + return counter, nil case ast.DeclareHandlerAction_Exit: - return io.EOF + remainingEndScopes := 1 + var newCounter int + for newCounter = matchingHandler.Counter; newCounter < len(statements); newCounter++ { + if remainingEndScopes == 0 { + break + } + switch statements[newCounter].OpCode { + case OpCode_ScopeBegin: + remainingEndScopes++ + case OpCode_ScopeEnd: + remainingEndScopes-- + default: + } + } + return newCounter, io.EOF case ast.DeclareHandlerAction_Undo: - return fmt.Errorf("DECLARE UNDO HANDLER is not supported") + return -1, fmt.Errorf("DECLARE UNDO HANDLER is not supported") } - - return nil + return counter, nil } func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStack, operation *InterpreterOperation, statements []*InterpreterOperation, counter int) (int, sql.RowIter, sql.RowIter, error) { @@ -428,7 +441,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac return 0, nil, nil, fmt.Errorf("unsupported handler action: %s", handler.Action) } - stack.NewHandler(hCond.ValueType, handler.Action, handler.Statement) + stack.NewHandler(hCond.ValueType, handler.Action, handler.Statement, counter) } // TODO: duplicate variables? @@ -759,16 +772,15 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.Row operation := statements[counter] newCounter, newSelIter, rowIter, err := execOp(ctx, runner, stack, operation, statements, counter) if err != nil { - hErr := handleError(ctx, stack, runner, err) - if hErr != nil { - if hErr != io.EOF { - return nil, nil, hErr - } - for counter < len(statements) && statements[counter].OpCode != OpCode_ScopeEnd { - counter++ - } + hCounter, hErr := handleError(ctx, runner, stack, statements, counter, err) + if hErr != nil && hErr != io.EOF { + return nil, nil, hErr + } + if hErr == io.EOF { + newCounter = hCounter + } else { + newCounter = counter } - newCounter = counter } if rowIter != nil { rowIters = append(rowIters, rowIter) diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 9361d286c1..1bc75cf8b7 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -101,6 +101,7 @@ type InterpreterHandler struct { Condition ast.DeclareHandlerConditionValue Action ast.DeclareHandlerAction Statement ast.Statement + Counter int // This is used to track the current position in the stack for the handler } // InterpreterVariable is a variable that lives on the stack. @@ -274,11 +275,12 @@ func (is *InterpreterStack) GetCursor(name string) *InterpreterCursor { } // NewHandler creates a new handler in the current scope. -func (is *InterpreterStack) NewHandler(cond ast.DeclareHandlerConditionValue, action ast.DeclareHandlerAction, stmt ast.Statement) { +func (is *InterpreterStack) NewHandler(cond ast.DeclareHandlerConditionValue, action ast.DeclareHandlerAction, stmt ast.Statement, counter int) { is.stack.Peek().handlers = append(is.stack.Peek().handlers, &InterpreterHandler{ Condition: cond, Action: action, Statement: stmt, + Counter: counter, }) } From e5bbdfccaa7656afda390952d4b3cb262a497593 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 24 Mar 2025 17:01:27 -0700 Subject: [PATCH 058/111] fix declares --- sql/procedures/interpreter_logic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index b74037b166..dfcab40bdc 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -329,7 +329,7 @@ func handleError(ctx *sql.Context, runner sql.StatementRunner, stack *Interprete default: } } - return newCounter, io.EOF + return newCounter - 1, io.EOF case ast.DeclareHandlerAction_Undo: return -1, fmt.Errorf("DECLARE UNDO HANDLER is not supported") } From 6a12f7a51697678d10b8a6ca45c8e0396f24a775 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 25 Mar 2025 12:01:02 -0700 Subject: [PATCH 059/111] debugging --- enginetest/memory_engine_test.go | 39 ++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 11abec39ce..650b4f7b49 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -205,23 +205,32 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "creating invalid procedure doesn't error until it is called", + Name: "SQLEXCEPTION declare handler", + SetUpScript: []string{ + `DROP TABLE IF EXISTS t1;`, + `CREATE TABLE t1 (pk BIGINT PRIMARY KEY);`, + `CREATE PROCEDURE eof() +BEGIN + DECLARE a, b INT DEFAULT 1; + DECLARE cur1 CURSOR FOR SELECT * FROM t1; + OPEN cur1; + BEGIN + DECLARE EXIT HANDLER FOR SQLEXCEPTION SET a = 7; + tloop: LOOP + FETCH cur1 INTO b; + IF a > 1000 THEN + LEAVE tloop; + END IF; + END LOOP; + END; + CLOSE cur1; + SELECT a; +END;`, + }, Assertions: []queries.ScriptTestAssertion{ { - Query: `CREATE PROCEDURE proc1 (OUT out_count INT) READS SQL DATA SELECT COUNT(*) FROM mytable WHERE i = 1 AND s = 'first row' AND func1(i);`, - Expected: []sql.Row{{types.NewOkResult(0)}}, - }, - { - Query: "CALL proc1(@out_count);", - ExpectedErr: sql.ErrTableNotFound, - }, - { - Query: "CREATE TABLE mytable (i int, s varchar(128));", - Expected: []sql.Row{{types.NewOkResult(0)}}, - }, - { - Query: "CALL proc1(@out_count);", - ExpectedErr: sql.ErrFunctionNotFound, + Query: "CALL eof();", + Expected: []sql.Row{}, }, }, }, From 6b7413bc45e1e058ed517fc66d50394b488dc7c8 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 25 Mar 2025 12:35:20 -0700 Subject: [PATCH 060/111] external procs --- enginetest/memory_engine_test.go | 32 +- .../queries/external_procedure_queries.go | 410 +++++++++--------- enginetest/queries/procedure_queries.go | 9 +- sql/rowexec/proc.go | 27 ++ 4 files changed, 246 insertions(+), 232 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 650b4f7b49..7d07f93f48 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -205,32 +205,20 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "SQLEXCEPTION declare handler", + Name: "Nested CALL with INOUT param", SetUpScript: []string{ - `DROP TABLE IF EXISTS t1;`, - `CREATE TABLE t1 (pk BIGINT PRIMARY KEY);`, - `CREATE PROCEDURE eof() -BEGIN - DECLARE a, b INT DEFAULT 1; - DECLARE cur1 CURSOR FOR SELECT * FROM t1; - OPEN cur1; - BEGIN - DECLARE EXIT HANDLER FOR SQLEXCEPTION SET a = 7; - tloop: LOOP - FETCH cur1 INTO b; - IF a > 1000 THEN - LEAVE tloop; - END IF; - END LOOP; - END; - CLOSE cur1; - SELECT a; -END;`, + "SET @outparam = 5", + "CREATE PROCEDURE p3(INOUT z INT) BEGIN SET z = z * 111; END;", + "CREATE PROCEDURE p2(INOUT y DOUBLE) BEGIN SET y = y + 4; CALL p3(y); END;", + "CREATE PROCEDURE p1(INOUT x BIGINT) BEGIN SET x = 3; CALL p2(x); END;", + "CALL p2(@outparam)", }, Assertions: []queries.ScriptTestAssertion{ { - Query: "CALL eof();", - Expected: []sql.Row{}, + Query: "SELECT @outparam", + Expected: []sql.Row{ + {int64(999)}, + }, }, }, }, diff --git a/enginetest/queries/external_procedure_queries.go b/enginetest/queries/external_procedure_queries.go index c4db672e01..547a3e8034 100644 --- a/enginetest/queries/external_procedure_queries.go +++ b/enginetest/queries/external_procedure_queries.go @@ -17,15 +17,15 @@ package queries import "github.com/dolthub/go-mysql-server/sql" var ExternalProcedureTests = []ScriptTest{ - { - Name: "Call external stored procedure that does not exist", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL procedure_does_not_exist('foo');", - ExpectedErr: sql.ErrStoredProcedureDoesNotExist, - }, - }, - }, + //{ + // Name: "Call external stored procedure that does not exist", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL procedure_does_not_exist('foo');", + // ExpectedErr: sql.ErrStoredProcedureDoesNotExist, + // }, + // }, + //}, { Name: "INOUT on first param, IN on second param", SetUpScript: []string{ @@ -39,200 +39,200 @@ var ExternalProcedureTests = []ScriptTest{ }, }, }, - { - Name: "Handle setting uninitialized user variables", - SetUpScript: []string{ - "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "SELECT @uservar12;", - Expected: []sql.Row{{5}}, - }, - { - Query: "SELECT @uservar13;", - Expected: []sql.Row{{uint(5)}}, - }, - { - Query: "SELECT @uservar14;", - Expected: []sql.Row{{"5"}}, - }, - { - Query: "SELECT @uservar15;", - Expected: []sql.Row{{0}}, - }, - }, - }, - { - Name: "Called from standard stored procedure", - SetUpScript: []string{ - "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "CALL p1(11);", - Expected: []sql.Row{{22}}, - }, - }, - }, - { - Name: "Overloaded Name", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_overloaded_mult(1);", - Expected: []sql.Row{{1}}, - }, - { - Query: "CALL memory_overloaded_mult(2, 3);", - Expected: []sql.Row{{6}}, - }, - { - Query: "CALL memory_overloaded_mult(4, 5, 6);", - Expected: []sql.Row{{120}}, - }, - }, - }, - { - Name: "Passing in all supported types", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + - "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - Expected: []sql.Row{{1111114444}}, - }, - { - Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + - "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", - Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, - }, - { - Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + - "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - Expected: []sql.Row{{uint64(1111114444)}}, - }, - }, - }, - { - Name: "BOOL and []BYTE INOUT conversions", - SetUpScript: []string{ - "SET @outparam1 = 1;", - "SET @outparam2 = 0;", - "SET @outparam3 = 'A';", - "SET @outparam4 = 'B';", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - Expected: []sql.Row{{1, 0, "A", "B"}}, - }, - { - Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - Expected: []sql.Row{}, - }, - { - Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - Expected: []sql.Row{{1, 1, "A", []byte("C")}}, - }, - { - Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - Expected: []sql.Row{}, - }, - { - Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - Expected: []sql.Row{{1, 0, "A", []byte("D")}}, - }, - }, - }, - { - Name: "Errors returned", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_error_table_not_found();", - ExpectedErr: sql.ErrTableNotFound, - }, - }, - }, - { - Name: "Variadic parameter", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_variadic_add();", - Expected: []sql.Row{{0}}, - }, - { - Query: "CALL memory_variadic_add(1);", - Expected: []sql.Row{{1}}, - }, - { - Query: "CALL memory_variadic_add(1, 2);", - Expected: []sql.Row{{3}}, - }, - { - Query: "CALL memory_variadic_add(1, 2, 3);", - Expected: []sql.Row{{6}}, - }, - { - Query: "CALL memory_variadic_add(1, 2, 3, 4);", - Expected: []sql.Row{{10}}, - }, - }, - }, - { - Name: "Variadic byte slices", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_variadic_byte_slice();", - Expected: []sql.Row{{""}}, - }, - { - Query: "CALL memory_variadic_byte_slice('A');", - Expected: []sql.Row{{"A"}}, - }, - { - Query: "CALL memory_variadic_byte_slice('A', 'B');", - Expected: []sql.Row{{"AB"}}, - }, - }, - }, - { - Name: "Variadic overloading", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_variadic_overload();", - ExpectedErr: sql.ErrCallIncorrectParameterCount, - }, - { - Query: "CALL memory_variadic_overload('A');", - ExpectedErr: sql.ErrCallIncorrectParameterCount, - }, - { - Query: "CALL memory_variadic_overload('A', 'B');", - Expected: []sql.Row{{"A-B"}}, - }, - { - Query: "CALL memory_variadic_overload('A', 'B', 'C');", - ExpectedErr: sql.ErrInvalidValue, - }, - { - Query: "CALL memory_variadic_overload('A', 'B', 5);", - Expected: []sql.Row{{"A,B,[5]"}}, - }, - }, - }, - { - Name: "show create procedure for external stored procedures", - Assertions: []ScriptTestAssertion{ - { - Query: "show create procedure memory_variadic_overload;", - Expected: []sql.Row{{ - "memory_variadic_overload", - "", - "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", - "utf8mb4", - "utf8mb4_0900_bin", - "utf8mb4_0900_bin", - }}, - }, - }, - }, + //{ + // Name: "Handle setting uninitialized user variables", + // SetUpScript: []string{ + // "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", + // }, + // Assertions: []ScriptTestAssertion{ + // { + // Query: "SELECT @uservar12;", + // Expected: []sql.Row{{5}}, + // }, + // { + // Query: "SELECT @uservar13;", + // Expected: []sql.Row{{uint(5)}}, + // }, + // { + // Query: "SELECT @uservar14;", + // Expected: []sql.Row{{"5"}}, + // }, + // { + // Query: "SELECT @uservar15;", + // Expected: []sql.Row{{0}}, + // }, + // }, + //}, + //{ + // Name: "Called from standard stored procedure", + // SetUpScript: []string{ + // "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", + // }, + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL p1(11);", + // Expected: []sql.Row{{22}}, + // }, + // }, + //}, + //{ + // Name: "Overloaded Name", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_overloaded_mult(1);", + // Expected: []sql.Row{{1}}, + // }, + // { + // Query: "CALL memory_overloaded_mult(2, 3);", + // Expected: []sql.Row{{6}}, + // }, + // { + // Query: "CALL memory_overloaded_mult(4, 5, 6);", + // Expected: []sql.Row{{120}}, + // }, + // }, + //}, + //{ + // Name: "Passing in all supported types", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + + // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + // Expected: []sql.Row{{1111114444}}, + // }, + // { + // Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + + // "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", + // Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, + // }, + // { + // Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + + // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + // Expected: []sql.Row{{uint64(1111114444)}}, + // }, + // }, + //}, + //{ + // Name: "BOOL and []BYTE INOUT conversions", + // SetUpScript: []string{ + // "SET @outparam1 = 1;", + // "SET @outparam2 = 0;", + // "SET @outparam3 = 'A';", + // "SET @outparam4 = 'B';", + // }, + // Assertions: []ScriptTestAssertion{ + // { + // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + // Expected: []sql.Row{{1, 0, "A", "B"}}, + // }, + // { + // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + // Expected: []sql.Row{}, + // }, + // { + // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + // Expected: []sql.Row{{1, 1, "A", []byte("C")}}, + // }, + // { + // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + // Expected: []sql.Row{}, + // }, + // { + // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + // Expected: []sql.Row{{1, 0, "A", []byte("D")}}, + // }, + // }, + //}, + //{ + // Name: "Errors returned", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_error_table_not_found();", + // ExpectedErr: sql.ErrTableNotFound, + // }, + // }, + //}, + //{ + // Name: "Variadic parameter", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_variadic_add();", + // Expected: []sql.Row{{0}}, + // }, + // { + // Query: "CALL memory_variadic_add(1);", + // Expected: []sql.Row{{1}}, + // }, + // { + // Query: "CALL memory_variadic_add(1, 2);", + // Expected: []sql.Row{{3}}, + // }, + // { + // Query: "CALL memory_variadic_add(1, 2, 3);", + // Expected: []sql.Row{{6}}, + // }, + // { + // Query: "CALL memory_variadic_add(1, 2, 3, 4);", + // Expected: []sql.Row{{10}}, + // }, + // }, + //}, + //{ + // Name: "Variadic byte slices", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_variadic_byte_slice();", + // Expected: []sql.Row{{""}}, + // }, + // { + // Query: "CALL memory_variadic_byte_slice('A');", + // Expected: []sql.Row{{"A"}}, + // }, + // { + // Query: "CALL memory_variadic_byte_slice('A', 'B');", + // Expected: []sql.Row{{"AB"}}, + // }, + // }, + //}, + //{ + // Name: "Variadic overloading", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_variadic_overload();", + // ExpectedErr: sql.ErrCallIncorrectParameterCount, + // }, + // { + // Query: "CALL memory_variadic_overload('A');", + // ExpectedErr: sql.ErrCallIncorrectParameterCount, + // }, + // { + // Query: "CALL memory_variadic_overload('A', 'B');", + // Expected: []sql.Row{{"A-B"}}, + // }, + // { + // Query: "CALL memory_variadic_overload('A', 'B', 'C');", + // ExpectedErr: sql.ErrInvalidValue, + // }, + // { + // Query: "CALL memory_variadic_overload('A', 'B', 5);", + // Expected: []sql.Row{{"A,B,[5]"}}, + // }, + // }, + //}, + //{ + // Name: "show create procedure for external stored procedures", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "show create procedure memory_variadic_overload;", + // Expected: []sql.Row{{ + // "memory_variadic_overload", + // "", + // "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", + // "utf8mb4", + // "utf8mb4_0900_bin", + // "utf8mb4_0900_bin", + // }}, + // }, + // }, + //}, } diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 8a4f0d666b..e1e40815bb 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -1278,8 +1278,9 @@ END;`, }, Assertions: []ScriptTestAssertion{ { - Query: "CALL eof();", - Expected: []sql.Row{}, + // TODO: MySQL returns: ERROR: 1329: No data - zero rows fetched, selected, or processed + Query: "CALL eof();", + ExpectedErrStr: "exhausted fetch iterator", }, { Query: "CALL duplicate_key();", @@ -2273,9 +2274,7 @@ var ProcedureCallTests = []ScriptTest{ { Query: "SELECT @outparam", Expected: []sql.Row{ - { - int64(777), - }, + {int64(777)}, }, }, }, diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 9c4a08bfeb..42f6bdc3b5 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -182,6 +182,33 @@ func (b *BaseBuilder) buildProcedureResolvedTable(ctx *sql.Context, n *plan.Proc } func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sql.RowIter, error) { + if n.Procedure.ExternalProc != nil { + for i, paramExpr := range n.Params { + val, err := paramExpr.Eval(ctx, row) + if err != nil { + return nil, err + } + paramName := n.Procedure.Params[i].Name + paramType := n.Procedure.Params[i].Type + err = n.Pref.InitializeVariable(paramName, paramType, val) + if err != nil { + return nil, err + } + } + + n.Pref.PushScope() + defer n.Pref.PopScope(ctx) + + innerIter, err := b.buildNodeExec(ctx, n.Procedure, row) + if err != nil { + return nil, err + } + return &callIter{ + call: n, + innerIter: innerIter, + }, nil + } + procParams := make([]*procedures.Parameter, len(n.Params)) for i, paramExpr := range n.Params { param := n.Procedure.Params[i] From 247c5d4571b942764f7833df03733c1561f3303e Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 25 Mar 2025 23:35:49 -0700 Subject: [PATCH 061/111] fix external procs kinda --- .../queries/external_procedure_queries.go | 410 +++++++++--------- sql/plan/procedure.go | 22 +- sql/rowexec/proc_iters.go | 49 +++ 3 files changed, 275 insertions(+), 206 deletions(-) diff --git a/enginetest/queries/external_procedure_queries.go b/enginetest/queries/external_procedure_queries.go index 547a3e8034..c4db672e01 100644 --- a/enginetest/queries/external_procedure_queries.go +++ b/enginetest/queries/external_procedure_queries.go @@ -17,15 +17,15 @@ package queries import "github.com/dolthub/go-mysql-server/sql" var ExternalProcedureTests = []ScriptTest{ - //{ - // Name: "Call external stored procedure that does not exist", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL procedure_does_not_exist('foo');", - // ExpectedErr: sql.ErrStoredProcedureDoesNotExist, - // }, - // }, - //}, + { + Name: "Call external stored procedure that does not exist", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL procedure_does_not_exist('foo');", + ExpectedErr: sql.ErrStoredProcedureDoesNotExist, + }, + }, + }, { Name: "INOUT on first param, IN on second param", SetUpScript: []string{ @@ -39,200 +39,200 @@ var ExternalProcedureTests = []ScriptTest{ }, }, }, - //{ - // Name: "Handle setting uninitialized user variables", - // SetUpScript: []string{ - // "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Query: "SELECT @uservar12;", - // Expected: []sql.Row{{5}}, - // }, - // { - // Query: "SELECT @uservar13;", - // Expected: []sql.Row{{uint(5)}}, - // }, - // { - // Query: "SELECT @uservar14;", - // Expected: []sql.Row{{"5"}}, - // }, - // { - // Query: "SELECT @uservar15;", - // Expected: []sql.Row{{0}}, - // }, - // }, - //}, - //{ - // Name: "Called from standard stored procedure", - // SetUpScript: []string{ - // "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL p1(11);", - // Expected: []sql.Row{{22}}, - // }, - // }, - //}, - //{ - // Name: "Overloaded Name", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_overloaded_mult(1);", - // Expected: []sql.Row{{1}}, - // }, - // { - // Query: "CALL memory_overloaded_mult(2, 3);", - // Expected: []sql.Row{{6}}, - // }, - // { - // Query: "CALL memory_overloaded_mult(4, 5, 6);", - // Expected: []sql.Row{{120}}, - // }, - // }, - //}, - //{ - // Name: "Passing in all supported types", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + - // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - // Expected: []sql.Row{{1111114444}}, - // }, - // { - // Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + - // "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", - // Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, - // }, - // { - // Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + - // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - // Expected: []sql.Row{{uint64(1111114444)}}, - // }, - // }, - //}, - //{ - // Name: "BOOL and []BYTE INOUT conversions", - // SetUpScript: []string{ - // "SET @outparam1 = 1;", - // "SET @outparam2 = 0;", - // "SET @outparam3 = 'A';", - // "SET @outparam4 = 'B';", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - // Expected: []sql.Row{{1, 0, "A", "B"}}, - // }, - // { - // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - // Expected: []sql.Row{}, - // }, - // { - // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - // Expected: []sql.Row{{1, 1, "A", []byte("C")}}, - // }, - // { - // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - // Expected: []sql.Row{}, - // }, - // { - // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - // Expected: []sql.Row{{1, 0, "A", []byte("D")}}, - // }, - // }, - //}, - //{ - // Name: "Errors returned", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_error_table_not_found();", - // ExpectedErr: sql.ErrTableNotFound, - // }, - // }, - //}, - //{ - // Name: "Variadic parameter", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_variadic_add();", - // Expected: []sql.Row{{0}}, - // }, - // { - // Query: "CALL memory_variadic_add(1);", - // Expected: []sql.Row{{1}}, - // }, - // { - // Query: "CALL memory_variadic_add(1, 2);", - // Expected: []sql.Row{{3}}, - // }, - // { - // Query: "CALL memory_variadic_add(1, 2, 3);", - // Expected: []sql.Row{{6}}, - // }, - // { - // Query: "CALL memory_variadic_add(1, 2, 3, 4);", - // Expected: []sql.Row{{10}}, - // }, - // }, - //}, - //{ - // Name: "Variadic byte slices", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_variadic_byte_slice();", - // Expected: []sql.Row{{""}}, - // }, - // { - // Query: "CALL memory_variadic_byte_slice('A');", - // Expected: []sql.Row{{"A"}}, - // }, - // { - // Query: "CALL memory_variadic_byte_slice('A', 'B');", - // Expected: []sql.Row{{"AB"}}, - // }, - // }, - //}, - //{ - // Name: "Variadic overloading", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_variadic_overload();", - // ExpectedErr: sql.ErrCallIncorrectParameterCount, - // }, - // { - // Query: "CALL memory_variadic_overload('A');", - // ExpectedErr: sql.ErrCallIncorrectParameterCount, - // }, - // { - // Query: "CALL memory_variadic_overload('A', 'B');", - // Expected: []sql.Row{{"A-B"}}, - // }, - // { - // Query: "CALL memory_variadic_overload('A', 'B', 'C');", - // ExpectedErr: sql.ErrInvalidValue, - // }, - // { - // Query: "CALL memory_variadic_overload('A', 'B', 5);", - // Expected: []sql.Row{{"A,B,[5]"}}, - // }, - // }, - //}, - //{ - // Name: "show create procedure for external stored procedures", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "show create procedure memory_variadic_overload;", - // Expected: []sql.Row{{ - // "memory_variadic_overload", - // "", - // "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", - // "utf8mb4", - // "utf8mb4_0900_bin", - // "utf8mb4_0900_bin", - // }}, - // }, - // }, - //}, + { + Name: "Handle setting uninitialized user variables", + SetUpScript: []string{ + "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT @uservar12;", + Expected: []sql.Row{{5}}, + }, + { + Query: "SELECT @uservar13;", + Expected: []sql.Row{{uint(5)}}, + }, + { + Query: "SELECT @uservar14;", + Expected: []sql.Row{{"5"}}, + }, + { + Query: "SELECT @uservar15;", + Expected: []sql.Row{{0}}, + }, + }, + }, + { + Name: "Called from standard stored procedure", + SetUpScript: []string{ + "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "CALL p1(11);", + Expected: []sql.Row{{22}}, + }, + }, + }, + { + Name: "Overloaded Name", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_overloaded_mult(1);", + Expected: []sql.Row{{1}}, + }, + { + Query: "CALL memory_overloaded_mult(2, 3);", + Expected: []sql.Row{{6}}, + }, + { + Query: "CALL memory_overloaded_mult(4, 5, 6);", + Expected: []sql.Row{{120}}, + }, + }, + }, + { + Name: "Passing in all supported types", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + + "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + Expected: []sql.Row{{1111114444}}, + }, + { + Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + + "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", + Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, + }, + { + Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + + "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + Expected: []sql.Row{{uint64(1111114444)}}, + }, + }, + }, + { + Name: "BOOL and []BYTE INOUT conversions", + SetUpScript: []string{ + "SET @outparam1 = 1;", + "SET @outparam2 = 0;", + "SET @outparam3 = 'A';", + "SET @outparam4 = 'B';", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + Expected: []sql.Row{{1, 0, "A", "B"}}, + }, + { + Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + Expected: []sql.Row{}, + }, + { + Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + Expected: []sql.Row{{1, 1, "A", []byte("C")}}, + }, + { + Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + Expected: []sql.Row{}, + }, + { + Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + Expected: []sql.Row{{1, 0, "A", []byte("D")}}, + }, + }, + }, + { + Name: "Errors returned", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_error_table_not_found();", + ExpectedErr: sql.ErrTableNotFound, + }, + }, + }, + { + Name: "Variadic parameter", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_variadic_add();", + Expected: []sql.Row{{0}}, + }, + { + Query: "CALL memory_variadic_add(1);", + Expected: []sql.Row{{1}}, + }, + { + Query: "CALL memory_variadic_add(1, 2);", + Expected: []sql.Row{{3}}, + }, + { + Query: "CALL memory_variadic_add(1, 2, 3);", + Expected: []sql.Row{{6}}, + }, + { + Query: "CALL memory_variadic_add(1, 2, 3, 4);", + Expected: []sql.Row{{10}}, + }, + }, + }, + { + Name: "Variadic byte slices", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_variadic_byte_slice();", + Expected: []sql.Row{{""}}, + }, + { + Query: "CALL memory_variadic_byte_slice('A');", + Expected: []sql.Row{{"A"}}, + }, + { + Query: "CALL memory_variadic_byte_slice('A', 'B');", + Expected: []sql.Row{{"AB"}}, + }, + }, + }, + { + Name: "Variadic overloading", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_variadic_overload();", + ExpectedErr: sql.ErrCallIncorrectParameterCount, + }, + { + Query: "CALL memory_variadic_overload('A');", + ExpectedErr: sql.ErrCallIncorrectParameterCount, + }, + { + Query: "CALL memory_variadic_overload('A', 'B');", + Expected: []sql.Row{{"A-B"}}, + }, + { + Query: "CALL memory_variadic_overload('A', 'B', 'C');", + ExpectedErr: sql.ErrInvalidValue, + }, + { + Query: "CALL memory_variadic_overload('A', 'B', 5);", + Expected: []sql.Row{{"A,B,[5]"}}, + }, + }, + }, + { + Name: "show create procedure for external stored procedures", + Assertions: []ScriptTestAssertion{ + { + Query: "show create procedure memory_variadic_overload;", + Expected: []sql.Row{{ + "memory_variadic_overload", + "", + "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", + "utf8mb4", + "utf8mb4_0900_bin", + "utf8mb4_0900_bin", + }}, + }, + }, + }, } diff --git a/sql/plan/procedure.go b/sql/plan/procedure.go index 7b1f1af772..c7af76ef67 100644 --- a/sql/plan/procedure.go +++ b/sql/plan/procedure.go @@ -132,15 +132,24 @@ func NewProcedure( // Resolved implements the sql.Node interface. func (p *Procedure) Resolved() bool { + if p.ExternalProc != nil { + return p.ExternalProc.Resolved() + } return true } func (p *Procedure) IsReadOnly() bool { + if p.ExternalProc != nil { + return p.ExternalProc.IsReadOnly() + } return false } // String implements the sql.Node interface. func (p *Procedure) String() string { + if p.ExternalProc != nil { + return p.ExternalProc.String() + } return "" } @@ -156,12 +165,23 @@ func (p *Procedure) Schema() sql.Schema { // Children implements the sql.Node interface. func (p *Procedure) Children() []sql.Node { + if p.ExternalProc != nil { + return []sql.Node{p.ExternalProc} + } return nil } // WithChildren implements the sql.Node interface. func (p *Procedure) WithChildren(children ...sql.Node) (sql.Node, error) { - return p, nil + if len(children) == 0 { + return p, nil + } + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 1) + } + np := *p + np.ExternalProc = children[0] + return &np, nil } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/rowexec/proc_iters.go b/sql/rowexec/proc_iters.go index 5ec79d1463..87e1a9e837 100644 --- a/sql/rowexec/proc_iters.go +++ b/sql/rowexec/proc_iters.go @@ -134,6 +134,55 @@ func (ci *callIter) Close(ctx *sql.Context) error { if err != nil { return err } + if ci.call.Procedure.ExternalProc == nil { + return nil + } + // Set all user and system variables from INOUT and OUT params + for i, param := range ci.call.Procedure.Params { + if param.Direction == plan.ProcedureParamDirection_Inout || + (param.Direction == plan.ProcedureParamDirection_Out && ci.call.Pref.VariableHasBeenSet(param.Name)) { + val, err := ci.call.Pref.GetVariableValue(param.Name) + if err != nil { + return err + } + + typ := ci.call.Pref.GetVariableType(param.Name) + + switch callParam := ci.call.Params[i].(type) { + case *expression.UserVar: + err = ctx.SetUserVariable(ctx, callParam.Name, val, typ) + if err != nil { + return err + } + case *expression.SystemVar: + // This should have been caught by the analyzer, so a major bug exists somewhere + return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name) + case *expression.ProcedureParam: + err = callParam.Set(val, param.Type) + if err != nil { + return err + } + } + } else if param.Direction == plan.ProcedureParamDirection_Out { // VariableHasBeenSet was false + // For OUT only, if a var was not set within the procedure body, then we set the vars to nil. + // If the var had a value before the call then it is basically removed. + switch callParam := ci.call.Params[i].(type) { + case *expression.UserVar: + err = ctx.SetUserVariable(ctx, callParam.Name, nil, ci.call.Pref.GetVariableType(param.Name)) + if err != nil { + return err + } + case *expression.SystemVar: + // This should have been caught by the analyzer, so a major bug exists somewhere + return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name) + case *expression.ProcedureParam: + err := callParam.Set(nil, param.Type) + if err != nil { + return err + } + } + } + } return nil } From 1462fc75c5bad87513131665fb6a55b599308349 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 26 Mar 2025 00:07:12 -0700 Subject: [PATCH 062/111] bump --- go.mod | 2 +- go.sum | 16 ++-------------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 4ab2805a3f..346a17ba98 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250319212010-451ea8d003fa github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250325024605-8131be3ca6d3 + github.com/dolthub/vitess v0.0.0-20250326064017-04ab843d56dc github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index c0084bf8f5..99a42c325e 100644 --- a/go.sum +++ b/go.sum @@ -52,26 +52,14 @@ github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27 github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2/go.mod h1:mIEZOHnFx4ZMQeawhw9rhsj+0zwQj7adVsnBX7t+eKY= -github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 h1:rh2ij2yTYKJWlX+c8XRg4H5OzqPewbU1lPK8pcfVmx8= -github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= github.com/dolthub/go-icu-regex v0.0.0-20250319212010-451ea8d003fa h1:NFbzJ4wjWRz32nz2EimbrHpRx1Xt6k+IaR8N+j4x62k= github.com/dolthub/go-icu-regex v0.0.0-20250319212010-451ea8d003fa/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTEtT5tOBsCuCrlYnLRKpbJVJkDbrTRhwQ= github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730 h1:GtlMVB7+Z7fZZj7BHRFd2rzxZ574dJ8cB/EHWdq1kbY= -github.com/dolthub/vitess v0.0.0-20250228011932-c4f6bba87730/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4 h1:wtS9ZWEyEeYzLCcqdGUo+7i3hAV5MWuY9Z7tYbQa65A= -github.com/dolthub/vitess v0.0.0-20250303224041-5cc89c183bc4/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a h1:HIH9g4z+yXr4DIFyT6L5qOIEGJ1zVtlj6baPyHAG4Yw= -github.com/dolthub/vitess v0.0.0-20250304211657-920ca9ec2b9a/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250320231804-0e77d549294c h1:Dv2DfEGb8WRBi8I5KF5Sy39TuZi/FI692mpobKWcv4g= -github.com/dolthub/vitess v0.0.0-20250320231804-0e77d549294c/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250324212634-ee57ba96134a h1:cWE14qNrcxzsVfGOJ8HKPg9Q1MLWKA8ON7SYRqJmWCs= -github.com/dolthub/vitess v0.0.0-20250324212634-ee57ba96134a/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250325024605-8131be3ca6d3 h1:euU+adNAYw46Zcp1HnoaSDWhqjfaL8s/1SPU+i16gYM= -github.com/dolthub/vitess v0.0.0-20250325024605-8131be3ca6d3/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250326064017-04ab843d56dc h1:nmU0xvumZ8fKxjWI6P1rWvvHzvrbh+R+uWzXJCw/dP4= +github.com/dolthub/vitess v0.0.0-20250326064017-04ab843d56dc/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From a5de72607323a2ce4e2074b8b484271d1044963e Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 26 Mar 2025 07:08:45 +0000 Subject: [PATCH 063/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/memory_engine_test.go | 14 +++++++------- enginetest/queries/procedure_queries.go | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index ef6f6aec8a..ab96415374 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -208,16 +208,16 @@ func TestSingleScript(t *testing.T) { Name: "Nested CALL with INOUT param", SetUpScript: []string{ "SET @outparam = 5", - "CREATE PROCEDURE p3(INOUT z INT) BEGIN SET z = z * 111; END;", + "CREATE PROCEDURE p3(INOUT z INT) BEGIN SET z = z * 111; END;", "CREATE PROCEDURE p2(INOUT y DOUBLE) BEGIN SET y = y + 4; CALL p3(y); END;", "CREATE PROCEDURE p1(INOUT x BIGINT) BEGIN SET x = 3; CALL p2(x); END;", - "CALL p2(@outparam)", - }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "SELECT @outparam", + "CALL p2(@outparam)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT @outparam", Expected: []sql.Row{ - {int64(999)}, + {int64(999)}, }, }, }, diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 650b827c58..c2a02d0266 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -1396,7 +1396,7 @@ END;`, }, Assertions: []ScriptTestAssertion{ { - Query: "CALL outer_declare();", + Query: "CALL outer_declare();", Expected: []sql.Row{ {types.NewOkResult(0)}, }, From 77d7501766521dd3015d39807f20026b6f863a28 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 27 Mar 2025 11:22:31 -0700 Subject: [PATCH 064/111] todos --- enginetest/memory_engine_test.go | 30 ++++++++++++++--------------- sql/procedures/interpreter_logic.go | 1 + sql/rowexec/proc.go | 8 ++++++++ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index ab96415374..ade9b78246 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -204,24 +204,24 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ - { - Name: "Nested CALL with INOUT param", - SetUpScript: []string{ - "SET @outparam = 5", - "CREATE PROCEDURE p3(INOUT z INT) BEGIN SET z = z * 111; END;", - "CREATE PROCEDURE p2(INOUT y DOUBLE) BEGIN SET y = y + 4; CALL p3(y); END;", - "CREATE PROCEDURE p1(INOUT x BIGINT) BEGIN SET x = 3; CALL p2(x); END;", - "CALL p2(@outparam)", - }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "SELECT @outparam", - Expected: []sql.Row{ - {int64(999)}, - }, + { + Name: "Nested CALL with INOUT param", + SetUpScript: []string{ + "SET @outparam = 5", + "CREATE PROCEDURE p3(INOUT z INT) BEGIN SET z = z * 111; END;", + "CREATE PROCEDURE p2(INOUT y DOUBLE) BEGIN SET y = y + 4; CALL p3(y); END;", + "CREATE PROCEDURE p1(INOUT x BIGINT) BEGIN SET x = 3; CALL p2(x); END;", + "CALL p1(@outparam)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT @outparam", + Expected: []sql.Row{ + {int64(777)}, }, }, }, + }, } for _, test := range scripts { diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index ddf6a118d0..249f53f563 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -129,6 +129,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } case *ast.Call: for i := range e.Params { + // TODO: don't replace variables for session stack newExpr, err := replaceVariablesInExpr(stack, e.Params[i]) if err != nil { return nil, err diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 42f6bdc3b5..98565a3a82 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -229,6 +229,14 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq } } + // check for existing procedure stack in session + // if not found, create a new one + + // + + // TODO: add all procedure parameters + // ctx.Session.GetProcedureVariables + rowIter, stack, err := procedures.Call(ctx, n, procParams) if err != nil { return nil, err From 423e8bd8a29665444991d01442875f64f6590d57 Mon Sep 17 00:00:00 2001 From: James Cor Date: Sun, 30 Mar 2025 19:02:40 -0700 Subject: [PATCH 065/111] fix as of --- enginetest/memory_engine_test.go | 2 +- sql/plan/call.go | 10 +-- sql/procedures/interpreter_logic.go | 100 ++++++++++++++++++---------- 3 files changed, 69 insertions(+), 43 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index ade9b78246..04b2870f9f 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -202,7 +202,7 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - //t.Skip() + t.Skip() var scripts = []queries.ScriptTest{ { Name: "Nested CALL with INOUT param", diff --git a/sql/plan/call.go b/sql/plan/call.go index ee37d085d1..30dc5271bd 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -215,14 +215,8 @@ func (c *Call) GetRunner() sql.StatementRunner { return c.Runner } -// GetParameters implements the sql.InterpreterNode interface. -func (c *Call) GetParameters() []sql.Type { - return nil -} - -// GetParameterNames implements the sql.InterpreterNode interface. -func (c *Call) GetParameterNames() []string { - return nil +func (c *Call) GetAsOf() sql.Expression { + return c.asOf } // GetStatements implements the sql.InterpreterNode interface. diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 249f53f563..94d58c707a 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -32,6 +32,7 @@ import ( // InterpreterNode is an interface that implements an interpreter. These are typically used for functions (which may be // implemented as a set of operations that are interpreted during runtime). type InterpreterNode interface { + GetAsOf() sql.Expression GetRunner() sql.StatementRunner GetReturn() sql.Type GetStatements() []*InterpreterOperation @@ -44,7 +45,7 @@ type Parameter struct { Value any } -func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLNode, error) { +func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast.AsOf) (ast.SQLNode, error) { switch e := expr.(type) { case *ast.ColName: iv := stack.GetVariable(strings.ToLower(e.Name.String())) @@ -58,40 +59,43 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN StoredProcVal: newExpr, }, nil case *ast.ParenExpr: - newExpr, err := replaceVariablesInExpr(stack, e.Expr) + newExpr, err := replaceVariablesInExpr(stack, e.Expr, asOf) if err != nil { return nil, err } e.Expr = newExpr.(ast.Expr) case *ast.AliasedTableExpr: - newExpr, err := replaceVariablesInExpr(stack, e.Expr) + newExpr, err := replaceVariablesInExpr(stack, e.Expr, asOf) if err != nil { return nil, err } e.Expr = newExpr.(ast.SimpleTableExpr) + if e.AsOf == nil && asOf != nil { + e.AsOf = asOf + } case *ast.AliasedExpr: - newExpr, err := replaceVariablesInExpr(stack, e.Expr) + newExpr, err := replaceVariablesInExpr(stack, e.Expr, asOf) if err != nil { return nil, err } e.Expr = newExpr.(ast.Expr) case *ast.BinaryExpr: - newLeftExpr, err := replaceVariablesInExpr(stack, e.Left) + newLeftExpr, err := replaceVariablesInExpr(stack, e.Left, asOf) if err != nil { return nil, err } - newRightExpr, err := replaceVariablesInExpr(stack, e.Right) + newRightExpr, err := replaceVariablesInExpr(stack, e.Right, asOf) if err != nil { return nil, err } e.Left = newLeftExpr.(ast.Expr) e.Right = newRightExpr.(ast.Expr) case *ast.ComparisonExpr: - newLeftExpr, err := replaceVariablesInExpr(stack, e.Left) + newLeftExpr, err := replaceVariablesInExpr(stack, e.Left, asOf) if err != nil { return nil, err } - newRightExpr, err := replaceVariablesInExpr(stack, e.Right) + newRightExpr, err := replaceVariablesInExpr(stack, e.Right, asOf) if err != nil { return nil, err } @@ -99,14 +103,14 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN e.Right = newRightExpr.(ast.Expr) case *ast.FuncExpr: for i := range e.Exprs { - newExpr, err := replaceVariablesInExpr(stack, e.Exprs[i]) + newExpr, err := replaceVariablesInExpr(stack, e.Exprs[i], asOf) if err != nil { return nil, err } e.Exprs[i] = newExpr.(ast.SelectExpr) } case *ast.NotExpr: - newExpr, err := replaceVariablesInExpr(stack, e.Expr) + newExpr, err := replaceVariablesInExpr(stack, e.Expr, asOf) if err != nil { return nil, err } @@ -114,7 +118,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN case *ast.Set: for _, setExpr := range e.Exprs { // TODO: properly handle user scope variables - newExpr, err := replaceVariablesInExpr(stack, setExpr.Expr) + newExpr, err := replaceVariablesInExpr(stack, setExpr.Expr, asOf) if err != nil { return nil, err } @@ -130,18 +134,21 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN case *ast.Call: for i := range e.Params { // TODO: don't replace variables for session stack - newExpr, err := replaceVariablesInExpr(stack, e.Params[i]) + newExpr, err := replaceVariablesInExpr(stack, e.Params[i], asOf) if err != nil { return nil, err } e.Params[i] = newExpr.(ast.Expr) } + if e.AsOf == nil && asOf != nil { + e.AsOf = asOf.Time + } case *ast.Limit: - newOffset, err := replaceVariablesInExpr(stack, e.Offset) + newOffset, err := replaceVariablesInExpr(stack, e.Offset, asOf) if err != nil { return nil, err } - newRowCount, err := replaceVariablesInExpr(stack, e.Rowcount) + newRowCount, err := replaceVariablesInExpr(stack, e.Rowcount, asOf) if err != nil { return nil, err } @@ -154,7 +161,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN case *ast.Into: // TODO: somehow support select into variables for i := range e.Variables { - newExpr, err := replaceVariablesInExpr(stack, e.Variables[i]) + newExpr, err := replaceVariablesInExpr(stack, e.Variables[i], asOf) if err != nil { return nil, err } @@ -162,7 +169,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } case *ast.Select: for i := range e.SelectExprs { - newExpr, err := replaceVariablesInExpr(stack, e.SelectExprs[i]) + newExpr, err := replaceVariablesInExpr(stack, e.SelectExprs[i], asOf) if err != nil { return nil, err } @@ -170,7 +177,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } if e.With != nil { for i := range e.With.Ctes { - newExpr, err := replaceVariablesInExpr(stack, e.With.Ctes[i].AliasedTableExpr) + newExpr, err := replaceVariablesInExpr(stack, e.With.Ctes[i].AliasedTableExpr, asOf) if err != nil { return nil, err } @@ -178,38 +185,47 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } } if e.Into != nil { - newExpr, err := replaceVariablesInExpr(stack, e.Into) + newExpr, err := replaceVariablesInExpr(stack, e.Into, asOf) if err != nil { return nil, err } e.Into = newExpr.(*ast.Into) } if e.Where != nil { - newExpr, err := replaceVariablesInExpr(stack, e.Where.Expr) + newExpr, err := replaceVariablesInExpr(stack, e.Where.Expr, asOf) if err != nil { return nil, err } e.Where.Expr = newExpr.(ast.Expr) } if e.Limit != nil { - newExpr, err := replaceVariablesInExpr(stack, e.Limit) + newExpr, err := replaceVariablesInExpr(stack, e.Limit, asOf) if err != nil { return nil, err } e.Limit = newExpr.(*ast.Limit) } + if e.From != nil { + for i := range e.From { + newExpr, err := replaceVariablesInExpr(stack, e.From[i], asOf) + if err != nil { + return nil, err + } + e.From[i] = newExpr.(*ast.AliasedTableExpr) + } + } case *ast.Subquery: - newExpr, err := replaceVariablesInExpr(stack, e.Select) + newExpr, err := replaceVariablesInExpr(stack, e.Select, asOf) if err != nil { return nil, err } e.Select = newExpr.(*ast.Select) case *ast.SetOp: - newLeftExpr, err := replaceVariablesInExpr(stack, e.Left) + newLeftExpr, err := replaceVariablesInExpr(stack, e.Left, asOf) if err != nil { return nil, err } - newRightExpr, err := replaceVariablesInExpr(stack, e.Right) + newRightExpr, err := replaceVariablesInExpr(stack, e.Right, asOf) if err != nil { return nil, err } @@ -217,7 +233,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN e.Right = newRightExpr.(ast.SelectStatement) case ast.ValTuple: for i := range e { - newExpr, err := replaceVariablesInExpr(stack, e[i]) + newExpr, err := replaceVariablesInExpr(stack, e[i], asOf) if err != nil { return nil, err } @@ -225,14 +241,14 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode) (ast.SQLN } case *ast.AliasedValues: for i := range e.Values { - newExpr, err := replaceVariablesInExpr(stack, e.Values[i]) + newExpr, err := replaceVariablesInExpr(stack, e.Values[i], asOf) if err != nil { return nil, err } e.Values[i] = newExpr.(ast.ValTuple) } case *ast.Insert: - newExpr, err := replaceVariablesInExpr(stack, e.Rows) + newExpr, err := replaceVariablesInExpr(stack, e.Rows, asOf) if err != nil { return nil, err } @@ -299,7 +315,7 @@ func handleError(ctx *sql.Context, runner sql.StatementRunner, stack *Interprete return -1, err } - _, _, rowIter, err := execOp(ctx, runner, stack, handlerOps[0], handlerOps, -1) + _, _, rowIter, err := execOp(ctx, runner, stack, handlerOps[0], handlerOps, nil, -1) if err != nil { return -1, err } @@ -337,11 +353,11 @@ func handleError(ctx *sql.Context, runner sql.StatementRunner, stack *Interprete return counter, nil } -func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStack, operation *InterpreterOperation, statements []*InterpreterOperation, counter int) (int, sql.RowIter, sql.RowIter, error) { +func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStack, operation *InterpreterOperation, statements []*InterpreterOperation, asOf *ast.AsOf, counter int) (int, sql.RowIter, sql.RowIter, error) { switch operation.OpCode { case OpCode_Select: selectStmt := operation.PrimaryData.(*ast.Select) - if newSelectStmt, err := replaceVariablesInExpr(stack, selectStmt); err == nil { + if newSelectStmt, err := replaceVariablesInExpr(stack, selectStmt, asOf); err == nil { selectStmt = newSelectStmt.(*ast.Select) } else { return 0, nil, nil, err @@ -560,7 +576,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac if cursor.RowIter != nil { return 0, nil, nil, sql.ErrCursorAlreadyOpen.New(openCur.Name) } - stmt, err := replaceVariablesInExpr(stack, cursor.SelectStmt) + stmt, err := replaceVariablesInExpr(stack, cursor.SelectStmt, asOf) if err != nil { return 0, nil, nil, err } @@ -625,7 +641,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac panic("select stmt with no select exprs") } for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) + newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i], asOf) if err != nil { return 0, nil, nil, err } @@ -657,7 +673,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac panic("select stmt with no select exprs") } for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i]) + newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i], asOf) if err != nil { return 0, nil, nil, err } @@ -718,7 +734,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } case OpCode_Execute: - stmt, err := replaceVariablesInExpr(stack, operation.PrimaryData) + stmt, err := replaceVariablesInExpr(stack, operation.PrimaryData, asOf) if err != nil { return 0, nil, nil, err } @@ -753,6 +769,22 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.Row stack.NewVariableWithValue(param.Name, param.Type, param.Value) } + var asOf *ast.AsOf + if asOfExpr := iNode.GetAsOf(); asOfExpr != nil { + switch a := asOfExpr.(type) { + case *expression.Literal: + v, err := a.Eval(ctx, nil) + if err != nil { + return nil, nil, err + } + asOfStr := v.(string) + asOf = &ast.AsOf{ + Time: ast.NewStrVal([]byte(asOfStr)), + } + default: + } + } + // TODO: remove this; track last selectRowIter var selIter sql.RowIter @@ -771,7 +803,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.Row } operation := statements[counter] - newCounter, newSelIter, rowIter, err := execOp(ctx, runner, stack, operation, statements, counter) + newCounter, newSelIter, rowIter, err := execOp(ctx, runner, stack, operation, statements, asOf, counter) if err != nil { hCounter, hErr := handleError(ctx, runner, stack, statements, counter, err) if hErr != nil && hErr != io.EOF { From a7feeda212d40757162dfe4d391a596d2ccb3ecc Mon Sep 17 00:00:00 2001 From: James Cor Date: Sun, 30 Mar 2025 19:08:35 -0700 Subject: [PATCH 066/111] fix join --- enginetest/memory_engine_test.go | 71 +++++++++++++++++++++++++---- sql/procedures/interpreter_logic.go | 2 +- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 04b2870f9f..47a07cb048 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -202,22 +202,75 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - t.Skip() + //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "Nested CALL with INOUT param", + Name: "SELECT with JOIN and table aliases", SetUpScript: []string{ - "SET @outparam = 5", - "CREATE PROCEDURE p3(INOUT z INT) BEGIN SET z = z * 111; END;", - "CREATE PROCEDURE p2(INOUT y DOUBLE) BEGIN SET y = y + 4; CALL p3(y); END;", - "CREATE PROCEDURE p1(INOUT x BIGINT) BEGIN SET x = 3; CALL p2(x); END;", - "CALL p1(@outparam)", + "CREATE TABLE foo(a BIGINT PRIMARY KEY, b VARCHAR(20))", + "INSERT INTO foo VALUES (1, 'd'), (2, 'e'), (3, 'f')", + "CREATE TABLE bar(b VARCHAR(30) PRIMARY KEY, c BIGINT)", + "INSERT INTO bar VALUES ('x', 3), ('y', 2), ('z', 1)", + // Direct child is SELECT + "CREATE PROCEDURE p1() SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1", + // Direct child is BEGIN/END + "CREATE PROCEDURE p2() BEGIN SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1; END;", + // Direct child is IF + "CREATE PROCEDURE p3() IF 0 = 0 THEN SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1; END IF;", + // Direct child is BEGIN/END with preceding SELECT + "CREATE PROCEDURE p4() BEGIN SELECT 7; SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1; END;", + // Direct child is IF with preceding SELECT + "CREATE PROCEDURE p5() IF 0 = 0 THEN SELECT 7; SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1; END IF;", }, Assertions: []queries.ScriptTestAssertion{ + { // Enforces that this is the expected output from the query normally + Query: "SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1", + Expected: []sql.Row{ + {int64(1), "z", "d"}, + {int64(2), "y", "e"}, + {int64(3), "x", "f"}, + }, + }, + { + Query: "CALL p1()", + Expected: []sql.Row{ + {int64(1), "z", "d"}, + {int64(2), "y", "e"}, + {int64(3), "x", "f"}, + }, + }, + { + Query: "CALL p2()", + Expected: []sql.Row{ + {int64(1), "z", "d"}, + {int64(2), "y", "e"}, + {int64(3), "x", "f"}, + }, + }, + { + SkipResultCheckOnServerEngine: true, // tracking issue: https://github.com/dolthub/dolt/issues/6918 + Query: "CALL p3()", + Expected: []sql.Row{ + {int64(1), "z", "d"}, + {int64(2), "y", "e"}, + {int64(3), "x", "f"}, + }, + }, + { + Query: "CALL p4()", + Expected: []sql.Row{ + {int64(1), "z", "d"}, + {int64(2), "y", "e"}, + {int64(3), "x", "f"}, + }, + }, { - Query: "SELECT @outparam", + SkipResultCheckOnServerEngine: true, // tracking issue: https://github.com/dolthub/dolt/issues/6918 + Query: "CALL p5()", Expected: []sql.Row{ - {int64(777)}, + {int64(1), "z", "d"}, + {int64(2), "y", "e"}, + {int64(3), "x", "f"}, }, }, }, diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 94d58c707a..8e7c8e4cc7 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -211,7 +211,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast if err != nil { return nil, err } - e.From[i] = newExpr.(*ast.AliasedTableExpr) + e.From[i] = newExpr.(ast.TableExpr) } } case *ast.Subquery: From 2027f89fd8b4b61d164d1f214a64f946f04dd987 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 31 Mar 2025 11:08:18 -0700 Subject: [PATCH 067/111] not sure what's wrong with double insert yet --- enginetest/memory_engine_test.go | 77 +++++++++-------------------- sql/procedures/interpreter_logic.go | 16 ++++++ 2 files changed, 40 insertions(+), 53 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 47a07cb048..f62d6fedaf 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -205,72 +205,43 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "SELECT with JOIN and table aliases", + Name: "insert trigger with stored procedure with deletes", SetUpScript: []string{ - "CREATE TABLE foo(a BIGINT PRIMARY KEY, b VARCHAR(20))", - "INSERT INTO foo VALUES (1, 'd'), (2, 'e'), (3, 'f')", - "CREATE TABLE bar(b VARCHAR(30) PRIMARY KEY, c BIGINT)", - "INSERT INTO bar VALUES ('x', 3), ('y', 2), ('z', 1)", - // Direct child is SELECT - "CREATE PROCEDURE p1() SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1", - // Direct child is BEGIN/END - "CREATE PROCEDURE p2() BEGIN SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1; END;", - // Direct child is IF - "CREATE PROCEDURE p3() IF 0 = 0 THEN SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1; END IF;", - // Direct child is BEGIN/END with preceding SELECT - "CREATE PROCEDURE p4() BEGIN SELECT 7; SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1; END;", - // Direct child is IF with preceding SELECT - "CREATE PROCEDURE p5() IF 0 = 0 THEN SELECT 7; SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1; END IF;", + "create table t (i int);", + "create table t1 (j int);", + ` +create procedure proc(x int) +begin + insert into t1 values (x); + insert into t1 values (x); +end; +`, + ` +create trigger trig before insert on t +for each row +begin + call proc(new.i); +end; +`, }, Assertions: []queries.ScriptTestAssertion{ - { // Enforces that this is the expected output from the query normally - Query: "SELECT f.a, bar.b, f.b FROM foo f INNER JOIN bar ON f.a = bar.c ORDER BY 1", - Expected: []sql.Row{ - {int64(1), "z", "d"}, - {int64(2), "y", "e"}, - {int64(3), "x", "f"}, - }, - }, - { - Query: "CALL p1()", - Expected: []sql.Row{ - {int64(1), "z", "d"}, - {int64(2), "y", "e"}, - {int64(3), "x", "f"}, - }, - }, - { - Query: "CALL p2()", - Expected: []sql.Row{ - {int64(1), "z", "d"}, - {int64(2), "y", "e"}, - {int64(3), "x", "f"}, - }, - }, { - SkipResultCheckOnServerEngine: true, // tracking issue: https://github.com/dolthub/dolt/issues/6918 - Query: "CALL p3()", + Query: "insert into t values (1);", Expected: []sql.Row{ - {int64(1), "z", "d"}, - {int64(2), "y", "e"}, - {int64(3), "x", "f"}, + {types.NewOkResult(1)}, }, }, { - Query: "CALL p4()", + Query: "select * from t;", Expected: []sql.Row{ - {int64(1), "z", "d"}, - {int64(2), "y", "e"}, - {int64(3), "x", "f"}, + {1}, }, }, { - SkipResultCheckOnServerEngine: true, // tracking issue: https://github.com/dolthub/dolt/issues/6918 - Query: "CALL p5()", + Query: "select * from t1;", Expected: []sql.Row{ - {int64(1), "z", "d"}, - {int64(2), "y", "e"}, - {int64(3), "x", "f"}, + {1}, + {1}, }, }, }, diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 8e7c8e4cc7..555b93b7c3 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -253,6 +253,22 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast return nil, err } e.Rows = newExpr.(ast.InsertRows) + case *ast.Delete: + if e.Where != nil { + newExpr, err := replaceVariablesInExpr(stack, e.Where.Expr, asOf) + if err != nil { + return nil, err + } + e.Where.Expr = newExpr.(ast.Expr) + } + case *ast.Update: + if e.Where != nil { + newExpr, err := replaceVariablesInExpr(stack, e.Where.Expr, asOf) + if err != nil { + return nil, err + } + e.Where.Expr = newExpr.(ast.Expr) + } } return expr, nil } From 9df3d57fbecc4870a514a026db5f5f89987a133e Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 1 Apr 2025 16:55:06 -0700 Subject: [PATCH 068/111] fixing triggers --- enginetest/memory_engine_test.go | 17 ++++++++--------- sql/procedures/interpreter_logic.go | 8 ++++++++ sql/session.go | 4 ++++ 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index da527ee3c6..b895aec964 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -212,8 +212,7 @@ func TestSingleScript(t *testing.T) { ` create procedure proc(x int) begin - insert into t1 values (x); - insert into t1 values (x); + insert into t1 values (x + 100); end; `, ` @@ -237,13 +236,13 @@ end; {1}, }, }, - { - Query: "select * from t1;", - Expected: []sql.Row{ - {1}, - {1}, - }, - }, + //{ + // Query: "select * from t1;", + // Expected: []sql.Row{ + // {101}, + // {201}, + // }, + //}, }, }, } diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 555b93b7c3..a2413b568b 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -778,6 +778,14 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac // Call runs the contained operations on the given runner. func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.RowIter, *InterpreterStack, error) { + // TODO: what about nested stored procedures? + if transSess, isTransSess := ctx.Session.(sql.TransactionSession); isTransSess { + transSess.SetInStoredProcedure(true) + defer func() { + transSess.SetInStoredProcedure(false) + }() + } + // Set up the initial state of the function counter := -1 // We increment before accessing, so start at -1 stack := NewInterpreterStack() diff --git a/sql/session.go b/sql/session.go index 8e45dc3305..25413cd900 100644 --- a/sql/session.go +++ b/sql/session.go @@ -171,6 +171,8 @@ type Session interface { // ValidateSession provides integrators a chance to do any custom validation of this session before any query is // executed in it. For example, Dolt uses this hook to validate that the session's working set is valid. ValidateSession(ctx *Context) error + + //SetInStoredProcedure(val bool) } // PersistableSession supports serializing/deserializing global system variables/ @@ -205,6 +207,8 @@ type TransactionSession interface { RollbackToSavepoint(ctx *Context, transaction Transaction, name string) error // ReleaseSavepoint removes the savepoint named from the transaction given ReleaseSavepoint(ctx *Context, transaction Transaction, name string) error + + SetInStoredProcedure(val bool) } // A LifecycleAwareSession is a a sql.Session that gets lifecycle callbacks From 742e9a8e237232da73b5d94cc91daa47092212eb Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 2 Apr 2025 10:36:11 -0700 Subject: [PATCH 069/111] stored procedures from triggers --- sql/procedures/interpreter_logic.go | 8 -------- sql/rowexec/proc.go | 17 +++++++++-------- sql/session.go | 2 -- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index a2413b568b..555b93b7c3 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -778,14 +778,6 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac // Call runs the contained operations on the given runner. func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.RowIter, *InterpreterStack, error) { - // TODO: what about nested stored procedures? - if transSess, isTransSess := ctx.Session.(sql.TransactionSession); isTransSess { - transSess.SetInStoredProcedure(true) - defer func() { - transSess.SetInStoredProcedure(false) - }() - } - // Set up the initial state of the function counter := -1 // We increment before accessing, so start at -1 stack := NewInterpreterStack() diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 98565a3a82..a063ca04be 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -229,14 +229,6 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq } } - // check for existing procedure stack in session - // if not found, create a new one - - // - - // TODO: add all procedure parameters - // ctx.Session.GetProcedureVariables - rowIter, stack, err := procedures.Call(ctx, n, procParams) if err != nil { return nil, err @@ -270,6 +262,15 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq } } + // We might close transactions in the procedure, so we need to start a new one if we're not in one already + if sess, ok := ctx.Session.(sql.TransactionSession); ok { + tx, tErr := sess.StartTransaction(ctx, sql.ReadWrite) + if tErr != nil { + return nil, tErr + } + ctx.SetTransaction(tx) + } + return &callIter{ call: n, innerIter: rowIter, diff --git a/sql/session.go b/sql/session.go index 25413cd900..500b4ef462 100644 --- a/sql/session.go +++ b/sql/session.go @@ -207,8 +207,6 @@ type TransactionSession interface { RollbackToSavepoint(ctx *Context, transaction Transaction, name string) error // ReleaseSavepoint removes the savepoint named from the transaction given ReleaseSavepoint(ctx *Context, transaction Transaction, name string) error - - SetInStoredProcedure(val bool) } // A LifecycleAwareSession is a a sql.Session that gets lifecycle callbacks From 751455904e8204b467f7b29c5f914d66632e57a4 Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 2 Apr 2025 17:38:37 +0000 Subject: [PATCH 070/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/memory_engine_test.go | 52 ++++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index b895aec964..295926b921 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -204,47 +204,47 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ - { - Name: "insert trigger with stored procedure with deletes", - SetUpScript: []string{ - "create table t (i int);", - "create table t1 (j int);", - ` + { + Name: "insert trigger with stored procedure with deletes", + SetUpScript: []string{ + "create table t (i int);", + "create table t1 (j int);", + ` create procedure proc(x int) begin insert into t1 values (x + 100); end; `, - ` + ` create trigger trig before insert on t for each row begin call proc(new.i); end; `, - }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "insert into t values (1);", - Expected: []sql.Row{ - {types.NewOkResult(1)}, - }, }, - { - Query: "select * from t;", - Expected: []sql.Row{ - {1}, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t values (1);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, + }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1}, + }, }, + //{ + // Query: "select * from t1;", + // Expected: []sql.Row{ + // {101}, + // {201}, + // }, + //}, }, - //{ - // Query: "select * from t1;", - // Expected: []sql.Row{ - // {101}, - // {201}, - // }, - //}, }, - }, } for _, test := range scripts { From 7969c062a460a540c38ec77a2bc42551c71e1791 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 2 Apr 2025 11:12:36 -0700 Subject: [PATCH 071/111] go mod tidy --- go.sum | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go.sum b/go.sum index ba4efa74af..79090cef33 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250325024605-8131be3ca6d3 h1:euU+adNAYw46Zcp1HnoaSDWhqjfaL8s/1SPU+i16gYM= -github.com/dolthub/vitess v0.0.0-20250325024605-8131be3ca6d3/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250326064017-04ab843d56dc h1:nmU0xvumZ8fKxjWI6P1rWvvHzvrbh+R+uWzXJCw/dP4= +github.com/dolthub/vitess v0.0.0-20250326064017-04ab843d56dc/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From 6f570b00d61d8a0d472a6c3f5b24bd936c66ad1c Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 4 Apr 2025 02:09:31 -0700 Subject: [PATCH 072/111] fix parameters --- sql/base_session.go | 67 ++++++++++++----- sql/core.go | 15 ++++ sql/planbuilder/proc.go | 18 +++++ sql/procedures/interpreter_logic.go | 109 +++++++++++++++++----------- sql/rowexec/proc.go | 5 +- sql/session.go | 7 ++ 6 files changed, 157 insertions(+), 64 deletions(-) diff --git a/sql/base_session.go b/sql/base_session.go index d5998dc6fe..69b4faaf16 100644 --- a/sql/base_session.go +++ b/sql/base_session.go @@ -53,6 +53,8 @@ type BaseSession struct { // privilege set if our counter doesn't equal the database's counter. privSetCounter uint64 privilegeSet PrivilegeSet + + storedProcParams map[string]*StoredProcParam } func (s *BaseSession) GetLogger() *logrus.Entry { @@ -252,6 +254,29 @@ func (s *BaseSession) IncrementStatusVariable(ctx *Context, statVarName string, return } +func (s *BaseSession) NewStoredProcParam(name string, param *StoredProcParam) { + if _, ok := s.storedProcParams[name]; ok { + return + } + s.storedProcParams[name] = param +} + +func (s *BaseSession) GetStoredProcParam(name string) *StoredProcParam { + if param, ok := s.storedProcParams[name]; ok { + return param + } + return nil +} + +func (s *BaseSession) SetStoredProcParam(name string, val any) error { + param := s.GetStoredProcParam(name) + if param == nil { + return fmt.Errorf("variable `%s` could not be found", name) + } + param.SetValue(val) + return nil +} + // GetCharacterSet returns the character set for this session (defined by the system variable `character_set_connection`). func (s *BaseSession) GetCharacterSet() CharacterSetID { sysVar, _ := s.systemVars[characterSetConnectionSysVarName] @@ -504,17 +529,18 @@ func NewBaseSessionWithClientServer(server string, client Client, id uint32) *Ba statusVars = make(map[string]StatusVarValue) } return &BaseSession{ - addr: server, - client: client, - id: id, - systemVars: systemVars, - statusVars: statusVars, - userVars: NewUserVars(), - idxReg: NewIndexRegistry(), - viewReg: NewViewRegistry(), - locks: make(map[string]bool), - lastQueryInfo: defaultLastQueryInfo(), - privSetCounter: 0, + addr: server, + client: client, + id: id, + systemVars: systemVars, + statusVars: statusVars, + userVars: NewUserVars(), + storedProcParams: make(map[string]*StoredProcParam), + idxReg: NewIndexRegistry(), + viewReg: NewViewRegistry(), + locks: make(map[string]bool), + lastQueryInfo: defaultLastQueryInfo(), + privSetCounter: 0, } } @@ -534,14 +560,15 @@ func NewBaseSession() *BaseSession { statusVars = make(map[string]StatusVarValue) } return &BaseSession{ - id: atomic.AddUint32(&autoSessionIDs, 1), - systemVars: systemVars, - statusVars: statusVars, - userVars: NewUserVars(), - idxReg: NewIndexRegistry(), - viewReg: NewViewRegistry(), - locks: make(map[string]bool), - lastQueryInfo: defaultLastQueryInfo(), - privSetCounter: 0, + id: atomic.AddUint32(&autoSessionIDs, 1), + systemVars: systemVars, + statusVars: statusVars, + userVars: NewUserVars(), + storedProcParams: make(map[string]*StoredProcParam), + idxReg: NewIndexRegistry(), + viewReg: NewViewRegistry(), + locks: make(map[string]bool), + lastQueryInfo: defaultLastQueryInfo(), + privSetCounter: 0, } } diff --git a/sql/core.go b/sql/core.go index f56d55bfa6..685d3d92a6 100644 --- a/sql/core.go +++ b/sql/core.go @@ -880,6 +880,21 @@ func IncrementStatusVariable(ctx *Context, name string, val int) { ctx.Session.IncrementStatusVariable(ctx, name, val) } +type StoredProcParam struct { + Type Type + Value any + HasBeenSet bool + Reference *StoredProcParam +} + +func (s *StoredProcParam) SetValue(val any) { + s.Value = val + s.HasBeenSet = true + if s.Reference != nil { + s.Reference.SetValue(val) + } +} + // OrderAndLimit stores the context of an ORDER BY ... LIMIT statement, and is used by index lookups and iterators. type OrderAndLimit struct { OrderBy Expression diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index c8fb549cb7..9d4fdd10f9 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -345,8 +345,26 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { b.handleErr(err) } + // TODO: build references here? + // TODO: here fill in x from session params := make([]sql.Expression, len(c.Params)) for i, param := range c.Params { + if len(proc.Params) == len(c.Params) { + procParam := proc.Params[i] + rspp := &sql.StoredProcParam{Type: procParam.Type} + b.ctx.Session.NewStoredProcParam(procParam.Name, rspp) + if col, isCol := param.(*ast.ColName); isCol { + colName := col.Name.String() // TODO: to lower? + if spp := b.ctx.Session.GetStoredProcParam(colName); spp != nil { + iv := &procedures.InterpreterVariable{ + Type: spp.Type, + Value: spp.Value, + } + param = iv.ToAST() + rspp.Reference = spp + } + } + } expr := b.buildScalar(inScope, param) params[i] = expr } diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 555b93b7c3..580b65866c 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -45,12 +45,21 @@ type Parameter struct { Value any } -func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast.AsOf) (ast.SQLNode, error) { +func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast.SQLNode, asOf *ast.AsOf) (ast.SQLNode, error) { switch e := expr.(type) { case *ast.ColName: - iv := stack.GetVariable(strings.ToLower(e.Name.String())) + varName := strings.ToLower(e.Name.String()) + iv := stack.GetVariable(varName) if iv == nil { - return expr, nil + spp := ctx.Session.GetStoredProcParam(varName) + if spp == nil { + return expr, nil + } + iv = &InterpreterVariable{ + Value: spp.Value, + Type: spp.Type, + HasBeenSet: spp.HasBeenSet, + } } newExpr := iv.ToAST() return &ast.ColName{ @@ -59,13 +68,13 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast StoredProcVal: newExpr, }, nil case *ast.ParenExpr: - newExpr, err := replaceVariablesInExpr(stack, e.Expr, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Expr, asOf) if err != nil { return nil, err } e.Expr = newExpr.(ast.Expr) case *ast.AliasedTableExpr: - newExpr, err := replaceVariablesInExpr(stack, e.Expr, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Expr, asOf) if err != nil { return nil, err } @@ -74,28 +83,28 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast e.AsOf = asOf } case *ast.AliasedExpr: - newExpr, err := replaceVariablesInExpr(stack, e.Expr, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Expr, asOf) if err != nil { return nil, err } e.Expr = newExpr.(ast.Expr) case *ast.BinaryExpr: - newLeftExpr, err := replaceVariablesInExpr(stack, e.Left, asOf) + newLeftExpr, err := replaceVariablesInExpr(ctx, stack, e.Left, asOf) if err != nil { return nil, err } - newRightExpr, err := replaceVariablesInExpr(stack, e.Right, asOf) + newRightExpr, err := replaceVariablesInExpr(ctx, stack, e.Right, asOf) if err != nil { return nil, err } e.Left = newLeftExpr.(ast.Expr) e.Right = newRightExpr.(ast.Expr) case *ast.ComparisonExpr: - newLeftExpr, err := replaceVariablesInExpr(stack, e.Left, asOf) + newLeftExpr, err := replaceVariablesInExpr(ctx, stack, e.Left, asOf) if err != nil { return nil, err } - newRightExpr, err := replaceVariablesInExpr(stack, e.Right, asOf) + newRightExpr, err := replaceVariablesInExpr(ctx, stack, e.Right, asOf) if err != nil { return nil, err } @@ -103,14 +112,14 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast e.Right = newRightExpr.(ast.Expr) case *ast.FuncExpr: for i := range e.Exprs { - newExpr, err := replaceVariablesInExpr(stack, e.Exprs[i], asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Exprs[i], asOf) if err != nil { return nil, err } e.Exprs[i] = newExpr.(ast.SelectExpr) } case *ast.NotExpr: - newExpr, err := replaceVariablesInExpr(stack, e.Expr, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Expr, asOf) if err != nil { return nil, err } @@ -118,7 +127,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast case *ast.Set: for _, setExpr := range e.Exprs { // TODO: properly handle user scope variables - newExpr, err := replaceVariablesInExpr(stack, setExpr.Expr, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, setExpr.Expr, asOf) if err != nil { return nil, err } @@ -133,8 +142,8 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast } case *ast.Call: for i := range e.Params { - // TODO: don't replace variables for session stack - newExpr, err := replaceVariablesInExpr(stack, e.Params[i], asOf) + // TODO: do not replace certain params + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Params[i], asOf) if err != nil { return nil, err } @@ -144,11 +153,11 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast e.AsOf = asOf.Time } case *ast.Limit: - newOffset, err := replaceVariablesInExpr(stack, e.Offset, asOf) + newOffset, err := replaceVariablesInExpr(ctx, stack, e.Offset, asOf) if err != nil { return nil, err } - newRowCount, err := replaceVariablesInExpr(stack, e.Rowcount, asOf) + newRowCount, err := replaceVariablesInExpr(ctx, stack, e.Rowcount, asOf) if err != nil { return nil, err } @@ -161,7 +170,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast case *ast.Into: // TODO: somehow support select into variables for i := range e.Variables { - newExpr, err := replaceVariablesInExpr(stack, e.Variables[i], asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Variables[i], asOf) if err != nil { return nil, err } @@ -169,7 +178,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast } case *ast.Select: for i := range e.SelectExprs { - newExpr, err := replaceVariablesInExpr(stack, e.SelectExprs[i], asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.SelectExprs[i], asOf) if err != nil { return nil, err } @@ -177,7 +186,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast } if e.With != nil { for i := range e.With.Ctes { - newExpr, err := replaceVariablesInExpr(stack, e.With.Ctes[i].AliasedTableExpr, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.With.Ctes[i].AliasedTableExpr, asOf) if err != nil { return nil, err } @@ -185,21 +194,21 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast } } if e.Into != nil { - newExpr, err := replaceVariablesInExpr(stack, e.Into, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Into, asOf) if err != nil { return nil, err } e.Into = newExpr.(*ast.Into) } if e.Where != nil { - newExpr, err := replaceVariablesInExpr(stack, e.Where.Expr, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Where.Expr, asOf) if err != nil { return nil, err } e.Where.Expr = newExpr.(ast.Expr) } if e.Limit != nil { - newExpr, err := replaceVariablesInExpr(stack, e.Limit, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Limit, asOf) if err != nil { return nil, err } @@ -207,7 +216,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast } if e.From != nil { for i := range e.From { - newExpr, err := replaceVariablesInExpr(stack, e.From[i], asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.From[i], asOf) if err != nil { return nil, err } @@ -215,17 +224,17 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast } } case *ast.Subquery: - newExpr, err := replaceVariablesInExpr(stack, e.Select, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Select, asOf) if err != nil { return nil, err } e.Select = newExpr.(*ast.Select) case *ast.SetOp: - newLeftExpr, err := replaceVariablesInExpr(stack, e.Left, asOf) + newLeftExpr, err := replaceVariablesInExpr(ctx, stack, e.Left, asOf) if err != nil { return nil, err } - newRightExpr, err := replaceVariablesInExpr(stack, e.Right, asOf) + newRightExpr, err := replaceVariablesInExpr(ctx, stack, e.Right, asOf) if err != nil { return nil, err } @@ -233,7 +242,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast e.Right = newRightExpr.(ast.SelectStatement) case ast.ValTuple: for i := range e { - newExpr, err := replaceVariablesInExpr(stack, e[i], asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e[i], asOf) if err != nil { return nil, err } @@ -241,21 +250,21 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast } case *ast.AliasedValues: for i := range e.Values { - newExpr, err := replaceVariablesInExpr(stack, e.Values[i], asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Values[i], asOf) if err != nil { return nil, err } e.Values[i] = newExpr.(ast.ValTuple) } case *ast.Insert: - newExpr, err := replaceVariablesInExpr(stack, e.Rows, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Rows, asOf) if err != nil { return nil, err } e.Rows = newExpr.(ast.InsertRows) case *ast.Delete: if e.Where != nil { - newExpr, err := replaceVariablesInExpr(stack, e.Where.Expr, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Where.Expr, asOf) if err != nil { return nil, err } @@ -263,7 +272,7 @@ func replaceVariablesInExpr(stack *InterpreterStack, expr ast.SQLNode, asOf *ast } case *ast.Update: if e.Where != nil { - newExpr, err := replaceVariablesInExpr(stack, e.Where.Expr, asOf) + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Where.Expr, asOf) if err != nil { return nil, err } @@ -373,7 +382,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac switch operation.OpCode { case OpCode_Select: selectStmt := operation.PrimaryData.(*ast.Select) - if newSelectStmt, err := replaceVariablesInExpr(stack, selectStmt, asOf); err == nil { + if newSelectStmt, err := replaceVariablesInExpr(ctx, stack, selectStmt, asOf); err == nil { selectStmt = newSelectStmt.(*ast.Select) } else { return 0, nil, nil, err @@ -416,7 +425,10 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } err = stack.SetVariable(intoVar, row[i]) if err != nil { - return 0, nil, nil, err + err = ctx.Session.SetStoredProcParam(intoVar, row[i]) + if err != nil { + return 0, nil, nil, err + } } } @@ -592,7 +604,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac if cursor.RowIter != nil { return 0, nil, nil, sql.ErrCursorAlreadyOpen.New(openCur.Name) } - stmt, err := replaceVariablesInExpr(stack, cursor.SelectStmt, asOf) + stmt, err := replaceVariablesInExpr(ctx, stack, cursor.SelectStmt, asOf) if err != nil { return 0, nil, nil, err } @@ -657,7 +669,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac panic("select stmt with no select exprs") } for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i], asOf) + newNode, err := replaceVariablesInExpr(ctx, stack, selectStmt.SelectExprs[i], asOf) if err != nil { return 0, nil, nil, err } @@ -680,7 +692,10 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac err = stack.SetVariable(strings.ToLower(operation.Target), row[0]) if err != nil { - return 0, nil, nil, err + err = ctx.Session.SetStoredProcParam(operation.Target, row[0]) + if err != nil { + return 0, nil, nil, err + } } case OpCode_If: @@ -689,7 +704,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac panic("select stmt with no select exprs") } for i := range selectStmt.SelectExprs { - newNode, err := replaceVariablesInExpr(stack, selectStmt.SelectExprs[i], asOf) + newNode, err := replaceVariablesInExpr(ctx, stack, selectStmt.SelectExprs[i], asOf) if err != nil { return 0, nil, nil, err } @@ -750,10 +765,11 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } case OpCode_Execute: - stmt, err := replaceVariablesInExpr(stack, operation.PrimaryData, asOf) + stmt, err := replaceVariablesInExpr(ctx, stack, operation.PrimaryData, asOf) if err != nil { return 0, nil, nil, err } + // TODO: create a OpCode_Call to store procedures in the stack rowIter, err := query(ctx, runner, stmt.(ast.Statement)) if err != nil { return 0, nil, nil, err @@ -778,12 +794,21 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac // Call runs the contained operations on the given runner. func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.RowIter, *InterpreterStack, error) { + for _, param := range params { + var spp *sql.StoredProcParam + spp = ctx.Session.GetStoredProcParam(param.Name) + for spp != nil { + spp.Value = param.Value + spp = spp.Reference + } + } + // Set up the initial state of the function counter := -1 // We increment before accessing, so start at -1 stack := NewInterpreterStack() - for _, param := range params { - stack.NewVariableWithValue(param.Name, param.Type, param.Value) - } + //for _, param := range params { + // stack.NewVariableWithValue(param.Name, param.Type, param.Value) + //} var asOf *ast.AsOf if asOfExpr := iNode.GetAsOf(); asOfExpr != nil { diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index a063ca04be..427bbf212a 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -209,6 +209,7 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq }, nil } + // TODO: replace with direct ctx modification procParams := make([]*procedures.Parameter, len(n.Params)) for i, paramExpr := range n.Params { param := n.Procedure.Params[i] @@ -229,7 +230,7 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq } } - rowIter, stack, err := procedures.Call(ctx, n, procParams) + rowIter, _, err := procedures.Call(ctx, n, procParams) if err != nil { return nil, err } @@ -240,7 +241,7 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq continue } // Set all user and system variables from INOUT and OUT params - stackVar := stack.GetVariable(procParam.Name) // TODO: ToLower? + stackVar := ctx.Session.GetStoredProcParam(procParam.Name) // TODO: ToLower? switch p := param.(type) { case *expression.ProcedureParam: err = p.Set(stackVar.Value, stackVar.Type) diff --git a/sql/session.go b/sql/session.go index 500b4ef462..852939f76f 100644 --- a/sql/session.go +++ b/sql/session.go @@ -92,6 +92,13 @@ type Session interface { GetAllStatusVariables(ctx *Context) map[string]StatusVarValue // IncrementStatusVariable increments the value of the status variable by the integer value IncrementStatusVariable(ctx *Context, statVarName string, val int) + + NewStoredProcParam(name string, param *StoredProcParam) + + GetStoredProcParam(name string) *StoredProcParam + + SetStoredProcParam(name string, val any) error + // GetCurrentDatabase gets the current database for this session GetCurrentDatabase() string // SetCurrentDatabase sets the current database for this session From 0872ac17646dd72de1bea6cbb8f1fdf8d58b55e8 Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 4 Apr 2025 09:10:49 +0000 Subject: [PATCH 073/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/base_session.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/base_session.go b/sql/base_session.go index 69b4faaf16..eb9320f5b4 100644 --- a/sql/base_session.go +++ b/sql/base_session.go @@ -270,7 +270,7 @@ func (s *BaseSession) GetStoredProcParam(name string) *StoredProcParam { func (s *BaseSession) SetStoredProcParam(name string, val any) error { param := s.GetStoredProcParam(name) - if param == nil { + if param == nil { return fmt.Errorf("variable `%s` could not be found", name) } param.SetValue(val) From 35c46febba9e2bcd22bffd914764218ca0b999bb Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 4 Apr 2025 10:48:08 -0700 Subject: [PATCH 074/111] test --- enginetest/memory_engine_test.go | 75 +++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 26 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 295926b921..46d3191746 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -204,47 +204,70 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ - { - Name: "insert trigger with stored procedure with deletes", - SetUpScript: []string{ - "create table t (i int);", - "create table t1 (j int);", - ` + { + Name: "insert trigger with stored procedure with deletes", + SetUpScript: []string{ + "create table t (i int);", + "create table t1 (j int);", + "insert into t1 values (1);", + "create table t2 (k int);", + "insert into t2 values (1);", + "create table t3 (l int);", + "insert into t3 values (1);", + "create table t4 (m int);", + ` create procedure proc(x int) begin - insert into t1 values (x + 100); + delete from t2 where k = (select j from t1 where j = x); + update t3 set l = 10 where l = x; + insert into t4 values (x); end; `, - ` + ` create trigger trig before insert on t for each row begin call proc(new.i); end; `, + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t values (1);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "insert into t values (1);", - Expected: []sql.Row{ - {types.NewOkResult(1)}, - }, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1}, }, - { - Query: "select * from t;", - Expected: []sql.Row{ - {1}, - }, + }, + { + Query: "select * from t1;", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "select * from t2;", + Expected: []sql.Row{}, + }, + { + Query: "select * from t3;", + Expected: []sql.Row{ + {10}, + }, + }, + { + Query: "select * from t4;", + Expected: []sql.Row{ + {1}, }, - //{ - // Query: "select * from t1;", - // Expected: []sql.Row{ - // {101}, - // {201}, - // }, - //}, }, }, + }, } for _, test := range scripts { From c137d53461c793b8f27c7dd6bbbca990f7371590 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 4 Apr 2025 14:00:57 -0700 Subject: [PATCH 075/111] fix external procs --- sql/rowexec/rel.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/rowexec/rel.go b/sql/rowexec/rel.go index c95adc2825..308acab4de 100644 --- a/sql/rowexec/rel.go +++ b/sql/rowexec/rel.go @@ -730,12 +730,17 @@ func (b *BaseBuilder) buildExternalProcedure(ctx *sql.Context, n *plan.ExternalP } for i, paramDefinition := range n.ParamDefinitions { if paramDefinition.Direction == plan.ProcedureParamDirection_Inout || paramDefinition.Direction == plan.ProcedureParamDirection_Out { + // TODO: not sure if we should still be doing this exprParam := n.Params[i] funcParamVal := funcParams[i+1].Elem().Interface() err := exprParam.Set(funcParamVal, exprParam.Type()) if err != nil { return nil, err } + err = ctx.Session.SetStoredProcParam(exprParam.Name(), funcParamVal) + if err != nil { + return nil, err + } } } // It's not invalid to return a nil RowIter, as having no rows to return is expected of many stored procedures. From cdfc450e6c65ee7471bf89e793b1658a2ce62880 Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 4 Apr 2025 21:02:17 +0000 Subject: [PATCH 076/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/memory_engine_test.go | 88 ++++++++++++++++---------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 46d3191746..c814ffc4f8 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -204,18 +204,18 @@ func TestSingleQueryPrepared(t *testing.T) { func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ - { - Name: "insert trigger with stored procedure with deletes", - SetUpScript: []string{ - "create table t (i int);", - "create table t1 (j int);", - "insert into t1 values (1);", - "create table t2 (k int);", - "insert into t2 values (1);", - "create table t3 (l int);", - "insert into t3 values (1);", - "create table t4 (m int);", - ` + { + Name: "insert trigger with stored procedure with deletes", + SetUpScript: []string{ + "create table t (i int);", + "create table t1 (j int);", + "insert into t1 values (1);", + "create table t2 (k int);", + "insert into t2 values (1);", + "create table t3 (l int);", + "insert into t3 values (1);", + "create table t4 (m int);", + ` create procedure proc(x int) begin delete from t2 where k = (select j from t1 where j = x); @@ -223,51 +223,51 @@ begin insert into t4 values (x); end; `, - ` + ` create trigger trig before insert on t for each row begin call proc(new.i); end; `, - }, - Assertions: []queries.ScriptTestAssertion{ - { - Query: "insert into t values (1);", - Expected: []sql.Row{ - {types.NewOkResult(1)}, - }, }, - { - Query: "select * from t;", - Expected: []sql.Row{ - {1}, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "insert into t values (1);", + Expected: []sql.Row{ + {types.NewOkResult(1)}, + }, }, - }, - { - Query: "select * from t1;", - Expected: []sql.Row{ - {1}, + { + Query: "select * from t;", + Expected: []sql.Row{ + {1}, + }, }, - }, - { - Query: "select * from t2;", - Expected: []sql.Row{}, - }, - { - Query: "select * from t3;", - Expected: []sql.Row{ - {10}, + { + Query: "select * from t1;", + Expected: []sql.Row{ + {1}, + }, }, - }, - { - Query: "select * from t4;", - Expected: []sql.Row{ - {1}, + { + Query: "select * from t2;", + Expected: []sql.Row{}, + }, + { + Query: "select * from t3;", + Expected: []sql.Row{ + {10}, + }, + }, + { + Query: "select * from t4;", + Expected: []sql.Row{ + {1}, + }, }, }, }, - }, } for _, test := range scripts { From ba4c9230fdf0d2c7e2b72da184d14d5ce02d4153 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 4 Apr 2025 15:19:50 -0700 Subject: [PATCH 077/111] debugging --- enginetest/engine_only_test.go | 1 + .../queries/external_procedure_queries.go | 410 +++++++++--------- 2 files changed, 206 insertions(+), 205 deletions(-) diff --git a/enginetest/engine_only_test.go b/enginetest/engine_only_test.go index 4ea075ca3d..d2efe06725 100644 --- a/enginetest/engine_only_test.go +++ b/enginetest/engine_only_test.go @@ -608,6 +608,7 @@ func TestTableFunctions(t *testing.T) { func TestExternalProcedures(t *testing.T) { harness := enginetest.NewDefaultMemoryHarness() harness.Setup(setup.MydbData) + harness.UseServer() for _, script := range queries.ExternalProcedureTests { func() { e, err := harness.NewEngine(t) diff --git a/enginetest/queries/external_procedure_queries.go b/enginetest/queries/external_procedure_queries.go index c4db672e01..547a3e8034 100644 --- a/enginetest/queries/external_procedure_queries.go +++ b/enginetest/queries/external_procedure_queries.go @@ -17,15 +17,15 @@ package queries import "github.com/dolthub/go-mysql-server/sql" var ExternalProcedureTests = []ScriptTest{ - { - Name: "Call external stored procedure that does not exist", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL procedure_does_not_exist('foo');", - ExpectedErr: sql.ErrStoredProcedureDoesNotExist, - }, - }, - }, + //{ + // Name: "Call external stored procedure that does not exist", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL procedure_does_not_exist('foo');", + // ExpectedErr: sql.ErrStoredProcedureDoesNotExist, + // }, + // }, + //}, { Name: "INOUT on first param, IN on second param", SetUpScript: []string{ @@ -39,200 +39,200 @@ var ExternalProcedureTests = []ScriptTest{ }, }, }, - { - Name: "Handle setting uninitialized user variables", - SetUpScript: []string{ - "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "SELECT @uservar12;", - Expected: []sql.Row{{5}}, - }, - { - Query: "SELECT @uservar13;", - Expected: []sql.Row{{uint(5)}}, - }, - { - Query: "SELECT @uservar14;", - Expected: []sql.Row{{"5"}}, - }, - { - Query: "SELECT @uservar15;", - Expected: []sql.Row{{0}}, - }, - }, - }, - { - Name: "Called from standard stored procedure", - SetUpScript: []string{ - "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "CALL p1(11);", - Expected: []sql.Row{{22}}, - }, - }, - }, - { - Name: "Overloaded Name", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_overloaded_mult(1);", - Expected: []sql.Row{{1}}, - }, - { - Query: "CALL memory_overloaded_mult(2, 3);", - Expected: []sql.Row{{6}}, - }, - { - Query: "CALL memory_overloaded_mult(4, 5, 6);", - Expected: []sql.Row{{120}}, - }, - }, - }, - { - Name: "Passing in all supported types", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + - "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - Expected: []sql.Row{{1111114444}}, - }, - { - Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + - "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", - Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, - }, - { - Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + - "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - Expected: []sql.Row{{uint64(1111114444)}}, - }, - }, - }, - { - Name: "BOOL and []BYTE INOUT conversions", - SetUpScript: []string{ - "SET @outparam1 = 1;", - "SET @outparam2 = 0;", - "SET @outparam3 = 'A';", - "SET @outparam4 = 'B';", - }, - Assertions: []ScriptTestAssertion{ - { - Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - Expected: []sql.Row{{1, 0, "A", "B"}}, - }, - { - Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - Expected: []sql.Row{}, - }, - { - Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - Expected: []sql.Row{{1, 1, "A", []byte("C")}}, - }, - { - Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - Expected: []sql.Row{}, - }, - { - Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - Expected: []sql.Row{{1, 0, "A", []byte("D")}}, - }, - }, - }, - { - Name: "Errors returned", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_error_table_not_found();", - ExpectedErr: sql.ErrTableNotFound, - }, - }, - }, - { - Name: "Variadic parameter", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_variadic_add();", - Expected: []sql.Row{{0}}, - }, - { - Query: "CALL memory_variadic_add(1);", - Expected: []sql.Row{{1}}, - }, - { - Query: "CALL memory_variadic_add(1, 2);", - Expected: []sql.Row{{3}}, - }, - { - Query: "CALL memory_variadic_add(1, 2, 3);", - Expected: []sql.Row{{6}}, - }, - { - Query: "CALL memory_variadic_add(1, 2, 3, 4);", - Expected: []sql.Row{{10}}, - }, - }, - }, - { - Name: "Variadic byte slices", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_variadic_byte_slice();", - Expected: []sql.Row{{""}}, - }, - { - Query: "CALL memory_variadic_byte_slice('A');", - Expected: []sql.Row{{"A"}}, - }, - { - Query: "CALL memory_variadic_byte_slice('A', 'B');", - Expected: []sql.Row{{"AB"}}, - }, - }, - }, - { - Name: "Variadic overloading", - Assertions: []ScriptTestAssertion{ - { - Query: "CALL memory_variadic_overload();", - ExpectedErr: sql.ErrCallIncorrectParameterCount, - }, - { - Query: "CALL memory_variadic_overload('A');", - ExpectedErr: sql.ErrCallIncorrectParameterCount, - }, - { - Query: "CALL memory_variadic_overload('A', 'B');", - Expected: []sql.Row{{"A-B"}}, - }, - { - Query: "CALL memory_variadic_overload('A', 'B', 'C');", - ExpectedErr: sql.ErrInvalidValue, - }, - { - Query: "CALL memory_variadic_overload('A', 'B', 5);", - Expected: []sql.Row{{"A,B,[5]"}}, - }, - }, - }, - { - Name: "show create procedure for external stored procedures", - Assertions: []ScriptTestAssertion{ - { - Query: "show create procedure memory_variadic_overload;", - Expected: []sql.Row{{ - "memory_variadic_overload", - "", - "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", - "utf8mb4", - "utf8mb4_0900_bin", - "utf8mb4_0900_bin", - }}, - }, - }, - }, + //{ + // Name: "Handle setting uninitialized user variables", + // SetUpScript: []string{ + // "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", + // }, + // Assertions: []ScriptTestAssertion{ + // { + // Query: "SELECT @uservar12;", + // Expected: []sql.Row{{5}}, + // }, + // { + // Query: "SELECT @uservar13;", + // Expected: []sql.Row{{uint(5)}}, + // }, + // { + // Query: "SELECT @uservar14;", + // Expected: []sql.Row{{"5"}}, + // }, + // { + // Query: "SELECT @uservar15;", + // Expected: []sql.Row{{0}}, + // }, + // }, + //}, + //{ + // Name: "Called from standard stored procedure", + // SetUpScript: []string{ + // "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", + // }, + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL p1(11);", + // Expected: []sql.Row{{22}}, + // }, + // }, + //}, + //{ + // Name: "Overloaded Name", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_overloaded_mult(1);", + // Expected: []sql.Row{{1}}, + // }, + // { + // Query: "CALL memory_overloaded_mult(2, 3);", + // Expected: []sql.Row{{6}}, + // }, + // { + // Query: "CALL memory_overloaded_mult(4, 5, 6);", + // Expected: []sql.Row{{120}}, + // }, + // }, + //}, + //{ + // Name: "Passing in all supported types", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + + // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + // Expected: []sql.Row{{1111114444}}, + // }, + // { + // Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + + // "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", + // Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, + // }, + // { + // Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + + // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + // Expected: []sql.Row{{uint64(1111114444)}}, + // }, + // }, + //}, + //{ + // Name: "BOOL and []BYTE INOUT conversions", + // SetUpScript: []string{ + // "SET @outparam1 = 1;", + // "SET @outparam2 = 0;", + // "SET @outparam3 = 'A';", + // "SET @outparam4 = 'B';", + // }, + // Assertions: []ScriptTestAssertion{ + // { + // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + // Expected: []sql.Row{{1, 0, "A", "B"}}, + // }, + // { + // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + // Expected: []sql.Row{}, + // }, + // { + // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + // Expected: []sql.Row{{1, 1, "A", []byte("C")}}, + // }, + // { + // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + // Expected: []sql.Row{}, + // }, + // { + // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + // Expected: []sql.Row{{1, 0, "A", []byte("D")}}, + // }, + // }, + //}, + //{ + // Name: "Errors returned", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_error_table_not_found();", + // ExpectedErr: sql.ErrTableNotFound, + // }, + // }, + //}, + //{ + // Name: "Variadic parameter", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_variadic_add();", + // Expected: []sql.Row{{0}}, + // }, + // { + // Query: "CALL memory_variadic_add(1);", + // Expected: []sql.Row{{1}}, + // }, + // { + // Query: "CALL memory_variadic_add(1, 2);", + // Expected: []sql.Row{{3}}, + // }, + // { + // Query: "CALL memory_variadic_add(1, 2, 3);", + // Expected: []sql.Row{{6}}, + // }, + // { + // Query: "CALL memory_variadic_add(1, 2, 3, 4);", + // Expected: []sql.Row{{10}}, + // }, + // }, + //}, + //{ + // Name: "Variadic byte slices", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_variadic_byte_slice();", + // Expected: []sql.Row{{""}}, + // }, + // { + // Query: "CALL memory_variadic_byte_slice('A');", + // Expected: []sql.Row{{"A"}}, + // }, + // { + // Query: "CALL memory_variadic_byte_slice('A', 'B');", + // Expected: []sql.Row{{"AB"}}, + // }, + // }, + //}, + //{ + // Name: "Variadic overloading", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "CALL memory_variadic_overload();", + // ExpectedErr: sql.ErrCallIncorrectParameterCount, + // }, + // { + // Query: "CALL memory_variadic_overload('A');", + // ExpectedErr: sql.ErrCallIncorrectParameterCount, + // }, + // { + // Query: "CALL memory_variadic_overload('A', 'B');", + // Expected: []sql.Row{{"A-B"}}, + // }, + // { + // Query: "CALL memory_variadic_overload('A', 'B', 'C');", + // ExpectedErr: sql.ErrInvalidValue, + // }, + // { + // Query: "CALL memory_variadic_overload('A', 'B', 5);", + // Expected: []sql.Row{{"A,B,[5]"}}, + // }, + // }, + //}, + //{ + // Name: "show create procedure for external stored procedures", + // Assertions: []ScriptTestAssertion{ + // { + // Query: "show create procedure memory_variadic_overload;", + // Expected: []sql.Row{{ + // "memory_variadic_overload", + // "", + // "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", + // "utf8mb4", + // "utf8mb4_0900_bin", + // "utf8mb4_0900_bin", + // }}, + // }, + // }, + //}, } From 112769f9443c0b629547e23c3d2a84df7f537946 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 7 Apr 2025 17:22:02 -0700 Subject: [PATCH 078/111] fix? --- .../queries/external_procedure_queries.go | 410 +++++++++--------- sql/plan/call.go | 13 +- sql/plan/procedure.go | 3 + sql/procedures/interpreter_logic.go | 166 +++---- sql/rowexec/proc.go | 3 + 5 files changed, 312 insertions(+), 283 deletions(-) diff --git a/enginetest/queries/external_procedure_queries.go b/enginetest/queries/external_procedure_queries.go index 547a3e8034..c4db672e01 100644 --- a/enginetest/queries/external_procedure_queries.go +++ b/enginetest/queries/external_procedure_queries.go @@ -17,15 +17,15 @@ package queries import "github.com/dolthub/go-mysql-server/sql" var ExternalProcedureTests = []ScriptTest{ - //{ - // Name: "Call external stored procedure that does not exist", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL procedure_does_not_exist('foo');", - // ExpectedErr: sql.ErrStoredProcedureDoesNotExist, - // }, - // }, - //}, + { + Name: "Call external stored procedure that does not exist", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL procedure_does_not_exist('foo');", + ExpectedErr: sql.ErrStoredProcedureDoesNotExist, + }, + }, + }, { Name: "INOUT on first param, IN on second param", SetUpScript: []string{ @@ -39,200 +39,200 @@ var ExternalProcedureTests = []ScriptTest{ }, }, }, - //{ - // Name: "Handle setting uninitialized user variables", - // SetUpScript: []string{ - // "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Query: "SELECT @uservar12;", - // Expected: []sql.Row{{5}}, - // }, - // { - // Query: "SELECT @uservar13;", - // Expected: []sql.Row{{uint(5)}}, - // }, - // { - // Query: "SELECT @uservar14;", - // Expected: []sql.Row{{"5"}}, - // }, - // { - // Query: "SELECT @uservar15;", - // Expected: []sql.Row{{0}}, - // }, - // }, - //}, - //{ - // Name: "Called from standard stored procedure", - // SetUpScript: []string{ - // "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL p1(11);", - // Expected: []sql.Row{{22}}, - // }, - // }, - //}, - //{ - // Name: "Overloaded Name", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_overloaded_mult(1);", - // Expected: []sql.Row{{1}}, - // }, - // { - // Query: "CALL memory_overloaded_mult(2, 3);", - // Expected: []sql.Row{{6}}, - // }, - // { - // Query: "CALL memory_overloaded_mult(4, 5, 6);", - // Expected: []sql.Row{{120}}, - // }, - // }, - //}, - //{ - // Name: "Passing in all supported types", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + - // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - // Expected: []sql.Row{{1111114444}}, - // }, - // { - // Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + - // "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", - // Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, - // }, - // { - // Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + - // "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", - // Expected: []sql.Row{{uint64(1111114444)}}, - // }, - // }, - //}, - //{ - // Name: "BOOL and []BYTE INOUT conversions", - // SetUpScript: []string{ - // "SET @outparam1 = 1;", - // "SET @outparam2 = 0;", - // "SET @outparam3 = 'A';", - // "SET @outparam4 = 'B';", - // }, - // Assertions: []ScriptTestAssertion{ - // { - // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - // Expected: []sql.Row{{1, 0, "A", "B"}}, - // }, - // { - // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - // Expected: []sql.Row{}, - // }, - // { - // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - // Expected: []sql.Row{{1, 1, "A", []byte("C")}}, - // }, - // { - // Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", - // Expected: []sql.Row{}, - // }, - // { - // Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", - // Expected: []sql.Row{{1, 0, "A", []byte("D")}}, - // }, - // }, - //}, - //{ - // Name: "Errors returned", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_error_table_not_found();", - // ExpectedErr: sql.ErrTableNotFound, - // }, - // }, - //}, - //{ - // Name: "Variadic parameter", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_variadic_add();", - // Expected: []sql.Row{{0}}, - // }, - // { - // Query: "CALL memory_variadic_add(1);", - // Expected: []sql.Row{{1}}, - // }, - // { - // Query: "CALL memory_variadic_add(1, 2);", - // Expected: []sql.Row{{3}}, - // }, - // { - // Query: "CALL memory_variadic_add(1, 2, 3);", - // Expected: []sql.Row{{6}}, - // }, - // { - // Query: "CALL memory_variadic_add(1, 2, 3, 4);", - // Expected: []sql.Row{{10}}, - // }, - // }, - //}, - //{ - // Name: "Variadic byte slices", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_variadic_byte_slice();", - // Expected: []sql.Row{{""}}, - // }, - // { - // Query: "CALL memory_variadic_byte_slice('A');", - // Expected: []sql.Row{{"A"}}, - // }, - // { - // Query: "CALL memory_variadic_byte_slice('A', 'B');", - // Expected: []sql.Row{{"AB"}}, - // }, - // }, - //}, - //{ - // Name: "Variadic overloading", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "CALL memory_variadic_overload();", - // ExpectedErr: sql.ErrCallIncorrectParameterCount, - // }, - // { - // Query: "CALL memory_variadic_overload('A');", - // ExpectedErr: sql.ErrCallIncorrectParameterCount, - // }, - // { - // Query: "CALL memory_variadic_overload('A', 'B');", - // Expected: []sql.Row{{"A-B"}}, - // }, - // { - // Query: "CALL memory_variadic_overload('A', 'B', 'C');", - // ExpectedErr: sql.ErrInvalidValue, - // }, - // { - // Query: "CALL memory_variadic_overload('A', 'B', 5);", - // Expected: []sql.Row{{"A,B,[5]"}}, - // }, - // }, - //}, - //{ - // Name: "show create procedure for external stored procedures", - // Assertions: []ScriptTestAssertion{ - // { - // Query: "show create procedure memory_variadic_overload;", - // Expected: []sql.Row{{ - // "memory_variadic_overload", - // "", - // "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", - // "utf8mb4", - // "utf8mb4_0900_bin", - // "utf8mb4_0900_bin", - // }}, - // }, - // }, - //}, + { + Name: "Handle setting uninitialized user variables", + SetUpScript: []string{ + "CALL memory_inout_set_unitialized(@uservar12, @uservar13, @uservar14, @uservar15);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT @uservar12;", + Expected: []sql.Row{{5}}, + }, + { + Query: "SELECT @uservar13;", + Expected: []sql.Row{{uint(5)}}, + }, + { + Query: "SELECT @uservar14;", + Expected: []sql.Row{{"5"}}, + }, + { + Query: "SELECT @uservar15;", + Expected: []sql.Row{{0}}, + }, + }, + }, + { + Name: "Called from standard stored procedure", + SetUpScript: []string{ + "CREATE PROCEDURE p1(x BIGINT) BEGIN CALL memory_inout_add(x, x); SELECT x; END;", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "CALL p1(11);", + Expected: []sql.Row{{22}}, + }, + }, + }, + { + Name: "Overloaded Name", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_overloaded_mult(1);", + Expected: []sql.Row{{1}}, + }, + { + Query: "CALL memory_overloaded_mult(2, 3);", + Expected: []sql.Row{{6}}, + }, + { + Query: "CALL memory_overloaded_mult(4, 5, 6);", + Expected: []sql.Row{{120}}, + }, + }, + }, + { + Name: "Passing in all supported types", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_overloaded_type_test(1, 100, 10000, 1000000, 100000000, 3, 300," + + "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + Expected: []sql.Row{{1111114444}}, + }, + { + Query: "CALL memory_overloaded_type_test(false, 'hi', 'A', '2020-02-20 12:00:00', 123.456," + + "true, 'bye', 'B', '2022-02-02 12:00:00', 654.32);", + Expected: []sql.Row{{`aa:false,ba:true,ab:"hi",bb:"bye",ac:[65],bc:[66],ad:2020-02-20,bd:2022-02-02,ae:123.456,be:654.32`}}, + }, + { + Query: "CALL memory_type_test3(1, 100, 10000, 1000000, 100000000, 3, 300," + + "10, 1000, 100000, 10000000, 1000000000, 30, 3000);", + Expected: []sql.Row{{uint64(1111114444)}}, + }, + }, + }, + { + Name: "BOOL and []BYTE INOUT conversions", + SetUpScript: []string{ + "SET @outparam1 = 1;", + "SET @outparam2 = 0;", + "SET @outparam3 = 'A';", + "SET @outparam4 = 'B';", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + Expected: []sql.Row{{1, 0, "A", "B"}}, + }, + { + Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + Expected: []sql.Row{}, + }, + { + Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + Expected: []sql.Row{{1, 1, "A", []byte("C")}}, + }, + { + Query: "CALL memory_inout_bool_byte(@outparam1, @outparam2, @outparam3, @outparam4);", + Expected: []sql.Row{}, + }, + { + Query: "SELECT @outparam1, @outparam2, @outparam3, @outparam4;", + Expected: []sql.Row{{1, 0, "A", []byte("D")}}, + }, + }, + }, + { + Name: "Errors returned", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_error_table_not_found();", + ExpectedErr: sql.ErrTableNotFound, + }, + }, + }, + { + Name: "Variadic parameter", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_variadic_add();", + Expected: []sql.Row{{0}}, + }, + { + Query: "CALL memory_variadic_add(1);", + Expected: []sql.Row{{1}}, + }, + { + Query: "CALL memory_variadic_add(1, 2);", + Expected: []sql.Row{{3}}, + }, + { + Query: "CALL memory_variadic_add(1, 2, 3);", + Expected: []sql.Row{{6}}, + }, + { + Query: "CALL memory_variadic_add(1, 2, 3, 4);", + Expected: []sql.Row{{10}}, + }, + }, + }, + { + Name: "Variadic byte slices", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_variadic_byte_slice();", + Expected: []sql.Row{{""}}, + }, + { + Query: "CALL memory_variadic_byte_slice('A');", + Expected: []sql.Row{{"A"}}, + }, + { + Query: "CALL memory_variadic_byte_slice('A', 'B');", + Expected: []sql.Row{{"AB"}}, + }, + }, + }, + { + Name: "Variadic overloading", + Assertions: []ScriptTestAssertion{ + { + Query: "CALL memory_variadic_overload();", + ExpectedErr: sql.ErrCallIncorrectParameterCount, + }, + { + Query: "CALL memory_variadic_overload('A');", + ExpectedErr: sql.ErrCallIncorrectParameterCount, + }, + { + Query: "CALL memory_variadic_overload('A', 'B');", + Expected: []sql.Row{{"A-B"}}, + }, + { + Query: "CALL memory_variadic_overload('A', 'B', 'C');", + ExpectedErr: sql.ErrInvalidValue, + }, + { + Query: "CALL memory_variadic_overload('A', 'B', 5);", + Expected: []sql.Row{{"A,B,[5]"}}, + }, + }, + }, + { + Name: "show create procedure for external stored procedures", + Assertions: []ScriptTestAssertion{ + { + Query: "show create procedure memory_variadic_overload;", + Expected: []sql.Row{{ + "memory_variadic_overload", + "", + "CREATE PROCEDURE memory_variadic_overload() SELECT 'External stored procedure';", + "utf8mb4", + "utf8mb4_0900_bin", + "utf8mb4_0900_bin", + }}, + }, + }, + }, } diff --git a/sql/plan/call.go b/sql/plan/call.go index 30dc5271bd..d27a07dc98 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -38,6 +38,9 @@ type Call struct { // this will have list of parsed operations to run Runner sql.StatementRunner Ops []procedures.InterpreterOperation + + // TODO: sure whatever + resSch sql.Schema } var _ sql.Node = (*Call)(nil) @@ -84,6 +87,9 @@ func (c *Call) IsReadOnly() bool { // Schema implements the sql.Node interface. func (c *Call) Schema() sql.Schema { + if c.resSch != nil { + return c.resSch + } if c.Procedure != nil { return c.Procedure.Schema() } @@ -224,7 +230,8 @@ func (c *Call) GetStatements() []*procedures.InterpreterOperation { return c.Procedure.Ops } -// GetReturn implements the sql.InterpreterNode interface. -func (c *Call) GetReturn() sql.Type { - return nil +// SetSchema implements the sql.InterpreterNode interface. +func (c *Call) SetSchema(sch sql.Schema) { + c.resSch = sch } + diff --git a/sql/plan/procedure.go b/sql/plan/procedure.go index c7af76ef67..50222ad68c 100644 --- a/sql/plan/procedure.go +++ b/sql/plan/procedure.go @@ -160,6 +160,9 @@ func (p *Procedure) DebugString() string { // Schema implements the sql.Node interface. func (p *Procedure) Schema() sql.Schema { + if p.ExternalProc != nil { + return p.ExternalProc.Schema() + } return types.OkResultSchema } diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 580b65866c..6ccb7dbbae 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -15,7 +15,8 @@ package procedures import ( - "errors" + "context" +"errors" "fmt" "io" "strconv" @@ -34,9 +35,9 @@ import ( type InterpreterNode interface { GetAsOf() sql.Expression GetRunner() sql.StatementRunner - GetReturn() sql.Type GetStatements() []*InterpreterOperation SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Node + SetSchema(sch sql.Schema) } type Parameter struct { @@ -282,10 +283,10 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. return expr, nil } -func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sql.RowIter, error) { - _, rowIter, _, err := runner.QueryWithBindings(ctx, "", stmt, nil, nil) +func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sql.Schema, sql.RowIter, error) { + sch, rowIter, _, err := runner.QueryWithBindings(ctx, "", stmt, nil, nil) if err != nil { - return nil, err + return nil, nil, err } var rows []sql.Row for { @@ -294,14 +295,14 @@ func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sq if rErr == io.EOF { break } - return nil, rErr + return nil, nil, rErr } rows = append(rows, row) } if err = rowIter.Close(ctx); err != nil { - return nil, err + return nil, nil, err } - return sql.RowsToRowIter(rows...), nil + return sch, sql.RowsToRowIter(rows...), nil } // handleError handles errors that occur during the execution of a procedure according to the defined handlers. @@ -340,7 +341,7 @@ func handleError(ctx *sql.Context, runner sql.StatementRunner, stack *Interprete return -1, err } - _, _, rowIter, err := execOp(ctx, runner, stack, handlerOps[0], handlerOps, nil, -1) + _, _, _, rowIter, err := execOp(ctx, runner, stack, handlerOps[0], handlerOps, nil, -1) if err != nil { return -1, err } @@ -378,56 +379,59 @@ func handleError(ctx *sql.Context, runner sql.StatementRunner, stack *Interprete return counter, nil } -func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStack, operation *InterpreterOperation, statements []*InterpreterOperation, asOf *ast.AsOf, counter int) (int, sql.RowIter, sql.RowIter, error) { +func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStack, operation *InterpreterOperation, statements []*InterpreterOperation, asOf *ast.AsOf, counter int) (int, sql.Schema, sql.RowIter, sql.RowIter, error) { switch operation.OpCode { case OpCode_Select: + if counter == 2 { + print() + } selectStmt := operation.PrimaryData.(*ast.Select) if newSelectStmt, err := replaceVariablesInExpr(ctx, stack, selectStmt, asOf); err == nil { selectStmt = newSelectStmt.(*ast.Select) } else { - return 0, nil, nil, err + return 0, nil, nil, nil, err } if selectStmt.Into == nil { - rowIter, err := query(ctx, runner, selectStmt) + sch, rowIter, err := query(ctx, runner, selectStmt) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } - return counter, rowIter, rowIter, nil + return counter, sch, rowIter, rowIter, nil } selectInto := selectStmt.Into selectStmt.Into = nil schema, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } row, err := rowIter.Next(ctx) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } if _, err = rowIter.Next(ctx); err != io.EOF { - return 0, nil, nil, err + return 0, nil, nil, nil, err } if err = rowIter.Close(ctx); err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } if len(row) != len(selectInto.Variables) { - return 0, nil, nil, sql.ErrColumnNumberDoesNotMatch.New() + return 0, nil, nil, nil, sql.ErrColumnNumberDoesNotMatch.New() } for i := range selectInto.Variables { intoVar := strings.ToLower(selectInto.Variables[i].String()) if strings.HasPrefix(intoVar, "@") { err = ctx.SetUserVariable(ctx, intoVar, row[i], schema[i].Type) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } } err = stack.SetVariable(intoVar, row[i]) if err != nil { err = ctx.Session.SetStoredProcParam(intoVar, row[i]) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } } } @@ -443,17 +447,17 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac var err error if stateVal != "" { if len(stateVal) != 5 { - return 0, nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") + return 0, nil, nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") } if stateVal[0:2] == "00" { - return 0, nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", stateVal) + return 0, nil, nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", stateVal) } } else { // use our own error num, err = strconv.ParseInt(string(cond.MysqlErrorCode.Val), 10, 64) if err != nil || num == 0 { err = fmt.Errorf("invalid value '%s' for MySQL error code", string(cond.MysqlErrorCode.Val)) - return 0, nil, nil, err + return 0, nil, nil, nil, err } } stack.NewCondition(condName, stateVal, num) @@ -468,7 +472,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac // TODO: duplicate handlers? if handler := declareStmt.Handler; handler != nil { if len(handler.ConditionValues) != 1 { - return 0, nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) + return 0, nil, nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) } hCond := handler.ConditionValues[0] @@ -476,14 +480,14 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac case ast.DeclareHandlerCondition_NotFound: case ast.DeclareHandlerCondition_SqlException: default: - return 0, nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) + return 0, nil, nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) } switch handler.Action { case ast.DeclareHandlerAction_Continue: case ast.DeclareHandlerAction_Exit: case ast.DeclareHandlerAction_Undo: - return 0, nil, nil, fmt.Errorf("unsupported handler action: %s", handler.Action) + return 0, nil, nil, nil, fmt.Errorf("unsupported handler action: %s", handler.Action) } stack.NewHandler(hCond.ValueType, handler.Action, handler.Statement, counter) @@ -494,7 +498,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac for _, decl := range vars.Names { varType, err := types.ColumnTypeToType(&vars.VarType) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } varName := strings.ToLower(decl.String()) if vars.VarType.Default == nil { @@ -514,19 +518,19 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac if signalStmt.ConditionName == "" { sqlState = signalStmt.SqlStateValue if sqlState[0:2] == "01" { - return 0, nil, nil, fmt.Errorf("warnings not yet implemented") + return 0, nil, nil, nil, fmt.Errorf("warnings not yet implemented") } } else { cond := stack.GetCondition(strings.ToLower(signalStmt.ConditionName)) if cond == nil { - return 0, nil, nil, sql.ErrDeclareConditionNotFound.New(signalStmt.ConditionName) + return 0, nil, nil, nil, sql.ErrDeclareConditionNotFound.New(signalStmt.ConditionName) } sqlState = cond.SQLState mysqlErrNo = int(cond.MySQLErrCode) } if len(sqlState) != 5 { - return 0, nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") + return 0, nil, nil, nil, fmt.Errorf("SQLSTATE VALUE must be a string with length 5 consisting of only integers") } for _, item := range signalStmt.Info { @@ -536,35 +540,35 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac case *ast.SQLVal: num, err := strconv.ParseInt(string(val.Val), 10, 64) if err != nil || num == 0 { - return 0, nil, nil, fmt.Errorf("invalid value '%s' for MySQL error code", string(val.Val)) + return 0, nil, nil, nil, fmt.Errorf("invalid value '%s' for MySQL error code", string(val.Val)) } mysqlErrNo = int(num) case *ast.ColName: - return 0, nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) + return 0, nil, nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) default: - return 0, nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) + return 0, nil, nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) } case ast.SignalConditionItemName_MessageText: switch val := item.Value.(type) { case *ast.SQLVal: msgTxt = string(val.Val) if len(msgTxt) > 128 { - return 0, nil, nil, fmt.Errorf("signal condition information item MESSAGE_TEXT has max length of 128") + return 0, nil, nil, nil, fmt.Errorf("signal condition information item MESSAGE_TEXT has max length of 128") } case *ast.ColName: - return 0, nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) + return 0, nil, nil, nil, fmt.Errorf("unsupported signal message text type: %T", val) default: - return 0, nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) + return 0, nil, nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item MESSAGE_TEXT", val) } default: switch val := item.Value.(type) { case *ast.SQLVal: msgTxt = string(val.Val) if len(msgTxt) > 64 { - return 0, nil, nil, fmt.Errorf("signal condition information item %s has max length of 64", strings.ToUpper(string(item.ConditionItemName))) + return 0, nil, nil, nil, fmt.Errorf("signal condition information item %s has max length of 64", strings.ToUpper(string(item.ConditionItemName))) } default: - return 0, nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item '%s''", item.Value, strings.ToUpper(string(item.ConditionItemName))) + return 0, nil, nil, nil, fmt.Errorf("invalid value '%v' for signal condition information item '%s''", item.Value, strings.ToUpper(string(item.ConditionItemName))) } } } @@ -583,7 +587,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac if msgTxt == "" { switch sqlState[0:2] { case "00": - return 0, nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", sqlState) + return 0, nil, nil, nil, fmt.Errorf("invalid SQLSTATE VALUE: '%s'", sqlState) case "01": msgTxt = "Unhandled user-defined warning condition" case "02": @@ -593,24 +597,24 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } } - return 0, nil, nil, mysql.NewSQLError(mysqlErrNo, sqlState, msgTxt) + return 0, nil, nil, nil, mysql.NewSQLError(mysqlErrNo, sqlState, msgTxt) case OpCode_Open: openCur := operation.PrimaryData.(*ast.OpenCursor) cursor := stack.GetCursor(strings.ToLower(openCur.Name)) if cursor == nil { - return 0, nil, nil, sql.ErrCursorNotFound.New(openCur.Name) + return 0, nil, nil, nil, sql.ErrCursorNotFound.New(openCur.Name) } if cursor.RowIter != nil { - return 0, nil, nil, sql.ErrCursorAlreadyOpen.New(openCur.Name) + return 0, nil, nil, nil, sql.ErrCursorAlreadyOpen.New(openCur.Name) } stmt, err := replaceVariablesInExpr(ctx, stack, cursor.SelectStmt, asOf) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } schema, rowIter, _, err := runner.QueryWithBindings(ctx, "", stmt.(ast.Statement), nil, nil) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } cursor.Schema = schema cursor.RowIter = rowIter @@ -619,33 +623,33 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac fetchCur := operation.PrimaryData.(*ast.FetchCursor) cursor := stack.GetCursor(strings.ToLower(fetchCur.Name)) if cursor == nil { - return 0, nil, nil, sql.ErrCursorNotFound.New(fetchCur.Name) + return 0, nil, nil, nil, sql.ErrCursorNotFound.New(fetchCur.Name) } if cursor.RowIter == nil { - return 0, nil, nil, sql.ErrCursorNotOpen.New(fetchCur.Name) + return 0, nil, nil, nil, sql.ErrCursorNotOpen.New(fetchCur.Name) } row, err := cursor.RowIter.Next(ctx) if err != nil { if err == io.EOF { - return 0, nil, nil, expression.FetchEOF + return 0, nil, nil, nil, expression.FetchEOF } - return 0, nil, nil, err + return 0, nil, nil, nil, err } if len(row) != len(fetchCur.Variables) { - return 0, nil, nil, sql.ErrFetchIncorrectCount.New() + return 0, nil, nil, nil, sql.ErrFetchIncorrectCount.New() } for i := range fetchCur.Variables { varName := strings.ToLower(fetchCur.Variables[i]) if strings.HasPrefix(varName, "@") { err = ctx.SetUserVariable(ctx, varName, row[i], cursor.Schema[i].Type) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } continue } err = stack.SetVariable(varName, row[i]) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } } @@ -653,13 +657,13 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac closeCur := operation.PrimaryData.(*ast.CloseCursor) cursor := stack.GetCursor(strings.ToLower(closeCur.Name)) if cursor == nil { - return 0, nil, nil, sql.ErrCursorNotFound.New(closeCur.Name) + return 0, nil, nil, nil, sql.ErrCursorNotFound.New(closeCur.Name) } if cursor.RowIter == nil { - return 0, nil, nil, sql.ErrCursorNotOpen.New(closeCur.Name) + return 0, nil, nil, nil, sql.ErrCursorNotOpen.New(closeCur.Name) } if err := cursor.RowIter.Close(ctx); err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } cursor.RowIter = nil @@ -671,30 +675,30 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac for i := range selectStmt.SelectExprs { newNode, err := replaceVariablesInExpr(ctx, stack, selectStmt.SelectExprs[i], asOf) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) } _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } row, err := rowIter.Next(ctx) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } if _, err = rowIter.Next(ctx); err != io.EOF { - return 0, nil, nil, err + return 0, nil, nil, nil, err } if err = rowIter.Close(ctx); err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } err = stack.SetVariable(strings.ToLower(operation.Target), row[0]) if err != nil { err = ctx.Session.SetStoredProcParam(operation.Target, row[0]) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } } @@ -706,30 +710,30 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac for i := range selectStmt.SelectExprs { newNode, err := replaceVariablesInExpr(ctx, stack, selectStmt.SelectExprs[i], asOf) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } selectStmt.SelectExprs[i] = newNode.(ast.SelectExpr) } _, rowIter, _, err := runner.QueryWithBindings(ctx, "", selectStmt, nil, nil) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } // TODO: exactly one result that is a bool for now row, err := rowIter.Next(ctx) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } if _, err = rowIter.Next(ctx); err != io.EOF { - return 0, nil, nil, err + return 0, nil, nil, nil, err } if err = rowIter.Close(ctx); err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } // go to the appropriate block cond, _, err := types.Boolean.Convert(row[0]) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } if cond == nil || cond.(int8) == 0 { counter = operation.Index - 1 // index of the else block, offset by 1 @@ -767,17 +771,17 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac case OpCode_Execute: stmt, err := replaceVariablesInExpr(ctx, stack, operation.PrimaryData, asOf) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } // TODO: create a OpCode_Call to store procedures in the stack - rowIter, err := query(ctx, runner, stmt.(ast.Statement)) + _, rowIter, err := query(ctx, runner, stmt.(ast.Statement)) if err != nil { - return 0, nil, nil, err + return 0, nil, nil, nil, err } - return counter, nil, rowIter, err + return counter, nil, nil, rowIter, err case OpCode_Exception: - return 0, nil, nil, operation.Error + return 0, nil, nil, nil, operation.Error case OpCode_ScopeBegin: stack.PushScope() @@ -789,7 +793,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac panic("unimplemented opcode") } - return counter, nil, nil, nil + return counter, nil, nil, nil, nil } // Call runs the contained operations on the given runner. @@ -828,6 +832,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.Row // TODO: remove this; track last selectRowIter var selIter sql.RowIter + var selSch sql.Schema // Run the statements // TODO: eventually return multiple sql.RowIters @@ -843,10 +848,17 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.Row break } + // TODO: server engine can't run multiple statements in procedures because the ctx keeps getting cancelled + // from the TrackedRowIter. + // Uncancel query? + // Use subcontexts? + subCtx := sql.NewContext(context.Background()) + subCtx.Session = ctx.Session + operation := statements[counter] - newCounter, newSelIter, rowIter, err := execOp(ctx, runner, stack, operation, statements, asOf, counter) + newCounter, newSelSch, newSelIter, rowIter, err := execOp(subCtx, runner, stack, operation, statements, asOf, counter) if err != nil { - hCounter, hErr := handleError(ctx, runner, stack, statements, counter, err) + hCounter, hErr := handleError(subCtx, runner, stack, statements, counter, err) if hErr != nil && hErr != io.EOF { return nil, nil, hErr } @@ -861,15 +873,19 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.Row } if newSelIter != nil { selIter = newSelIter + selSch = newSelSch } counter = newCounter } if selIter != nil { + iNode.SetSchema(selSch) return selIter, stack, nil } if len(rowIters) == 0 { rowIters = append(rowIters, sql.RowsToRowIter(sql.Row{types.NewOkResult(0)})) } + + // TODO: probably need to set result schema for these too return rowIters[len(rowIters)-1], stack, nil } diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 427bbf212a..c81189f3e0 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -231,6 +231,9 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq } rowIter, _, err := procedures.Call(ctx, n, procParams) + if err != nil && err.Error() == "context canceled" { + print() + } if err != nil { return nil, err } From d364eef449805bf8214813450ba4e281c76ace70 Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 8 Apr 2025 07:44:41 +0000 Subject: [PATCH 079/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/plan/call.go | 3 +-- sql/procedures/interpreter_logic.go | 8 ++++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sql/plan/call.go b/sql/plan/call.go index d27a07dc98..b3ecba4ec2 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -40,7 +40,7 @@ type Call struct { Ops []procedures.InterpreterOperation // TODO: sure whatever - resSch sql.Schema + resSch sql.Schema } var _ sql.Node = (*Call)(nil) @@ -234,4 +234,3 @@ func (c *Call) GetStatements() []*procedures.InterpreterOperation { func (c *Call) SetSchema(sch sql.Schema) { c.resSch = sch } - diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 6ccb7dbbae..fb55302012 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -15,8 +15,8 @@ package procedures import ( - "context" -"errors" + "context" + "errors" "fmt" "io" "strconv" @@ -832,7 +832,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.Row // TODO: remove this; track last selectRowIter var selIter sql.RowIter - var selSch sql.Schema + var selSch sql.Schema // Run the statements // TODO: eventually return multiple sql.RowIters @@ -873,7 +873,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.Row } if newSelIter != nil { selIter = newSelIter - selSch = newSelSch + selSch = newSelSch } counter = newCounter } From d77b59ab35e0f70336e35d91634f31b9f7cf6ed2 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 8 Apr 2025 01:01:55 -0700 Subject: [PATCH 080/111] fix --- sql/procedures/interpreter_logic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index fb55302012..c33c70498e 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -731,7 +731,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } // go to the appropriate block - cond, _, err := types.Boolean.Convert(row[0]) + cond, _, err := types.Boolean.Convert(ctx, row[0]) if err != nil { return 0, nil, nil, nil, err } From 8c87be399f01ac9d93293d57bf65e30576051147 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 8 Apr 2025 01:21:06 -0700 Subject: [PATCH 081/111] asdf --- sql/rowexec/proc.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 1a3bb2c826..bf0f098f1d 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -219,7 +219,7 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq if err != nil { return nil, err } - paramVal, _, err = paramType.Convert(paramVal) + paramVal, _, err = paramType.Convert(ctx, paramVal) if err != nil { return nil, err } @@ -247,7 +247,7 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq stackVar := ctx.Session.GetStoredProcParam(procParam.Name) // TODO: ToLower? switch p := param.(type) { case *expression.ProcedureParam: - err = p.Set(stackVar.Value, stackVar.Type) + err = p.Set(ctx, stackVar.Value, stackVar.Type) if err != nil { return nil, err } From 4e5e399ea1a30da3a3ecbdfcfe4f43b0e6807749 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 8 Apr 2025 01:46:55 -0700 Subject: [PATCH 082/111] cleaning up --- sql/procedures/interpreter_logic.go | 20 +------------ sql/rowexec/proc.go | 46 +++++++++++++---------------- 2 files changed, 21 insertions(+), 45 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index c33c70498e..7c72408285 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -40,12 +40,6 @@ type InterpreterNode interface { SetSchema(sch sql.Schema) } -type Parameter struct { - Name string - Type sql.Type - Value any -} - func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast.SQLNode, asOf *ast.AsOf) (ast.SQLNode, error) { switch e := expr.(type) { case *ast.ColName: @@ -797,22 +791,10 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } // Call runs the contained operations on the given runner. -func Call(ctx *sql.Context, iNode InterpreterNode, params []*Parameter) (sql.RowIter, *InterpreterStack, error) { - for _, param := range params { - var spp *sql.StoredProcParam - spp = ctx.Session.GetStoredProcParam(param.Name) - for spp != nil { - spp.Value = param.Value - spp = spp.Reference - } - } - +func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterStack, error) { // Set up the initial state of the function counter := -1 // We increment before accessing, so start at -1 stack := NewInterpreterStack() - //for _, param := range params { - // stack.NewVariableWithValue(param.Name, param.Type, param.Value) - //} var asOf *ast.AsOf if asOfExpr := iNode.GetAsOf(); asOfExpr != nil { diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index bf0f098f1d..150cce4c07 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -209,31 +209,25 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq }, nil } - // TODO: replace with direct ctx modification - procParams := make([]*procedures.Parameter, len(n.Params)) + // Initialize parameters for i, paramExpr := range n.Params { param := n.Procedure.Params[i] - paramName := strings.ToLower(param.Name) - paramType := param.Type paramVal, err := paramExpr.Eval(ctx, row) if err != nil { return nil, err } - paramVal, _, err = paramType.Convert(ctx, paramVal) + paramVal, _, err = param.Type.Convert(ctx, paramVal) if err != nil { return nil, err } - procParams[i] = &procedures.Parameter{ - Name: paramName, - Value: paramVal, - Type: paramType, + paramName := strings.ToLower(param.Name) + for spp := ctx.Session.GetStoredProcParam(paramName); spp != nil; { + spp.Value = paramVal + spp = spp.Reference } } - rowIter, _, err := procedures.Call(ctx, n, procParams) - if err != nil && err.Error() == "context canceled" { - print() - } + rowIter, _, err := procedures.Call(ctx, n) if err != nil { return nil, err } @@ -244,25 +238,25 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq continue } // Set all user and system variables from INOUT and OUT params - stackVar := ctx.Session.GetStoredProcParam(procParam.Name) // TODO: ToLower? + paramName := strings.ToLower(procParam.Name) + spp := ctx.Session.GetStoredProcParam(paramName) + if spp == nil { + return nil, fmt.Errorf("parameter `%s` not found", paramName) + } switch p := param.(type) { case *expression.ProcedureParam: - err = p.Set(ctx, stackVar.Value, stackVar.Type) - if err != nil { - return nil, err - } + err = p.Set(ctx, spp.Value, spp.Type) case *expression.UserVar: - val := stackVar.Value - if procParam.Direction == plan.ProcedureParamDirection_Out && !stackVar.HasBeenSet { + val := spp.Value + if procParam.Direction == plan.ProcedureParamDirection_Out && !spp.HasBeenSet { val = nil } - err = ctx.SetUserVariable(ctx, p.Name, val, stackVar.Type) - if err != nil { - return nil, err - } + err = ctx.SetUserVariable(ctx, p.Name, val, spp.Type) case *expression.SystemVar: - // This should have been caught by the analyzer, so a major bug exists somewhere - return nil, fmt.Errorf("unable to set `%s` as it is a system variable", p.Name) + err = fmt.Errorf("unable to set `%s` as it is a system variable", p.Name) + } + if err != nil { + return nil, err } } From 50fbea8d862e6de6fdc9516fffc6e37a74cef606 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 11 Apr 2025 10:26:46 -0700 Subject: [PATCH 083/111] bump --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index aed4846e2d..b68c09402b 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250410090211-143e6b272ad4 + github.com/dolthub/vitess v0.0.0-20250411172524-030d7ebedb96 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 391e1a6454..5a7c014fb4 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250410090211-143e6b272ad4 h1:LGTt2LtYX8vaai32d+c9L0sMcP+Dg9w1kO6+lbsxxYg= -github.com/dolthub/vitess v0.0.0-20250410090211-143e6b272ad4/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250411172524-030d7ebedb96 h1:CRHpJLqR9rYaBfEwFn5De/6j7IPQxUObMxuO3svg2FI= +github.com/dolthub/vitess v0.0.0-20250411172524-030d7ebedb96/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From a205f74e199ddead2aa2701ae213770f0feec57f Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 11 Apr 2025 12:44:00 -0700 Subject: [PATCH 084/111] fix schemas --- sql/procedures/interpreter_logic.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 7c72408285..b6a2838cc0 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -768,11 +768,11 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac return 0, nil, nil, nil, err } // TODO: create a OpCode_Call to store procedures in the stack - _, rowIter, err := query(ctx, runner, stmt.(ast.Statement)) + sch, rowIter, err := query(ctx, runner, stmt.(ast.Statement)) if err != nil { return 0, nil, nil, nil, err } - return counter, nil, nil, rowIter, err + return counter, sch, nil, rowIter, err case OpCode_Exception: return 0, nil, nil, nil, operation.Error @@ -819,6 +819,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta // Run the statements // TODO: eventually return multiple sql.RowIters var rowIters []sql.RowIter + var retSch sql.Schema runner := iNode.GetRunner() statements := iNode.GetStatements() for { @@ -838,7 +839,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta subCtx.Session = ctx.Session operation := statements[counter] - newCounter, newSelSch, newSelIter, rowIter, err := execOp(subCtx, runner, stack, operation, statements, asOf, counter) + newCounter, sch, newSelIter, rowIter, err := execOp(subCtx, runner, stack, operation, statements, asOf, counter) if err != nil { hCounter, hErr := handleError(subCtx, runner, stack, statements, counter, err) if hErr != nil && hErr != io.EOF { @@ -852,10 +853,11 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta } if rowIter != nil { rowIters = append(rowIters, rowIter) + retSch = sch } if newSelIter != nil { selIter = newSelIter - selSch = newSelSch + selSch = sch } counter = newCounter } @@ -865,7 +867,10 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta return selIter, stack, nil } if len(rowIters) == 0 { + iNode.SetSchema(types.OkResultSchema) rowIters = append(rowIters, sql.RowsToRowIter(sql.Row{types.NewOkResult(0)})) + } else { + iNode.SetSchema(retSch) } // TODO: probably need to set result schema for these too From 7509fb7987d508b69535629a8a2d0e122ca84336 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 11 Apr 2025 12:56:39 -0700 Subject: [PATCH 085/111] this? --- sql/procedures/interpreter_logic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index b6a2838cc0..5056a26882 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -869,7 +869,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta if len(rowIters) == 0 { iNode.SetSchema(types.OkResultSchema) rowIters = append(rowIters, sql.RowsToRowIter(sql.Row{types.NewOkResult(0)})) - } else { + } else if retSch != nil { iNode.SetSchema(retSch) } From 25087dddeaa30f7de97834eb1ac4953cb54bfe6e Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 11 Apr 2025 15:38:24 -0700 Subject: [PATCH 086/111] asdfasdfafdg --- enginetest/memory_engine_test.go | 61 +++---------------------- enginetest/queries/procedure_queries.go | 4 ++ 2 files changed, 11 insertions(+), 54 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index c814ffc4f8..d37f0f13f0 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -205,65 +205,19 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "insert trigger with stored procedure with deletes", - SetUpScript: []string{ - "create table t (i int);", - "create table t1 (j int);", - "insert into t1 values (1);", - "create table t2 (k int);", - "insert into t2 values (1);", - "create table t3 (l int);", - "insert into t3 values (1);", - "create table t4 (m int);", - ` -create procedure proc(x int) -begin - delete from t2 where k = (select j from t1 where j = x); - update t3 set l = 10 where l = x; - insert into t4 values (x); -end; -`, - ` -create trigger trig before insert on t -for each row -begin - call proc(new.i); -end; -`, - }, + Name: "AS OF propagates to nested CALLs", + SetUpScript: []string{}, Assertions: []queries.ScriptTestAssertion{ { - Query: "insert into t values (1);", - Expected: []sql.Row{ - {types.NewOkResult(1)}, - }, - }, - { - Query: "select * from t;", + Query: "create procedure create_proc() create table t (i int primary key, j int);", Expected: []sql.Row{ - {1}, + {types.NewOkResult(0)}, }, }, { - Query: "select * from t1;", + Query: "call create_proc()", Expected: []sql.Row{ - {1}, - }, - }, - { - Query: "select * from t2;", - Expected: []sql.Row{}, - }, - { - Query: "select * from t3;", - Expected: []sql.Row{ - {10}, - }, - }, - { - Query: "select * from t4;", - Expected: []sql.Row{ - {1}, + {types.NewOkResult(0)}, }, }, }, @@ -272,8 +226,7 @@ end; for _, test := range scripts { harness := enginetest.NewMemoryHarness("", 1, testNumPartitions, true, nil) - // TODO: fix this - //harness.UseServer() + harness.UseServer() engine, err := harness.NewEngine(t) if err != nil { panic(err) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index c2a02d0266..7562e455ef 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -988,6 +988,8 @@ END;`, Expected: []sql.Row{{}}, }, { + // TODO: Set statements don't return anything for whatever reason + SkipResultCheckOnServerEngine: true, Query: "CALL p1(@x);", Expected: []sql.Row{ {types.NewOkResult(0)}, @@ -1396,6 +1398,8 @@ END;`, }, Assertions: []ScriptTestAssertion{ { + // TODO: Set statements don't return anything for whatever reason + SkipResultCheckOnServerEngine: true, Query: "CALL outer_declare();", Expected: []sql.Row{ {types.NewOkResult(0)}, From de57c567e3767e46b920667770dd4e1a7a68f2f6 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 11 Apr 2025 15:41:39 -0700 Subject: [PATCH 087/111] aaaaaaaa --- enginetest/server_engine.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enginetest/server_engine.go b/enginetest/server_engine.go index 868042d45e..251b558dc7 100644 --- a/enginetest/server_engine.go +++ b/enginetest/server_engine.go @@ -216,7 +216,7 @@ func (s *ServerQueryEngine) queryOrExec(ctx *sql.Context, stmt *gosql.Stmt, pars var err error switch parsed.(type) { // TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned. - case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush: + case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush: var rows *gosql.Rows if stmt != nil { rows, err = stmt.Query(args...) From 64ad194242231817ab6b4cf2b20602bf05109d82 Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 11 Apr 2025 22:43:06 +0000 Subject: [PATCH 088/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/memory_engine_test.go | 2 +- enginetest/queries/procedure_queries.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index d37f0f13f0..52d5523895 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -205,7 +205,7 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "AS OF propagates to nested CALLs", + Name: "AS OF propagates to nested CALLs", SetUpScript: []string{}, Assertions: []queries.ScriptTestAssertion{ { diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 7562e455ef..a2f416084a 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -990,7 +990,7 @@ END;`, { // TODO: Set statements don't return anything for whatever reason SkipResultCheckOnServerEngine: true, - Query: "CALL p1(@x);", + Query: "CALL p1(@x);", Expected: []sql.Row{ {types.NewOkResult(0)}, }, @@ -1400,7 +1400,7 @@ END;`, { // TODO: Set statements don't return anything for whatever reason SkipResultCheckOnServerEngine: true, - Query: "CALL outer_declare();", + Query: "CALL outer_declare();", Expected: []sql.Row{ {types.NewOkResult(0)}, }, From 5de9a32d3df43771ec740e27174e1f50607a87c7 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 11 Apr 2025 15:51:16 -0700 Subject: [PATCH 089/111] dfasdfasdf --- enginetest/engine_only_test.go | 1 - enginetest/memory_engine_test.go | 2 +- enginetest/server_engine.go | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/enginetest/engine_only_test.go b/enginetest/engine_only_test.go index 83c32f49ef..9996e63088 100644 --- a/enginetest/engine_only_test.go +++ b/enginetest/engine_only_test.go @@ -608,7 +608,6 @@ func TestTableFunctions(t *testing.T) { func TestExternalProcedures(t *testing.T) { harness := enginetest.NewDefaultMemoryHarness() harness.Setup(setup.MydbData) - harness.UseServer() for _, script := range queries.ExternalProcedureTests { func() { e, err := harness.NewEngine(t) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 52d5523895..cb5455eed8 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -202,7 +202,7 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - //t.Skip() + t.Skip() var scripts = []queries.ScriptTest{ { Name: "AS OF propagates to nested CALLs", diff --git a/enginetest/server_engine.go b/enginetest/server_engine.go index 251b558dc7..868042d45e 100644 --- a/enginetest/server_engine.go +++ b/enginetest/server_engine.go @@ -216,7 +216,7 @@ func (s *ServerQueryEngine) queryOrExec(ctx *sql.Context, stmt *gosql.Stmt, pars var err error switch parsed.(type) { // TODO: added `FLUSH` stmt here (should be `exec`) because we don't support `FLUSH BINARY LOGS` or `FLUSH ENGINE LOGS`, so nil schema is returned. - case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush: + case *sqlparser.Select, *sqlparser.SetOp, *sqlparser.Show, *sqlparser.Set, *sqlparser.Call, *sqlparser.Begin, *sqlparser.Use, *sqlparser.Load, *sqlparser.Execute, *sqlparser.Analyze, *sqlparser.Flush: var rows *gosql.Rows if stmt != nil { rows, err = stmt.Query(args...) From 2d4d652f89ee4ed0a7949cdcda77ee5eebb55fd3 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 11 Apr 2025 16:04:09 -0700 Subject: [PATCH 090/111] remove todos --- sql/expression/procedurereference.go | 2 -- sql/planbuilder/proc.go | 2 -- sql/procedures.go | 3 --- sql/procedures/interpreter_logic.go | 17 ----------------- sql/rowexec/rel.go | 1 - sql/session.go | 2 -- 6 files changed, 27 deletions(-) diff --git a/sql/expression/procedurereference.go b/sql/expression/procedurereference.go index 418954a5b2..ba92f546ce 100644 --- a/sql/expression/procedurereference.go +++ b/sql/expression/procedurereference.go @@ -23,8 +23,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) -// TODO: instead of procedure reference, copy stack from doltgres - // ProcedureReference contains the state for a single CALL statement of a stored procedure. type ProcedureReference struct { InnermostScope *procedureScope diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 9d4fdd10f9..3a3693b5b0 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -345,8 +345,6 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { b.handleErr(err) } - // TODO: build references here? - // TODO: here fill in x from session params := make([]sql.Expression, len(c.Params)) for i, param := range c.Params { if len(proc.Params) == len(c.Params) { diff --git a/sql/procedures.go b/sql/procedures.go index 893f832141..ed84b1a34c 100644 --- a/sql/procedures.go +++ b/sql/procedures.go @@ -27,9 +27,6 @@ type Interpreter interface { SetStatementRunner(ctx *Context, runner StatementRunner) Expression } -// TODO: InterpreterNode interface -// TODO: alternatively have plan.Call just have an interpreter expression - // StatementRunner is essentially an interface that the engine will implement. We cannot directly reference the engine // here as it will cause an import cycle, so this may be updated to suit any function changes that the engine // experiences. diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 5056a26882..5e6dd291bb 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -121,7 +121,6 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. e.Expr = newExpr.(ast.Expr) case *ast.Set: for _, setExpr := range e.Exprs { - // TODO: properly handle user scope variables newExpr, err := replaceVariablesInExpr(ctx, stack, setExpr.Expr, asOf) if err != nil { return nil, err @@ -137,7 +136,6 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. } case *ast.Call: for i := range e.Params { - // TODO: do not replace certain params newExpr, err := replaceVariablesInExpr(ctx, stack, e.Params[i], asOf) if err != nil { return nil, err @@ -163,7 +161,6 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. e.Rowcount = newRowCount.(ast.Expr) } case *ast.Into: - // TODO: somehow support select into variables for i := range e.Variables { newExpr, err := replaceVariablesInExpr(ctx, stack, e.Variables[i], asOf) if err != nil { @@ -433,7 +430,6 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac case OpCode_Declare: declareStmt := operation.PrimaryData.(*ast.Declare) - // TODO: duplicate conditions? if cond := declareStmt.Condition; cond != nil { condName := strings.ToLower(cond.Name) stateVal := cond.SqlStateValue @@ -457,13 +453,11 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac stack.NewCondition(condName, stateVal, num) } - // TODO: duplicate cursors? if cursor := declareStmt.Cursor; cursor != nil { cursorName := strings.ToLower(cursor.Name) stack.NewCursor(cursorName, cursor.SelectStmt) } - // TODO: duplicate handlers? if handler := declareStmt.Handler; handler != nil { if len(handler.ConditionValues) != 1 { return 0, nil, nil, nil, sql.ErrUnsupportedSyntax.New(ast.String(declareStmt)) @@ -487,7 +481,6 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac stack.NewHandler(hCond.ValueType, handler.Action, handler.Statement, counter) } - // TODO: duplicate variables? if vars := declareStmt.Variables; vars != nil { for _, decl := range vars.Names { varType, err := types.ColumnTypeToType(&vars.VarType) @@ -504,7 +497,6 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } case OpCode_Signal: - // TODO: copy logic from planbuilder/proc.go: buildSignal() signalStmt := operation.PrimaryData.(*ast.Signal) var msgTxt string var sqlState string @@ -712,7 +704,6 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac if err != nil { return 0, nil, nil, nil, err } - // TODO: exactly one result that is a bool for now row, err := rowIter.Next(ctx) if err != nil { return 0, nil, nil, nil, err @@ -767,7 +758,6 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac if err != nil { return 0, nil, nil, nil, err } - // TODO: create a OpCode_Call to store procedures in the stack sch, rowIter, err := query(ctx, runner, stmt.(ast.Statement)) if err != nil { return 0, nil, nil, nil, err @@ -812,12 +802,10 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta } } - // TODO: remove this; track last selectRowIter var selIter sql.RowIter var selSch sql.Schema // Run the statements - // TODO: eventually return multiple sql.RowIters var rowIters []sql.RowIter var retSch sql.Schema runner := iNode.GetRunner() @@ -831,10 +819,6 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta break } - // TODO: server engine can't run multiple statements in procedures because the ctx keeps getting cancelled - // from the TrackedRowIter. - // Uncancel query? - // Use subcontexts? subCtx := sql.NewContext(context.Background()) subCtx.Session = ctx.Session @@ -873,6 +857,5 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta iNode.SetSchema(retSch) } - // TODO: probably need to set result schema for these too return rowIters[len(rowIters)-1], stack, nil } diff --git a/sql/rowexec/rel.go b/sql/rowexec/rel.go index 5e31c4eaa6..79a706b43d 100644 --- a/sql/rowexec/rel.go +++ b/sql/rowexec/rel.go @@ -730,7 +730,6 @@ func (b *BaseBuilder) buildExternalProcedure(ctx *sql.Context, n *plan.ExternalP } for i, paramDefinition := range n.ParamDefinitions { if paramDefinition.Direction == plan.ProcedureParamDirection_Inout || paramDefinition.Direction == plan.ProcedureParamDirection_Out { - // TODO: not sure if we should still be doing this exprParam := n.Params[i] funcParamVal := funcParams[i+1].Elem().Interface() err := exprParam.Set(ctx, funcParamVal, exprParam.Type()) diff --git a/sql/session.go b/sql/session.go index c8c66beb7a..4bcb8ecd66 100644 --- a/sql/session.go +++ b/sql/session.go @@ -178,8 +178,6 @@ type Session interface { // ValidateSession provides integrators a chance to do any custom validation of this session before any query is // executed in it. For example, Dolt uses this hook to validate that the session's working set is valid. ValidateSession(ctx *Context) error - - //SetInStoredProcedure(val bool) } // PersistableSession supports serializing/deserializing global system variables/ From 58466c3b1606aac42b1933e7104804216bf74824 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 11 Apr 2025 16:05:36 -0700 Subject: [PATCH 091/111] skip for the server --- enginetest/queries/procedure_queries.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index a2f416084a..78c3c49f45 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2848,6 +2848,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { + SkipResultCheckOnServerEngine: true, Query: "call create_proc()", Expected: []sql.Row{ {types.NewOkResult(0)}, @@ -2875,6 +2876,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { + SkipResultCheckOnServerEngine: true, Query: "call insert_proc()", Expected: []sql.Row{ {types.NewOkResult(3)}, @@ -2889,6 +2891,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { + SkipResultCheckOnServerEngine: true, Query: "call insert_proc()", ExpectedErrStr: "duplicate primary key given: [1]", }, @@ -2900,6 +2903,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { + SkipResultCheckOnServerEngine: true, Query: "call update_proc()", Expected: []sql.Row{ {types.OkResult{RowsAffected: 2, Info: plan.UpdateInfo{Matched: 2, Updated: 2}}}, @@ -2914,6 +2918,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { + SkipResultCheckOnServerEngine: true, Query: "call update_proc()", Expected: []sql.Row{ {types.OkResult{RowsAffected: 0, Info: plan.UpdateInfo{Matched: 2}}}, @@ -2927,6 +2932,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, }, { + SkipResultCheckOnServerEngine: true, Query: "call drop_proc()", Expected: []sql.Row{ {types.NewOkResult(0)}, From 0d1adac080634793fa2453c99da393b959c180f1 Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 11 Apr 2025 23:07:02 +0000 Subject: [PATCH 092/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/queries/procedure_queries.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 78c3c49f45..7ca990a260 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -2849,7 +2849,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, { SkipResultCheckOnServerEngine: true, - Query: "call create_proc()", + Query: "call create_proc()", Expected: []sql.Row{ {types.NewOkResult(0)}, }, @@ -2877,7 +2877,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, { SkipResultCheckOnServerEngine: true, - Query: "call insert_proc()", + Query: "call insert_proc()", Expected: []sql.Row{ {types.NewOkResult(3)}, }, @@ -2892,8 +2892,8 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, { SkipResultCheckOnServerEngine: true, - Query: "call insert_proc()", - ExpectedErrStr: "duplicate primary key given: [1]", + Query: "call insert_proc()", + ExpectedErrStr: "duplicate primary key given: [1]", }, { @@ -2904,7 +2904,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, { SkipResultCheckOnServerEngine: true, - Query: "call update_proc()", + Query: "call update_proc()", Expected: []sql.Row{ {types.OkResult{RowsAffected: 2, Info: plan.UpdateInfo{Matched: 2, Updated: 2}}}, }, @@ -2919,7 +2919,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, { SkipResultCheckOnServerEngine: true, - Query: "call update_proc()", + Query: "call update_proc()", Expected: []sql.Row{ {types.OkResult{RowsAffected: 0, Info: plan.UpdateInfo{Matched: 2}}}, }, @@ -2933,7 +2933,7 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{ }, { SkipResultCheckOnServerEngine: true, - Query: "call drop_proc()", + Query: "call drop_proc()", Expected: []sql.Row{ {types.NewOkResult(0)}, }, From 4d62d318f2e44949a6cb8d676f7ae3cb80e0f1f0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 14 Apr 2025 15:03:27 -0700 Subject: [PATCH 093/111] bump --- go.mod | 2 +- go.sum | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 24bed70fd0..c461a3e5bf 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250414165810-f0031a6472b7 + github.com/dolthub/vitess v0.0.0-20250414213151-810e7add1b8e github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 1611fa4492..d7e1481889 100644 --- a/go.sum +++ b/go.sum @@ -58,10 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250410233614-8d8c7a5b3d6b h1:2wE+qJwJ5SRIzz+dJQT8XbkpK+g8/pFt34AU/iJ5K+Y= -github.com/dolthub/vitess v0.0.0-20250410233614-8d8c7a5b3d6b/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250414165810-f0031a6472b7 h1:4Y043kZgAH1WhOER0nk+02KPKxJX8Ir6yK7cGzY04c4= -github.com/dolthub/vitess v0.0.0-20250414165810-f0031a6472b7/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250414213151-810e7add1b8e h1:EASQ7Rjk4lbQjwqQaBmjjGxpbRRQqlnudEdmJZGjk/A= +github.com/dolthub/vitess v0.0.0-20250414213151-810e7add1b8e/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From b7d7c6143bab34ffe248bcb350c5cd28b8f96602 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 18 Apr 2025 15:13:39 -0700 Subject: [PATCH 094/111] bump --- go.mod | 2 +- go.sum | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/go.mod b/go.mod index 5f3e502498..f7285c5349 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250417230335-b8d80bc39341 + github.com/dolthub/vitess v0.0.0-20250418221234-8d272f40a217 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 720b5a1056..91ce51c915 100644 --- a/go.sum +++ b/go.sum @@ -58,12 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250410233614-8d8c7a5b3d6b h1:2wE+qJwJ5SRIzz+dJQT8XbkpK+g8/pFt34AU/iJ5K+Y= -github.com/dolthub/vitess v0.0.0-20250410233614-8d8c7a5b3d6b/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250414165810-f0031a6472b7 h1:4Y043kZgAH1WhOER0nk+02KPKxJX8Ir6yK7cGzY04c4= -github.com/dolthub/vitess v0.0.0-20250414165810-f0031a6472b7/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250417230335-b8d80bc39341 h1:qebIGlJEgi/mSXVZ39P77cklPuuIl8gApyTVMnKm79s= -github.com/dolthub/vitess v0.0.0-20250417230335-b8d80bc39341/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250418221234-8d272f40a217 h1:QcVH/2VQrEv+azfhqUU1wkkP6fsbJfzx5JO1dqx/DwY= +github.com/dolthub/vitess v0.0.0-20250418221234-8d272f40a217/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From e64f37f902e23bcc094ff972af7809e924dc3151 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 21 Apr 2025 11:11:38 -0700 Subject: [PATCH 095/111] some feedback --- sql/base_session.go | 4 ++++ sql/core.go | 4 ++++ sql/plan/call.go | 4 +--- sql/plan/procedure.go | 1 + sql/procedures/interpreter_logic.go | 18 +++++++++++++----- sql/procedures/interpreter_stack.go | 4 ++++ sql/rowexec/proc.go | 10 ++++++---- 7 files changed, 33 insertions(+), 12 deletions(-) diff --git a/sql/base_session.go b/sql/base_session.go index dd775a0952..63b03ae9fa 100644 --- a/sql/base_session.go +++ b/sql/base_session.go @@ -254,6 +254,7 @@ func (s *BaseSession) IncrementStatusVariable(ctx *Context, statVarName string, return } +// NewStoredProcParam creates a new Stored Procedure Parameter in the Session func (s *BaseSession) NewStoredProcParam(name string, param *StoredProcParam) { if _, ok := s.storedProcParams[name]; ok { return @@ -261,6 +262,7 @@ func (s *BaseSession) NewStoredProcParam(name string, param *StoredProcParam) { s.storedProcParams[name] = param } +// GetStoredProcParam retrieves the named stored procedure parameter, from the Session, returning nil if not found. func (s *BaseSession) GetStoredProcParam(name string) *StoredProcParam { if param, ok := s.storedProcParams[name]; ok { return param @@ -268,6 +270,8 @@ func (s *BaseSession) GetStoredProcParam(name string) *StoredProcParam { return nil } +// SetStoredProcParam sets the named Stored Procedure Parameter from the Session to val and marks it as HasSet. +// If the Parameter has not been initialized, this will throw an error. func (s *BaseSession) SetStoredProcParam(name string, val any) error { param := s.GetStoredProcParam(name) if param == nil { diff --git a/sql/core.go b/sql/core.go index c169b0eec6..3318835884 100644 --- a/sql/core.go +++ b/sql/core.go @@ -880,6 +880,9 @@ func IncrementStatusVariable(ctx *Context, name string, val int) { ctx.Session.IncrementStatusVariable(ctx, name, val) } +// StoredProcParam is a Parameter for a Stored Procedure. +// Stored Procedures Parameters can be referenced from within other Stored Procedures, so we need to store them +// somewhere that is accessible between interpreter calls to the engine. type StoredProcParam struct { Type Type Value any @@ -887,6 +890,7 @@ type StoredProcParam struct { Reference *StoredProcParam } +// SetValue saves val to the StoredProcParam, and set HasBeenSet to true. func (s *StoredProcParam) SetValue(val any) { s.Value = val s.HasBeenSet = true diff --git a/sql/plan/call.go b/sql/plan/call.go index b3ecba4ec2..00d28e9807 100644 --- a/sql/plan/call.go +++ b/sql/plan/call.go @@ -23,8 +23,6 @@ import ( "github.com/dolthub/go-mysql-server/sql/types" ) -// TODO: we need different types of calls: one for external procedures one for stored procedures - type Call struct { db sql.Database Name string @@ -39,7 +37,7 @@ type Call struct { Runner sql.StatementRunner Ops []procedures.InterpreterOperation - // TODO: sure whatever + // retain the result schema resSch sql.Schema } diff --git a/sql/plan/procedure.go b/sql/plan/procedure.go index 50222ad68c..9dcddb6dcb 100644 --- a/sql/plan/procedure.go +++ b/sql/plan/procedure.go @@ -138,6 +138,7 @@ func (p *Procedure) Resolved() bool { return true } +// IsReadOnly implements the sql.Node interface. func (p *Procedure) IsReadOnly() bool { if p.ExternalProc != nil { return p.ExternalProc.IsReadOnly() diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 5e6dd291bb..f13e8b86df 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -40,6 +40,8 @@ type InterpreterNode interface { SetSchema(sch sql.Schema) } +// replaceVariablesInExpr will search for every ast.Node and handle each one on a case by case basis. +// If a new ast.Node is added to the vitess parser we may need to add a case for it here. func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast.SQLNode, asOf *ast.AsOf) (ast.SQLNode, error) { switch e := expr.(type) { case *ast.ColName: @@ -286,6 +288,9 @@ func query(ctx *sql.Context, runner sql.StatementRunner, stmt ast.Statement) (sq if rErr == io.EOF { break } + if cErr := rowIter.Close(ctx); cErr != nil { + return nil, nil, cErr + } return nil, nil, rErr } rows = append(rows, row) @@ -402,6 +407,9 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac return 0, nil, nil, nil, err } if _, err = rowIter.Next(ctx); err != io.EOF { + if rErr := rowIter.Close(ctx); rErr != nil { + return 0, nil, nil, nil, rErr + } return 0, nil, nil, nil, err } if err = rowIter.Close(ctx); err != nil { @@ -507,7 +515,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac return 0, nil, nil, nil, fmt.Errorf("warnings not yet implemented") } } else { - cond := stack.GetCondition(strings.ToLower(signalStmt.ConditionName)) + cond := stack.GetCondition(signalStmt.ConditionName) if cond == nil { return 0, nil, nil, nil, sql.ErrDeclareConditionNotFound.New(signalStmt.ConditionName) } @@ -587,7 +595,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac case OpCode_Open: openCur := operation.PrimaryData.(*ast.OpenCursor) - cursor := stack.GetCursor(strings.ToLower(openCur.Name)) + cursor := stack.GetCursor(openCur.Name) if cursor == nil { return 0, nil, nil, nil, sql.ErrCursorNotFound.New(openCur.Name) } @@ -607,7 +615,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac case OpCode_Fetch: fetchCur := operation.PrimaryData.(*ast.FetchCursor) - cursor := stack.GetCursor(strings.ToLower(fetchCur.Name)) + cursor := stack.GetCursor(fetchCur.Name) if cursor == nil { return 0, nil, nil, nil, sql.ErrCursorNotFound.New(fetchCur.Name) } @@ -641,7 +649,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac case OpCode_Close: closeCur := operation.PrimaryData.(*ast.CloseCursor) - cursor := stack.GetCursor(strings.ToLower(closeCur.Name)) + cursor := stack.GetCursor(closeCur.Name) if cursor == nil { return 0, nil, nil, nil, sql.ErrCursorNotFound.New(closeCur.Name) } @@ -680,7 +688,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac return 0, nil, nil, nil, err } - err = stack.SetVariable(strings.ToLower(operation.Target), row[0]) + err = stack.SetVariable(operation.Target, row[0]) if err != nil { err = ctx.Session.SetStoredProcParam(operation.Target, row[0]) if err != nil { diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 1bc75cf8b7..1125345583 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -17,6 +17,7 @@ package procedures import ( "fmt" "strconv" + "strings" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" @@ -205,6 +206,7 @@ func (is *InterpreterStack) NewVariableAlias(alias string, variable *Interpreter // GetVariable traverses the stack (starting from the top) to find a variable with a matching name. Returns nil if no // variable was found. func (is *InterpreterStack) GetVariable(name string) *InterpreterVariable { + name = strings.ToLower(name) for i := 0; i < is.stack.Len(); i++ { if iv, ok := is.stack.PeekDepth(i).variables[name]; ok { return iv @@ -248,6 +250,7 @@ func (is *InterpreterStack) NewCondition(name string, sqlState string, mysqlErrC // GetCondition traverses the stack (starting from the top) to find a condition with a matching name. Returns nil if no // variable was found. func (is *InterpreterStack) GetCondition(name string) *InterpreterCondition { + name = strings.ToLower(name) for i := 0; i < is.stack.Len(); i++ { if ic, ok := is.stack.PeekDepth(i).conditions[name]; ok { return ic @@ -266,6 +269,7 @@ func (is *InterpreterStack) NewCursor(name string, selStmt ast.SelectStatement) // GetCursor traverses the stack (starting from the top) to find a condition with a matching name. Returns nil if no // variable was found. func (is *InterpreterStack) GetCursor(name string) *InterpreterCursor { + name = strings.ToLower(name) for i := 0; i < is.stack.Len(); i++ { if ic, ok := is.stack.PeekDepth(i).cursors[name]; ok { return ic diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index 150cce4c07..b0fd053443 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -262,11 +262,13 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq // We might close transactions in the procedure, so we need to start a new one if we're not in one already if sess, ok := ctx.Session.(sql.TransactionSession); ok { - tx, tErr := sess.StartTransaction(ctx, sql.ReadWrite) - if tErr != nil { - return nil, tErr + if tx := ctx.GetTransaction(); tx == nil { + tx, err = sess.StartTransaction(ctx, sql.ReadWrite) + if err != nil { + return nil, err + } + ctx.SetTransaction(tx) } - ctx.SetTransaction(tx) } return &callIter{ From 8450a2a69cd0a26316b1f13febf7cacefbd4173f Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 22 Apr 2025 12:05:04 -0700 Subject: [PATCH 096/111] preserve transactions --- sql/rowexec/proc.go | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/sql/rowexec/proc.go b/sql/rowexec/proc.go index b0fd053443..f456354862 100644 --- a/sql/rowexec/proc.go +++ b/sql/rowexec/proc.go @@ -227,6 +227,11 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq } } + // Preserve existing transaction + oldTx := ctx.GetTransaction() + defer ctx.SetTransaction(oldTx) + ctx.SetTransaction(nil) + rowIter, _, err := procedures.Call(ctx, n) if err != nil { return nil, err @@ -260,17 +265,6 @@ func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sq } } - // We might close transactions in the procedure, so we need to start a new one if we're not in one already - if sess, ok := ctx.Session.(sql.TransactionSession); ok { - if tx := ctx.GetTransaction(); tx == nil { - tx, err = sess.StartTransaction(ctx, sql.ReadWrite) - if err != nil { - return nil, err - } - ctx.SetTransaction(tx) - } - } - return &callIter{ call: n, innerIter: rowIter, From 4f4eac7b36b518883dbf8bc00bca95a811bd61f0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 22 Apr 2025 13:01:59 -0700 Subject: [PATCH 097/111] comments --- sql/base_session.go | 2 ++ sql/planbuilder/proc.go | 5 ++++- sql/procedures/interpreter_operation.go | 3 +-- sql/session.go | 7 +++---- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/sql/base_session.go b/sql/base_session.go index 63b03ae9fa..9ea4aefa9c 100644 --- a/sql/base_session.go +++ b/sql/base_session.go @@ -256,6 +256,7 @@ func (s *BaseSession) IncrementStatusVariable(ctx *Context, statVarName string, // NewStoredProcParam creates a new Stored Procedure Parameter in the Session func (s *BaseSession) NewStoredProcParam(name string, param *StoredProcParam) { + name = strings.ToLower(name) if _, ok := s.storedProcParams[name]; ok { return } @@ -264,6 +265,7 @@ func (s *BaseSession) NewStoredProcParam(name string, param *StoredProcParam) { // GetStoredProcParam retrieves the named stored procedure parameter, from the Session, returning nil if not found. func (s *BaseSession) GetStoredProcParam(name string) *StoredProcParam { + name = strings.ToLower(name) if param, ok := s.storedProcParams[name]; ok { return param } diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 3a3693b5b0..70ba6308ab 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -347,12 +347,15 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { params := make([]sql.Expression, len(c.Params)) for i, param := range c.Params { + // While it is possible to detect a parameter count mismatch here and throw an error, + // there's some weirdness involving external procedures. The analyzer rule applyProceduresCall will + // catch this discrepancy. if len(proc.Params) == len(c.Params) { procParam := proc.Params[i] rspp := &sql.StoredProcParam{Type: procParam.Type} b.ctx.Session.NewStoredProcParam(procParam.Name, rspp) if col, isCol := param.(*ast.ColName); isCol { - colName := col.Name.String() // TODO: to lower? + colName := col.Name.String() if spp := b.ctx.Session.GetStoredProcParam(colName); spp != nil { iv := &procedures.InterpreterVariable{ Type: spp.Type, diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index bc833bcff4..d2d1d08e4f 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -12,8 +12,7 @@ package procedures import ast "github.com/dolthub/vitess/go/vt/sqlparser" -// OpCode states the operation to be performed. Most operations have a direct analogue to a Pl/pgSQL operation, however -// some exist only in Doltgres (specific to our interpreter implementation). +// OpCode is the internal representation queries run by Stored Procedures. type OpCode uint16 const ( diff --git a/sql/session.go b/sql/session.go index 6f13313071..1dd6867390 100644 --- a/sql/session.go +++ b/sql/session.go @@ -92,13 +92,12 @@ type Session interface { GetAllStatusVariables(ctx *Context) map[string]StatusVarValue // IncrementStatusVariable increments the value of the status variable by the integer value IncrementStatusVariable(ctx *Context, statVarName string, val int) - + // NewStoredProcParam creates a new Stored Procedure Parameter in the Session. NewStoredProcParam(name string, param *StoredProcParam) - + // GetStoredProcParam finds and returns the Stored Procedure Parameter by the given name. GetStoredProcParam(name string) *StoredProcParam - + // SetStoredProcParam sets the Stored Procedure Parameter of the given name to the given val. SetStoredProcParam(name string, val any) error - // GetCurrentDatabase gets the current database for this session GetCurrentDatabase() string // SetCurrentDatabase sets the current database for this session From 74a0a8eef6ebe1c297367317505a4d79610f9524 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 23 Apr 2025 19:18:45 -0700 Subject: [PATCH 098/111] progress --- sql/planbuilder/proc.go | 2 -- sql/procedures/interpreter_logic.go | 10 ++++++++++ sql/procedures/interpreter_operation.go | 18 +++++++++--------- sql/procedures/statements.go | 9 +-------- sql/rowexec/rel.go | 5 +---- 5 files changed, 21 insertions(+), 23 deletions(-) diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 70ba6308ab..9961329bea 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -314,14 +314,12 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { } if esp != nil { proc, err = resolveExternalStoredProcedure(*esp) - // TODO: return plan.NewExternalCall here } else if spdb, ok := db.(sql.StoredProcedureDatabase); ok { var procDetails sql.StoredProcedureDetails procDetails, ok, err = spdb.GetStoredProcedure(b.ctx, procName) if err == nil { if ok { proc, innerQFlags, err = BuildProcedureHelper(b.ctx, b.cat, false, inScope, db, asOf, procDetails) - // TODO: somewhat hacky way of preserving this flag // This is necessary so that the resolveSubqueries analyzer rule // will apply NodeExecBuilder to Subqueries in procedure body if innerQFlags.IsSet(sql.QFlagScalarSubquery) { diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index f13e8b86df..97b6795fc7 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -272,7 +272,14 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. } e.Where.Expr = newExpr.(ast.Expr) } + case *ast.ConvertExpr: + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Expr, asOf) + if err != nil { + return nil, err + } + e.Expr = newExpr.(ast.Expr) } + return expr, nil } @@ -679,6 +686,9 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } row, err := rowIter.Next(ctx) if err != nil { + if cErr := rowIter.Close(ctx); cErr != nil { + return 0, nil, nil, nil, cErr + } return 0, nil, nil, nil, err } if _, err = rowIter.Next(ctx); err != io.EOF { diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index d2d1d08e4f..e2e6506976 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -16,20 +16,20 @@ import ast "github.com/dolthub/vitess/go/vt/sqlparser" type OpCode uint16 const ( - OpCode_Select OpCode = iota - OpCode_Declare // https://www.postgresql.org/docs/15/plpgsql-declarations.html + OpCode_Select OpCode = iota + OpCode_Declare OpCode_Signal OpCode_Open OpCode_Fetch OpCode_Close OpCode_Set - OpCode_If // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-CONDITIONALS - OpCode_Goto // All control-flow structures can be represented using Goto - OpCode_Execute // Everything that's not a SELECT - OpCode_Exception // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-ERROR-TRAPPING - OpCode_Return // https://www.postgresql.org/docs/15/plpgsql-control-structures.html#PLPGSQL-STATEMENTS-RETURNING - OpCode_ScopeBegin // This is used for scope control, specific to Doltgres - OpCode_ScopeEnd // This is used for scope control, specific to Doltgres + OpCode_If + OpCode_Goto + OpCode_Execute + OpCode_Exception + OpCode_Return + OpCode_ScopeBegin + OpCode_ScopeEnd ) // InterpreterOperation is an operation that will be performed by the interpreter. diff --git a/sql/procedures/statements.go b/sql/procedures/statements.go index 745d2b87b3..2503735920 100644 --- a/sql/procedures/statements.go +++ b/sql/procedures/statements.go @@ -14,7 +14,7 @@ package procedures -// Statement represents a PL/pgSQL statement. +// Statement represents a Stored Procedure Statement. type Statement interface { // OperationSize reports the number of operations that the statement will convert to. OperationSize() int32 @@ -75,13 +75,6 @@ func (stmt Block) AppendOperations(ops *[]InterpreterOperation, stack *Interpret OpCode: OpCode_ScopeBegin, }) for _, variable := range stmt.Variable { - //if !variable.IsParameter { - // *ops = append(*ops, InterpreterOperation{ - // OpCode: OpCode_Declare, - // PrimaryData: variable.Type, - // Target: variable.Name, - // }) - //} stack.NewVariableWithValue(variable.Name, nil, nil) } for _, innerStmt := range stmt.Body { diff --git a/sql/rowexec/rel.go b/sql/rowexec/rel.go index 79a706b43d..041ed8f525 100644 --- a/sql/rowexec/rel.go +++ b/sql/rowexec/rel.go @@ -736,10 +736,7 @@ func (b *BaseBuilder) buildExternalProcedure(ctx *sql.Context, n *plan.ExternalP if err != nil { return nil, err } - err = ctx.Session.SetStoredProcParam(exprParam.Name(), funcParamVal) - if err != nil { - return nil, err - } + _ = ctx.Session.SetStoredProcParam(exprParam.Name(), funcParamVal) } } // It's not invalid to return a nil RowIter, as having no rows to return is expected of many stored procedures. From 15d34c0d3fc001e9f9c8602bcc69c128f85650bb Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 23 Apr 2025 23:27:19 -0700 Subject: [PATCH 099/111] progress --- sql/procedures/interpreter_logic.go | 15 +++++++++++++++ sql/procedures/interpreter_stack.go | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 97b6795fc7..618eece610 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -147,6 +147,9 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. if e.AsOf == nil && asOf != nil { e.AsOf = asOf.Time } + if len(e.ProcName.Qualifier.String()) == 0 { + e.ProcName.Qualifier = ast.NewTableIdent(stack.GetDatabase()) + } case *ast.Limit: newOffset, err := replaceVariablesInExpr(ctx, stack, e.Offset, asOf) if err != nil { @@ -251,12 +254,18 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. e.Values[i] = newExpr.(ast.ValTuple) } case *ast.Insert: + if asOf != nil { + return nil, sql.ErrProcedureCallAsOfReadOnly.New() + } newExpr, err := replaceVariablesInExpr(ctx, stack, e.Rows, asOf) if err != nil { return nil, err } e.Rows = newExpr.(ast.InsertRows) case *ast.Delete: + if asOf != nil { + return nil, sql.ErrProcedureCallAsOfReadOnly.New() + } if e.Where != nil { newExpr, err := replaceVariablesInExpr(ctx, stack, e.Where.Expr, asOf) if err != nil { @@ -265,6 +274,9 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. e.Where.Expr = newExpr.(ast.Expr) } case *ast.Update: + if asOf != nil { + return nil, sql.ErrProcedureCallAsOfReadOnly.New() + } if e.Where != nil { newExpr, err := replaceVariablesInExpr(ctx, stack, e.Where.Expr, asOf) if err != nil { @@ -828,6 +840,9 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta var retSch sql.Schema runner := iNode.GetRunner() statements := iNode.GetStatements() + if dbNode, isDbNode := iNode.(sql.Databaser); isDbNode { + stack.SetDatabase(dbNode.Database().Name()) + } for { counter++ if counter < 0 { diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 1125345583..3ba2201227 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -152,6 +152,9 @@ type InterpreterScopeDetails struct { // labels mark the counter of the start of a loop or block. labels map[string]int + + // database is the current database for this scope. + database string } // InterpreterStack represents the working information that an interpreter will use during execution. It is not exactly @@ -304,6 +307,21 @@ func (is *InterpreterStack) NewLabel(name string, index int) { is.stack.Peek().labels[name] = index } +// GetDatabase returns the current database for this scope. +func (is *InterpreterStack) GetDatabase() string { + for i := 0; i < is.stack.Len(); i++ { + if db := is.stack.PeekDepth(i).database; db != "" { + return db + } + } + return "" +} + +// SetDatabase sets the current database for this scope. +func (is *InterpreterStack) SetDatabase(db string) { + is.stack.Peek().database = db +} + // GetLabel traverses the stack (starting from the top) to find a label with a matching name. Returns -1 if no // variable was found. func (is *InterpreterStack) GetLabel(name string) int { From a34f76e34a69b6b92c2cec1145278d8474f9f8c8 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 24 Apr 2025 01:46:31 -0700 Subject: [PATCH 100/111] progress --- sql/procedures/interpreter_logic.go | 36 +++++++++++++++++++++++++ sql/procedures/interpreter_operation.go | 1 + sql/procedures/parse.go | 7 +++++ 3 files changed, 44 insertions(+) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 618eece610..7bc57fcfbd 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -718,6 +718,42 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } } + case OpCode_Call: + stmt, err := replaceVariablesInExpr(ctx, stack, operation.PrimaryData, asOf) + if err != nil { + return 0, nil, nil, nil, err + } + // put stack variables into session variables + callStmt := stmt.(*ast.Call) + stackToParam := make(map[*InterpreterVariable]*sql.StoredProcParam) + for _, param := range callStmt.Params { + colName, isColName := param.(*ast.ColName) + if !isColName { + continue + } + paramName := colName.Name.String() + iv := stack.GetVariable(paramName) + if iv == nil { + continue + } + spp := &sql.StoredProcParam{ + Type: iv.Type, + Value: iv.Value, + } + ctx.Session.NewStoredProcParam(paramName, spp) + stackToParam[iv] = spp + } + sch, rowIter, err := query(ctx, runner, callStmt) + if err != nil { + return 0, nil, nil, nil, err + } + // assign stored proc params to stack variables + for iv, spp := range stackToParam { + iv.Value = spp.Value + } + + return counter, sch, nil, rowIter, err + case OpCode_If: selectStmt := operation.PrimaryData.(*ast.Select) if selectStmt.SelectExprs == nil { diff --git a/sql/procedures/interpreter_operation.go b/sql/procedures/interpreter_operation.go index e2e6506976..47c8a40774 100644 --- a/sql/procedures/interpreter_operation.go +++ b/sql/procedures/interpreter_operation.go @@ -23,6 +23,7 @@ const ( OpCode_Fetch OpCode_Close OpCode_Set + OpCode_Call OpCode_If OpCode_Goto OpCode_Execute diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 4abcf65b8c..9e9de9bd12 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -141,6 +141,13 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast } *ops = append(*ops, setOp) + case *ast.Call: + callOp := &InterpreterOperation{ + OpCode: OpCode_Call, + PrimaryData: s, + } + *ops = append(*ops, callOp) + case *ast.IfStatement: var ifElseGotoOps []*InterpreterOperation for _, ifCond := range s.Conditions { From b67e4ecf250a3063a514c1d666c3c6c179947002 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Apr 2025 11:11:47 -0700 Subject: [PATCH 101/111] bump --- go.mod | 2 +- go.sum | 10 ++-------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index 53bde2914c..63db593da9 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250423221552-f731ee5c5379 + github.com/dolthub/vitess v0.0.0-20250424225619-bfb20390c1d1 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 03ff90b801..593f8000cf 100644 --- a/go.sum +++ b/go.sum @@ -58,14 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250410233614-8d8c7a5b3d6b h1:2wE+qJwJ5SRIzz+dJQT8XbkpK+g8/pFt34AU/iJ5K+Y= -github.com/dolthub/vitess v0.0.0-20250410233614-8d8c7a5b3d6b/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250414165810-f0031a6472b7 h1:4Y043kZgAH1WhOER0nk+02KPKxJX8Ir6yK7cGzY04c4= -github.com/dolthub/vitess v0.0.0-20250414165810-f0031a6472b7/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250417230335-b8d80bc39341 h1:qebIGlJEgi/mSXVZ39P77cklPuuIl8gApyTVMnKm79s= -github.com/dolthub/vitess v0.0.0-20250417230335-b8d80bc39341/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= -github.com/dolthub/vitess v0.0.0-20250423221552-f731ee5c5379 h1:3nPFx23Ol0djIPf9rDw/y38yEn1BXqTXOUkYrWfxrEI= -github.com/dolthub/vitess v0.0.0-20250423221552-f731ee5c5379/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250424225619-bfb20390c1d1 h1:ZXssuc0ZNqKUD7xQCd0/xLT+nKrmAetNYb5v7xgU0U0= +github.com/dolthub/vitess v0.0.0-20250424225619-bfb20390c1d1/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= From 7862ef1037b3075a74a77a49728cb160d3561f16 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Apr 2025 12:39:44 -0700 Subject: [PATCH 102/111] fix? --- sql/base_session.go | 7 ++++--- sql/planbuilder/proc.go | 6 +++--- sql/procedures/interpreter_logic.go | 2 +- sql/session.go | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/base_session.go b/sql/base_session.go index 9ea4aefa9c..7d2a43037c 100644 --- a/sql/base_session.go +++ b/sql/base_session.go @@ -255,12 +255,13 @@ func (s *BaseSession) IncrementStatusVariable(ctx *Context, statVarName string, } // NewStoredProcParam creates a new Stored Procedure Parameter in the Session -func (s *BaseSession) NewStoredProcParam(name string, param *StoredProcParam) { +func (s *BaseSession) NewStoredProcParam(name string, param *StoredProcParam) *StoredProcParam { name = strings.ToLower(name) - if _, ok := s.storedProcParams[name]; ok { - return + if spp, ok := s.storedProcParams[name]; ok { + return spp } s.storedProcParams[name] = param + return param } // GetStoredProcParam retrieves the named stored procedure parameter, from the Session, returning nil if not found. diff --git a/sql/planbuilder/proc.go b/sql/planbuilder/proc.go index 9961329bea..cbd844a543 100644 --- a/sql/planbuilder/proc.go +++ b/sql/planbuilder/proc.go @@ -350,8 +350,8 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { // catch this discrepancy. if len(proc.Params) == len(c.Params) { procParam := proc.Params[i] - rspp := &sql.StoredProcParam{Type: procParam.Type} - b.ctx.Session.NewStoredProcParam(procParam.Name, rspp) + rSpp := &sql.StoredProcParam{Type: procParam.Type} + rSpp = b.ctx.Session.NewStoredProcParam(procParam.Name, rSpp) if col, isCol := param.(*ast.ColName); isCol { colName := col.Name.String() if spp := b.ctx.Session.GetStoredProcParam(colName); spp != nil { @@ -360,7 +360,7 @@ func (b *Builder) buildCall(inScope *scope, c *ast.Call) (outScope *scope) { Value: spp.Value, } param = iv.ToAST() - rspp.Reference = spp + rSpp.Reference = spp } } } diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 7bc57fcfbd..c22ff44df8 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -740,7 +740,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac Type: iv.Type, Value: iv.Value, } - ctx.Session.NewStoredProcParam(paramName, spp) + spp = ctx.Session.NewStoredProcParam(paramName, spp) stackToParam[iv] = spp } sch, rowIter, err := query(ctx, runner, callStmt) diff --git a/sql/session.go b/sql/session.go index 1dd6867390..4aa5f0747c 100644 --- a/sql/session.go +++ b/sql/session.go @@ -93,7 +93,7 @@ type Session interface { // IncrementStatusVariable increments the value of the status variable by the integer value IncrementStatusVariable(ctx *Context, statVarName string, val int) // NewStoredProcParam creates a new Stored Procedure Parameter in the Session. - NewStoredProcParam(name string, param *StoredProcParam) + NewStoredProcParam(name string, param *StoredProcParam) *StoredProcParam // GetStoredProcParam finds and returns the Stored Procedure Parameter by the given name. GetStoredProcParam(name string) *StoredProcParam // SetStoredProcParam sets the Stored Procedure Parameter of the given name to the given val. From c601f3f79b26224554af33929e99bc909977da10 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Apr 2025 13:29:45 -0700 Subject: [PATCH 103/111] remove unused functions --- sql/procedures/statements.go | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/sql/procedures/statements.go b/sql/procedures/statements.go index 2503735920..777d6ddeb6 100644 --- a/sql/procedures/statements.go +++ b/sql/procedures/statements.go @@ -16,8 +16,6 @@ package procedures // Statement represents a Stored Procedure Statement. type Statement interface { - // OperationSize reports the number of operations that the statement will convert to. - OperationSize() int32 // AppendOperations adds the statement to the operation slice. AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error } @@ -54,20 +52,6 @@ type Block struct { var _ Statement = Block{} -// OperationSize implements the interface Statement. -func (stmt Block) OperationSize() int32 { - total := int32(2) // We start with 2 since we'll have ScopeBegin and ScopeEnd - for _, variable := range stmt.Variable { - if !variable.IsParameter { - total++ - } - } - for _, innerStmt := range stmt.Body { - total += innerStmt.OperationSize() - } - return total -} - // AppendOperations implements the interface Statement. func (stmt Block) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { stack.PushScope() @@ -97,11 +81,6 @@ type ExecuteSQL struct { var _ Statement = ExecuteSQL{} -// OperationSize implements the interface Statement. -func (ExecuteSQL) OperationSize() int32 { - return 1 -} - // AppendOperations implements the interface Statement. func (stmt ExecuteSQL) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { *ops = append(*ops, InterpreterOperation{ @@ -201,12 +180,3 @@ type Variable struct { Type string IsParameter bool } - -// OperationSizeForStatements returns the sum of OperationSize for every statement. -func OperationSizeForStatements(stmts []Statement) int32 { - total := int32(0) - for _, stmt := range stmts { - total += stmt.OperationSize() - } - return total -} From 3cbd94dc3a93e301f62972699872acae7d37bc0f Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Apr 2025 21:01:05 -0700 Subject: [PATCH 104/111] rename interpreter expression --- sql/analyzer/interpreter.go | 2 +- sql/procedures.go | 6 ------ sql/procedures/interpreter_logic.go | 6 ++++++ 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/analyzer/interpreter.go b/sql/analyzer/interpreter.go index 9a2b559bf4..5ae1ac4473 100644 --- a/sql/analyzer/interpreter.go +++ b/sql/analyzer/interpreter.go @@ -34,7 +34,7 @@ func interpreter(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, s } newNode, sameExpr, err := transform.NodeExprs(newNode, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - if interp, ok := expr.(sql.Interpreter); ok { + if interp, ok := expr.(procedures.InterpreterExpr); ok { return interp.SetStatementRunner(ctx, a.Runner), transform.NewTree, nil } return expr, transform.SameTree, nil diff --git a/sql/procedures.go b/sql/procedures.go index ed84b1a34c..20c766cd78 100644 --- a/sql/procedures.go +++ b/sql/procedures.go @@ -21,12 +21,6 @@ import ( "github.com/dolthub/vitess/go/vt/sqlparser" ) -// Interpreter is an interface that implements an interpreter. These are typically used for functions (which may be -// implemented as a set of operations that are interpreted during runtime). -type Interpreter interface { - SetStatementRunner(ctx *Context, runner StatementRunner) Expression -} - // StatementRunner is essentially an interface that the engine will implement. We cannot directly reference the engine // here as it will cause an import cycle, so this may be updated to suit any function changes that the engine // experiences. diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index c22ff44df8..da9517c7f2 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -30,6 +30,12 @@ import ( ast "github.com/dolthub/vitess/go/vt/sqlparser" ) +// InterpreterExpr is an interface that implements an interpreter. These are typically used for functions (which may be +// implemented as a set of operations that are interpreted during runtime). +type InterpreterExpr interface { + SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Expression +} + // InterpreterNode is an interface that implements an interpreter. These are typically used for functions (which may be // implemented as a set of operations that are interpreted during runtime). type InterpreterNode interface { From bd614429a4ab55cba2b550e0c0c166b67cdd96c7 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Apr 2025 21:19:57 -0700 Subject: [PATCH 105/111] close iters when popping --- sql/procedures/interpreter_logic.go | 6 +++--- sql/procedures/interpreter_stack.go | 13 +++++++++++-- sql/procedures/statements.go | 22 ++++++++++++---------- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index da9517c7f2..95dea60da2 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -804,7 +804,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac case OpCode_ScopeBegin: stack.PushScope() case OpCode_ScopeEnd: - stack.PopScope() + stack.PopScope(ctx) default: // No-op } @@ -816,7 +816,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac } switch statements[counter].OpCode { case OpCode_ScopeBegin: - stack.PopScope() + stack.PopScope(ctx) case OpCode_ScopeEnd: stack.PushScope() default: @@ -843,7 +843,7 @@ func execOp(ctx *sql.Context, runner sql.StatementRunner, stack *InterpreterStac stack.PushScope() case OpCode_ScopeEnd: - stack.PopScope() + stack.PopScope(ctx) default: panic("unimplemented opcode") diff --git a/sql/procedures/interpreter_stack.go b/sql/procedures/interpreter_stack.go index 3ba2201227..ea8b89c2f5 100644 --- a/sql/procedures/interpreter_stack.go +++ b/sql/procedures/interpreter_stack.go @@ -346,6 +346,15 @@ func (is *InterpreterStack) PushScope() { } // PopScope removes the current scope. -func (is *InterpreterStack) PopScope() { - is.stack.Pop() +func (is *InterpreterStack) PopScope(ctx *sql.Context) { + scope := is.stack.Pop() + for _, cursor := range scope.cursors { + if cursor == nil { + continue + } + if cursor.RowIter == nil { + continue + } + cursor.RowIter.Close(ctx) + } } diff --git a/sql/procedures/statements.go b/sql/procedures/statements.go index 777d6ddeb6..54934064cf 100644 --- a/sql/procedures/statements.go +++ b/sql/procedures/statements.go @@ -14,10 +14,12 @@ package procedures +import "github.com/dolthub/go-mysql-server/sql" + // Statement represents a Stored Procedure Statement. type Statement interface { // AppendOperations adds the statement to the operation slice. - AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error + AppendOperations(ctx *sql.Context, ops *[]InterpreterOperation, stack *InterpreterStack) error } // Assignment represents an assignment statement. @@ -35,7 +37,7 @@ func (Assignment) OperationSize() int32 { } // AppendOperations implements the interface Statement. -func (stmt Assignment) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { +func (stmt Assignment) AppendOperations(ctx *sql.Context, ops *[]InterpreterOperation, stack *InterpreterStack) error { //*ops = append(*ops, InterpreterOperation{ // OpCode: OpCode_Assign, // Target: stmt.VariableName, @@ -53,7 +55,7 @@ type Block struct { var _ Statement = Block{} // AppendOperations implements the interface Statement. -func (stmt Block) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { +func (stmt Block) AppendOperations(ctx *sql.Context, ops *[]InterpreterOperation, stack *InterpreterStack) error { stack.PushScope() *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_ScopeBegin, @@ -62,14 +64,14 @@ func (stmt Block) AppendOperations(ops *[]InterpreterOperation, stack *Interpret stack.NewVariableWithValue(variable.Name, nil, nil) } for _, innerStmt := range stmt.Body { - if err := innerStmt.AppendOperations(ops, stack); err != nil { + if err := innerStmt.AppendOperations(ctx, ops, stack); err != nil { return err } } *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_ScopeEnd, }) - stack.PopScope() + stack.PopScope(ctx) return nil } @@ -82,7 +84,7 @@ type ExecuteSQL struct { var _ Statement = ExecuteSQL{} // AppendOperations implements the interface Statement. -func (stmt ExecuteSQL) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { +func (stmt ExecuteSQL) AppendOperations(_ *sql.Context, ops *[]InterpreterOperation, stack *InterpreterStack) error { *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_Execute, Target: stmt.Target, @@ -103,7 +105,7 @@ func (Goto) OperationSize() int32 { } // AppendOperations implements the interface Statement. -func (stmt Goto) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { +func (stmt Goto) AppendOperations(_ *sql.Context, ops *[]InterpreterOperation, _ *InterpreterStack) error { *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_Goto, Index: len(*ops) + int(stmt.Offset), @@ -125,7 +127,7 @@ func (If) OperationSize() int32 { } // AppendOperations implements the interface Statement. -func (stmt If) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { +func (stmt If) AppendOperations(_ *sql.Context, _ *[]InterpreterOperation, _ *InterpreterStack) error { //*ops = append(*ops, InterpreterOperation{ // OpCode: OpCode_If, // PrimaryData: "SELECT ;", @@ -147,7 +149,7 @@ func (Perform) OperationSize() int32 { } // AppendOperations implements the interface Statement. -func (stmt Perform) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { +func (stmt Perform) AppendOperations(_ *sql.Context, _ *[]InterpreterOperation, _ *InterpreterStack) error { //*ops = append(*ops, InterpreterOperation{ // OpCode: OpCode_Perform, //}) @@ -167,7 +169,7 @@ func (Return) OperationSize() int32 { } // AppendOperations implements the interface Statement. -func (stmt Return) AppendOperations(ops *[]InterpreterOperation, stack *InterpreterStack) error { +func (stmt Return) AppendOperations(_ *sql.Context, ops *[]InterpreterOperation, _ *InterpreterStack) error { *ops = append(*ops, InterpreterOperation{ OpCode: OpCode_Return, }) From 9b72c7ead650607522eba15bb6944cb874e5b169 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 25 Apr 2025 21:21:57 -0700 Subject: [PATCH 106/111] fixing comments --- sql/procedures/parse.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/procedures/parse.go b/sql/procedures/parse.go index 9e9de9bd12..e3992ae30d 100644 --- a/sql/procedures/parse.go +++ b/sql/procedures/parse.go @@ -227,7 +227,6 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast if s.Else == nil { // throw an error if when there is no else block // this is just an empty case statement that will always hit the else - // todo: alternatively, use an error opcode errOp := &InterpreterOperation{ OpCode: OpCode_Exception, Error: mysql.NewSQLError(1339, "20000", "Case not found for CASE statement"), @@ -362,8 +361,7 @@ func ConvertStmt(ops *[]*InterpreterOperation, stack *InterpreterStack, stmt ast return nil } -// Parse parses the given CREATE FUNCTION string (which must be the entire string, not just the body) into a Block -// containing the contents of the body. +// Parse takes the ast.Statement and converts it series of OpCodes. func Parse(stmt ast.Statement) ([]*InterpreterOperation, error) { ops := make([]*InterpreterOperation, 0, 64) stack := NewInterpreterStack() From 9d3cb5950e33b240073690688900b7cb68150c56 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 28 Apr 2025 16:56:10 -0700 Subject: [PATCH 107/111] better context --- sql/procedures/interpreter_logic.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 95dea60da2..29b69899ac 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -15,8 +15,7 @@ package procedures import ( - "context" - "errors" + "errors" "fmt" "io" "strconv" @@ -894,7 +893,7 @@ func Call(ctx *sql.Context, iNode InterpreterNode) (sql.RowIter, *InterpreterSta break } - subCtx := sql.NewContext(context.Background()) + subCtx := sql.NewContext(ctx.Context) subCtx.Session = ctx.Session operation := statements[counter] From f6f4b918477a47445e4e7001ba0e2b99e42bf2d6 Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 28 Apr 2025 23:57:27 +0000 Subject: [PATCH 108/111] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/procedures/interpreter_logic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 29b69899ac..67962775c6 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -15,7 +15,7 @@ package procedures import ( - "errors" + "errors" "fmt" "io" "strconv" From a6f8e51f73271ee42162647f19c33481012a9075 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 29 Apr 2025 16:07:21 -0700 Subject: [PATCH 109/111] add missing ast nodes --- sql/procedures/interpreter_logic.go | 34 ++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/sql/procedures/interpreter_logic.go b/sql/procedures/interpreter_logic.go index 67962775c6..7841439958 100644 --- a/sql/procedures/interpreter_logic.go +++ b/sql/procedures/interpreter_logic.go @@ -112,6 +112,34 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. } e.Left = newLeftExpr.(ast.Expr) e.Right = newRightExpr.(ast.Expr) + case *ast.AndExpr: + newLeftExpr, err := replaceVariablesInExpr(ctx, stack, e.Left, asOf) + if err != nil { + return nil, err + } + newRightExpr, err := replaceVariablesInExpr(ctx, stack, e.Right, asOf) + if err != nil { + return nil, err + } + e.Left = newLeftExpr.(ast.Expr) + e.Right = newRightExpr.(ast.Expr) + case *ast.OrExpr: + newLeftExpr, err := replaceVariablesInExpr(ctx, stack, e.Left, asOf) + if err != nil { + return nil, err + } + newRightExpr, err := replaceVariablesInExpr(ctx, stack, e.Right, asOf) + if err != nil { + return nil, err + } + e.Left = newLeftExpr.(ast.Expr) + e.Right = newRightExpr.(ast.Expr) + case *ast.NotExpr: + newExpr, err := replaceVariablesInExpr(ctx, stack, e.Expr, asOf) + if err != nil { + return nil, err + } + e.Expr = newExpr.(ast.Expr) case *ast.FuncExpr: for i := range e.Exprs { newExpr, err := replaceVariablesInExpr(ctx, stack, e.Exprs[i], asOf) @@ -120,12 +148,6 @@ func replaceVariablesInExpr(ctx *sql.Context, stack *InterpreterStack, expr ast. } e.Exprs[i] = newExpr.(ast.SelectExpr) } - case *ast.NotExpr: - newExpr, err := replaceVariablesInExpr(ctx, stack, e.Expr, asOf) - if err != nil { - return nil, err - } - e.Expr = newExpr.(ast.Expr) case *ast.Set: for _, setExpr := range e.Exprs { newExpr, err := replaceVariablesInExpr(ctx, stack, setExpr.Expr, asOf) From 9b13d2cebbac2544f3fe111e830f5ef0cacd0132 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 30 Apr 2025 09:55:16 -0700 Subject: [PATCH 110/111] remove unused file --- sql/procedures/statements.go | 184 ----------------------------------- 1 file changed, 184 deletions(-) delete mode 100644 sql/procedures/statements.go diff --git a/sql/procedures/statements.go b/sql/procedures/statements.go deleted file mode 100644 index 54934064cf..0000000000 --- a/sql/procedures/statements.go +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2025 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package procedures - -import "github.com/dolthub/go-mysql-server/sql" - -// Statement represents a Stored Procedure Statement. -type Statement interface { - // AppendOperations adds the statement to the operation slice. - AppendOperations(ctx *sql.Context, ops *[]InterpreterOperation, stack *InterpreterStack) error -} - -// Assignment represents an assignment statement. -type Assignment struct { - VariableName string - Expression string - VariableIndex int32 // TODO: figure out what this is used for, probably to get around shadowed variables? -} - -var _ Statement = Assignment{} - -// OperationSize implements the interface Statement. -func (Assignment) OperationSize() int32 { - return 1 -} - -// AppendOperations implements the interface Statement. -func (stmt Assignment) AppendOperations(ctx *sql.Context, ops *[]InterpreterOperation, stack *InterpreterStack) error { - //*ops = append(*ops, InterpreterOperation{ - // OpCode: OpCode_Assign, - // Target: stmt.VariableName, - //}) - return nil -} - -// Block contains a collection of statements, alongside the variables that were declared for the block. Only the -// top-level block will contain parameter variables. -type Block struct { - Variable []Variable - Body []Statement -} - -var _ Statement = Block{} - -// AppendOperations implements the interface Statement. -func (stmt Block) AppendOperations(ctx *sql.Context, ops *[]InterpreterOperation, stack *InterpreterStack) error { - stack.PushScope() - *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_ScopeBegin, - }) - for _, variable := range stmt.Variable { - stack.NewVariableWithValue(variable.Name, nil, nil) - } - for _, innerStmt := range stmt.Body { - if err := innerStmt.AppendOperations(ctx, ops, stack); err != nil { - return err - } - } - *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_ScopeEnd, - }) - stack.PopScope(ctx) - return nil -} - -// ExecuteSQL represents a standard SQL statement's execution (including the INTO syntax). -type ExecuteSQL struct { - Statement string - Target string -} - -var _ Statement = ExecuteSQL{} - -// AppendOperations implements the interface Statement. -func (stmt ExecuteSQL) AppendOperations(_ *sql.Context, ops *[]InterpreterOperation, stack *InterpreterStack) error { - *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Execute, - Target: stmt.Target, - }) - return nil -} - -// Goto jumps to the counter at the given offset. -type Goto struct { - Offset int32 -} - -var _ Statement = Goto{} - -// OperationSize implements the interface Statement. -func (Goto) OperationSize() int32 { - return 1 -} - -// AppendOperations implements the interface Statement. -func (stmt Goto) AppendOperations(_ *sql.Context, ops *[]InterpreterOperation, _ *InterpreterStack) error { - *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Goto, - Index: len(*ops) + int(stmt.Offset), - }) - return nil -} - -// If represents an IF condition, alongside its Goto offset if the condition is true. -type If struct { - Condition string - GotoOffset int32 -} - -var _ Statement = If{} - -// OperationSize implements the interface Statement. -func (If) OperationSize() int32 { - return 1 -} - -// AppendOperations implements the interface Statement. -func (stmt If) AppendOperations(_ *sql.Context, _ *[]InterpreterOperation, _ *InterpreterStack) error { - //*ops = append(*ops, InterpreterOperation{ - // OpCode: OpCode_If, - // PrimaryData: "SELECT ;", - // Index: len(*ops) + int(stmt.GotoOffset), - //}) - return nil -} - -// Perform represents a PERFORM statement. -type Perform struct { - Statement string -} - -var _ Statement = Perform{} - -// OperationSize implements the interface Statement. -func (Perform) OperationSize() int32 { - return 1 -} - -// AppendOperations implements the interface Statement. -func (stmt Perform) AppendOperations(_ *sql.Context, _ *[]InterpreterOperation, _ *InterpreterStack) error { - //*ops = append(*ops, InterpreterOperation{ - // OpCode: OpCode_Perform, - //}) - return nil -} - -// Return represents a RETURN statement. -type Return struct { - Expression string -} - -var _ Statement = Return{} - -// OperationSize implements the interface Statement. -func (Return) OperationSize() int32 { - return 1 -} - -// AppendOperations implements the interface Statement. -func (stmt Return) AppendOperations(_ *sql.Context, ops *[]InterpreterOperation, _ *InterpreterStack) error { - *ops = append(*ops, InterpreterOperation{ - OpCode: OpCode_Return, - }) - return nil -} - -// Variable represents a variable. These are exclusively found within Block. -type Variable struct { - Name string - Type string - IsParameter bool -} From f1067220069b31e5ef6f00893326ff8fca9b7958 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 30 Apr 2025 11:08:32 -0700 Subject: [PATCH 111/111] bump --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 63db593da9..f7a0fdfc1e 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20250424225619-bfb20390c1d1 + github.com/dolthub/vitess v0.0.0-20250430180243-0eee73763bc5 github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index 593f8000cf..1904876371 100644 --- a/go.sum +++ b/go.sum @@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE= github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= -github.com/dolthub/vitess v0.0.0-20250424225619-bfb20390c1d1 h1:ZXssuc0ZNqKUD7xQCd0/xLT+nKrmAetNYb5v7xgU0U0= -github.com/dolthub/vitess v0.0.0-20250424225619-bfb20390c1d1/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= +github.com/dolthub/vitess v0.0.0-20250430180243-0eee73763bc5 h1:eyC/UHnNsCham/65hV9p/Si0S+cq774kbgk0/KPFYws= +github.com/dolthub/vitess v0.0.0-20250430180243-0eee73763bc5/go.mod h1:1gQZs/byeHLMSul3Lvl3MzioMtOW1je79QYGyi2fd70= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=