Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions enginetest/queries/procedure_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2823,6 +2823,63 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{
},
},
},
{
Name: "procedure must not contain CREATE TABLE",
Assertions: []ScriptTestAssertion{
{
Query: "create procedure p() create table t (pk int);",
ExpectedErrStr: "creating tables in stored procedures is currently unsupported and will be added in a future release",
},
{
Query: "create procedure p() begin create table t (pk int); end;",
ExpectedErrStr: "creating tables in stored procedures is currently unsupported and will be added in a future release",
},
},
},
{
Name: "procedure must not contain CREATE TRIGGER",
SetUpScript: []string{
"create table t (i int);",
},
Assertions: []ScriptTestAssertion{
{
Query: "create procedure p() create trigger trig before insert on t for each row begin select 1; end;",
ExpectedErrStr: "creating triggers in stored procedures is currently unsupported and will be added in a future release",
},
{
Query: "create procedure p() begin create trigger trig before insert on t for each row begin select 1; end; end;",
ExpectedErrStr: "creating triggers in stored procedures is currently unsupported and will be added in a future release",
},
},
},
{
Name: "procedure must not contain CREATE DB",
SetUpScript: []string{},
Assertions: []ScriptTestAssertion{
{
Query: "create procedure p() create database procdb;",
ExpectedErrStr: "creating databases in stored procedures is currently unsupported and will be added in a future release",
},
{
Query: "create procedure p() begin create database procdb; end;",
ExpectedErrStr: "creating databases in stored procedures is currently unsupported and will be added in a future release",
},
},
},
{
Name: "procedure must not contain CREATE VIEW",
SetUpScript: []string{},
Assertions: []ScriptTestAssertion{
{
Query: "create procedure p() create view v as select 1;",
ExpectedErrStr: "creating views in stored procedures is currently unsupported and will be added in a future release",
},
{
Query: "create procedure p() begin create view v as select 1; end;",
ExpectedErrStr: "creating views in stored procedures is currently unsupported and will be added in a future release",
},
},
},
}

var NoDbProcedureTests = []ScriptTestAssertion{
Expand Down
1 change: 0 additions & 1 deletion sql/analyzer/rule_ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ const (
resolveUnionsId // resolveUnions
ValidateColumnDefaultsId // validateColumnDefaults
validateCreateTriggerId // validateCreateTrigger
validateCreateProcedureId // validateCreateProcedure
validateReadOnlyDatabaseId // validateReadOnlyDatabase
validateReadOnlyTransactionId // validateReadOnlyTransaction
validateDatabaseSetId // validateDatabaseSet
Expand Down
113 changes: 56 additions & 57 deletions sql/analyzer/ruleid_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ var OnceBeforeDefault = []Rule{
{validateCreateTableId, validateCreateTable},
{validateAlterTableId, validateAlterTable},
{validateExprSemId, validateExprSem},
{validateCreateProcedureId, validateCreateProcedure},
{resolveDropConstraintId, resolveDropConstraint},
{resolveAlterColumnId, resolveAlterColumn},
{validateDropTablesId, validateDropTables},
Expand Down
82 changes: 1 addition & 81 deletions sql/analyzer/stored_procedures.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ package analyzer
import (
"fmt"
"slices"
"strings"

"gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
Expand Down Expand Up @@ -84,11 +81,8 @@ func loadStoredProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan

// analyzeCreateProcedure checks the plan.CreateProcedure and returns a valid plan.Procedure or an error
func analyzeCreateProcedure(ctx *sql.Context, a *Analyzer, cp *plan.CreateProcedure, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (*plan.Procedure, error) {
err := validateStoredProcedure(ctx, cp.Procedure)
if err != nil {
return nil, err
}
var analyzedNode sql.Node
var err error
analyzedNode, _, err = analyzeProcedureBodies(ctx, a, cp.Procedure, false, scope, sel, qFlags)
if err != nil {
return nil, err
Expand Down Expand Up @@ -164,80 +158,6 @@ func analyzeProcedureBodies(ctx *sql.Context, a *Analyzer, node sql.Node, skipCa
return node, transform.NewTree, nil
}

// validateCreateProcedure handles CreateProcedure nodes, ensuring that all nodes in Procedure are supported.
func validateCreateProcedure(ctx *sql.Context, a *Analyzer, node sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
cp, ok := node.(*plan.CreateProcedure)
if !ok {
return node, transform.SameTree, nil
}

err := validateStoredProcedure(ctx, cp.Procedure)
if err != nil {
return nil, transform.SameTree, err
}

return node, transform.SameTree, nil
}

// validateStoredProcedure handles Procedure nodes, resolving references to the parameters, along with ensuring
// that all logic contained within the stored procedure body is valid.
func validateStoredProcedure(_ *sql.Context, proc *plan.Procedure) error {
// For now, we don't support creating any of the following within stored procedures.
// These will be removed in the future, but cause issues with the current execution plan.
var err error
spUnsupportedErr := errors.NewKind("creating %s in stored procedures is currently unsupported " +
"and will be added in a future release")
transform.Inspect(proc, func(n sql.Node) bool {
switch n.(type) {
case *plan.CreateTable:
err = spUnsupportedErr.New("tables")
case *plan.CreateTrigger:
err = spUnsupportedErr.New("triggers")
case *plan.CreateProcedure:
err = spUnsupportedErr.New("procedures")
case *plan.CreateDB:
err = spUnsupportedErr.New("databases")
case *plan.CreateForeignKey:
err = spUnsupportedErr.New("foreign keys")
case *plan.CreateIndex:
err = spUnsupportedErr.New("indexes")
case *plan.CreateView:
err = spUnsupportedErr.New("views")
default:
return true
}
return false
})
if err != nil {
return err
}

transform.Inspect(proc, func(n sql.Node) bool {
switch n := n.(type) {
case *plan.Call:
if proc.Name == strings.ToLower(n.Name) {
err = sql.ErrProcedureRecursiveCall.New(proc.Name)
}
case *plan.LockTables: // Blocked in vitess, but this is for safety
err = sql.ErrProcedureInvalidBodyStatement.New("LOCK TABLES")
case *plan.UnlockTables: // Blocked in vitess, but this is for safety
err = sql.ErrProcedureInvalidBodyStatement.New("UNLOCK TABLES")
case *plan.Use: // Blocked in vitess, but this is for safety
err = sql.ErrProcedureInvalidBodyStatement.New("USE")
case *plan.LoadData:
err = sql.ErrProcedureInvalidBodyStatement.New("LOAD DATA")
default:
return true
}
return false
})
if err != nil {
return err
}

return nil
}

// applyProcedures applies the relevant stored procedures to the node given (if necessary).
func applyProcedures(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
if _, ok := n.(*plan.CreateProcedure); ok {
Expand Down
18 changes: 18 additions & 0 deletions sql/planbuilder/create_ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
"github.com/dolthub/go-mysql-server/sql/types"
)

Expand Down Expand Up @@ -132,6 +133,9 @@ func getCurrentUserForDefiner(ctx *sql.Context, definer string) string {
}

func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuery string, c *ast.DDL) (outScope *scope) {
b.qFlags.Set(sql.QFlagCreateProcedure)
defer func() { b.qFlags.Unset(sql.QFlagCreateProcedure) }()

var params []plan.ProcedureParam
for _, param := range c.ProcedureSpec.Params {
var direction plan.ProcedureParamDirection
Expand Down Expand Up @@ -200,6 +204,20 @@ func (b *Builder) buildCreateProcedure(inScope *scope, subQuery string, fullQuer
bodyStr := strings.TrimSpace(fullQuery[c.SubStatementPositionStart:c.SubStatementPositionEnd])

bodyScope := b.buildSubquery(inScope, c.ProcedureSpec.Body, bodyStr, fullQuery)
b.validateStoredProcedure(bodyScope.node)

// Check for recursive calls to same procedure
transform.Inspect(bodyScope.node, func(node sql.Node) bool {
switch n := node.(type) {
case *plan.Call:
if strings.EqualFold(procName, n.Name) {
b.handleErr(sql.ErrProcedureRecursiveCall.New(procName))
}
return false
default:
return true
}
})

var db sql.Database = nil
dbName := c.ProcedureSpec.ProcName.Qualifier.String()
Expand Down
Loading
Loading