From 4e80776c8df2c490db29d2979178dee8c5d5ee6d Mon Sep 17 00:00:00 2001 From: Jason Fulghum Date: Fri, 14 Feb 2025 14:53:02 -0800 Subject: [PATCH 1/9] First pass on prototype of postgres INSERT RETURNING support --- sql/plan/insert.go | 51 ++++++++++++++++++++++++++++++++++++---- sql/planbuilder/dml.go | 8 +++++++ sql/rowexec/dml.go | 2 ++ sql/rowexec/dml_iters.go | 23 ++++++++++++++++++ sql/rowexec/insert.go | 43 +++++++++++++++++++++++---------- 5 files changed, 110 insertions(+), 17 deletions(-) diff --git a/sql/plan/insert.go b/sql/plan/insert.go index c52acf843d..13910ec29e 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -70,6 +70,10 @@ type InsertInto struct { // a |Values| node with only literal expressions. LiteralValueSource bool + // Returning is a list of expressions to return after the insert operation. This feature is not supported + // in MySQL's syntax, but is exposed through PostgreSQL's syntax. + Returning []sql.Expression + // FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id. FirstGeneratedAutoIncRowIdx int } @@ -119,6 +123,31 @@ func (ii *InsertInto) Schema() sql.Schema { if ii.IsReplace { return append(ii.Destination.Schema(), ii.Destination.Schema()...) } + + // Postgres allows the returned values of the insert statement to be controlled, so if returning expressions + // were specified, then we return a different schema. + // TODO: does anything else depend on the schema returned by insert statements? triggers? + // TODO: Do we need to check for the expressions being fully resolved? (probably!) + if ii.Returning != nil { + // TODO: If we don't return the destination schema anymore... does that mess up other things, like trigger processing? + // TODO: we need to look at the expressions in the returning clause + + // TODO: We need to make sure the Returning expressions get resolved by the analyzer! Will need to be special cased... :-/ + returningExprsResovled := true + returningSchema := sql.Schema{} + for _, expr := range ii.Returning { + if !expr.Resolved() { + returningExprsResovled = false + break + } + returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, "")) + } + + if returningExprsResovled { + return returningSchema + } + } + return ii.Destination.Schema() } @@ -227,23 +256,28 @@ func (ii *InsertInto) DebugString() string { // Expressions implements the sql.Expressioner interface. func (ii *InsertInto) Expressions() []sql.Expression { - return append(ii.OnDupExprs, ii.checks.ToExpressions()...) + exprs := append(ii.OnDupExprs, ii.checks.ToExpressions()...) + exprs = append(exprs, ii.Returning...) + return exprs } // WithExpressions implements the sql.Expressioner interface. func (ii *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, error) { - if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks) { - return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks)) + if len(newExprs) != len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning) { + return nil, sql.ErrInvalidChildrenNumber.New(ii, len(newExprs), len(ii.OnDupExprs)+len(ii.checks)+len(ii.Returning)) } nii := *ii nii.OnDupExprs = newExprs[:len(nii.OnDupExprs)] + newExprs = newExprs[len(nii.OnDupExprs):] var err error - nii.checks, err = nii.checks.FromExpressions(newExprs[len(nii.OnDupExprs):]) + nii.checks, err = nii.checks.FromExpressions(newExprs[:len(nii.checks)]) if err != nil { return nil, err } + newExprs = newExprs[len(nii.checks):] + nii.Returning = newExprs return &nii, nil } @@ -263,6 +297,15 @@ func (ii *InsertInto) Resolved() bool { return false } } + + if ii.Returning != nil { + for _, expr := range ii.Returning { + if !expr.Resolved() { + return false + } + } + } + return true } diff --git a/sql/planbuilder/dml.go b/sql/planbuilder/dml.go index d43a6326a4..b0bd6d73e7 100644 --- a/sql/planbuilder/dml.go +++ b/sql/planbuilder/dml.go @@ -147,6 +147,14 @@ func (b *Builder) buildInsert(inScope *scope, i *ast.Insert) (outScope *scope) { ins := plan.NewInsertInto(db, plan.NewInsertDestination(sch, dest), srcScope.node, isReplace, columns, onDupExprs, ignore) ins.LiteralValueSource = srcLiteralOnly + if i.Returning != nil { + returningExprs := make([]sql.Expression, len(i.Returning)) + for i, selectExpr := range i.Returning { + returningExprs[i] = b.selectExprToExpression(destScope, selectExpr) + } + ins.Returning = returningExprs + } + b.validateInsert(ins) outScope = destScope diff --git a/sql/rowexec/dml.go b/sql/rowexec/dml.go index 944183d16c..00f19f83eb 100644 --- a/sql/rowexec/dml.go +++ b/sql/rowexec/dml.go @@ -90,6 +90,8 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row ctx: ctx, ignore: ii.Ignore, firstGeneratedAutoIncRowIdx: ii.FirstGeneratedAutoIncRowIdx, + returnExprs: ii.Returning, + returnSchema: ii.Schema(), } var ed sql.EditOpenerCloser diff --git a/sql/rowexec/dml_iters.go b/sql/rowexec/dml_iters.go index e9366649bf..38f78f484f 100644 --- a/sql/rowexec/dml_iters.go +++ b/sql/rowexec/dml_iters.go @@ -569,6 +569,29 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc childIter := i.GetChildIter() childIter, sch := AddAccumulatorIter(ctx, childIter) return i.WithChildIter(childIter), sch + case *plan.TableEditorIter: + // TODO: If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter + innerIter := i.InnerIter() + if insertIter, ok := innerIter.(*insertIter); ok { + if insertIter.returnExprs != nil { + return insertIter, insertIter.returnSchema + } else { + // TODO: How do we use the default logic if this isn't true... ? For now, just copying... + clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0 + rowHandler := getRowHandler(clientFoundRowsToggled, iter) + if rowHandler == nil { + return iter, nil + } + return &accumulatorIter{ + iter: iter, + updateRowHandler: rowHandler, + }, types.OkResultSchema + } + } + + // But... do we add a different iterator? How does Postgresql do this? does it return all rows? + // TODO: Where do we get the correct schema from? 🤔 + return iter, nil default: clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0 rowHandler := getRowHandler(clientFoundRowsToggled, iter) diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index 573c5fc7f7..ae652111a3 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -30,19 +30,21 @@ import ( ) type insertIter struct { - schema sql.Schema - inserter sql.RowInserter - replacer sql.RowReplacer - updater sql.RowUpdater - rowSource sql.RowIter - unlocker func() - ctx *sql.Context - insertExprs []sql.Expression - updateExprs []sql.Expression - checks sql.CheckConstraints - tableNode sql.Node - closed bool - ignore bool + schema sql.Schema + inserter sql.RowInserter + replacer sql.RowReplacer + updater sql.RowUpdater + rowSource sql.RowIter + unlocker func() + ctx *sql.Context + insertExprs []sql.Expression + updateExprs []sql.Expression + checks sql.CheckConstraints + tableNode sql.Node + closed bool + ignore bool + returnExprs []sql.Expression + returnSchema sql.Schema firstGeneratedAutoIncRowIdx int } @@ -173,6 +175,21 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) i.updateLastInsertId(ctx, row) + if i.returnExprs != nil { + var retExprRow sql.Row + // TODO: Why does the GetField expression pull field 1 instead of field 0? + emptyOne := []interface{}{nil} + row = append(emptyOne, row...) + for _, returnExpr := range i.returnExprs { + result, err := returnExpr.Eval(ctx, row) + if err != nil { + return nil, err + } + retExprRow = append(retExprRow, result) + } + return retExprRow, nil + } + return row, nil } From 0c5cab854735c17cf9591383ed47050dfe6a64c6 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 14 Mar 2025 15:23:19 -0700 Subject: [PATCH 2/9] newline --- sql/plan/insert.go | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 13910ec29e..55e58db115 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -276,6 +276,7 @@ func (ii *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, err if err != nil { return nil, err } + newExprs = newExprs[len(nii.checks):] nii.Returning = newExprs From 423468f7701072e9b9e4d1c663de13079a2fc77a Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 14 Mar 2025 16:30:47 -0700 Subject: [PATCH 3/9] simplifications --- sql/plan/insert.go | 29 ++++++----------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 55a5a896c9..c3a1377f58 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -17,6 +17,7 @@ package plan import ( "strings" + "github.com/dolthub/go-mysql-server/sql/expression" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -135,20 +136,13 @@ func (ii *InsertInto) Schema() sql.Schema { // TODO: If we don't return the destination schema anymore... does that mess up other things, like trigger processing? // TODO: we need to look at the expressions in the returning clause - // TODO: We need to make sure the Returning expressions get resolved by the analyzer! Will need to be special cased... :-/ - returningExprsResovled := true + // We know that returning exprs are resolved here, because you can't call Schema() safely until Resolved() is true. returningSchema := sql.Schema{} for _, expr := range ii.Returning { - if !expr.Resolved() { - returningExprsResovled = false - break - } returningSchema = append(returningSchema, transform.ExpressionToColumn(expr, "")) } - if returningExprsResovled { - return returningSchema - } + return returningSchema } return ii.Destination.Schema() @@ -299,26 +293,15 @@ func (ii *InsertInto) Resolved() bool { if !ii.Destination.Resolved() || !ii.Source.Resolved() { return false } - for _, updateExpr := range ii.OnDupExprs { - if !updateExpr.Resolved() { - return false - } - } + for _, checkExpr := range ii.checks { if !checkExpr.Expr.Resolved() { return false } } - if ii.Returning != nil { - for _, expr := range ii.Returning { - if !expr.Resolved() { - return false - } - } - } - - return true + return expression.ExpressionsResolved(ii.OnDupExprs...) && + expression.ExpressionsResolved(ii.Returning...) } // InsertDestination is a wrapper for a table to be used with InsertInto.Destination that allows the schema to be From 0990e824e7276de8ded03207b864dba52fb1c81b Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Fri, 14 Mar 2025 17:27:40 -0700 Subject: [PATCH 4/9] Fix bug with on duplicate key update --- sql/rowexec/dml_iters.go | 6 ++---- sql/rowexec/insert.go | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/rowexec/dml_iters.go b/sql/rowexec/dml_iters.go index 345371336f..4e3263afe0 100644 --- a/sql/rowexec/dml_iters.go +++ b/sql/rowexec/dml_iters.go @@ -591,10 +591,10 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc childIter, sch := AddAccumulatorIter(ctx, childIter) return i.WithChildIter(childIter), sch case *plan.TableEditorIter: - // TODO: If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter + // If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter innerIter := i.InnerIter() if insertIter, ok := innerIter.(*insertIter); ok { - if insertIter.returnExprs != nil { + if len(insertIter.returnExprs) > 0 { return insertIter, insertIter.returnSchema } else { // TODO: How do we use the default logic if this isn't true... ? For now, just copying... @@ -610,8 +610,6 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc } } - // But... do we add a different iterator? How does Postgresql do this? does it return all rows? - // TODO: Where do we get the correct schema from? 🤔 return iter, nil default: clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0 diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index 4137310125..4106723fae 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -179,7 +179,7 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) if i.returnExprs != nil { var retExprRow sql.Row - // TODO: Why does the GetField expression pull field 1 instead of field 0? + // TODO NEXT: Why does the GetField expression pull field 1 instead of field 0? emptyOne := []interface{}{nil} row = append(emptyOne, row...) for _, returnExpr := range i.returnExprs { From f7a7650f5558877d38bcc51c8330830c9e432dd2 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Mon, 17 Mar 2025 15:36:28 -0700 Subject: [PATCH 5/9] Bug fixes for insert returning --- sql/analyzer/fix_exec_indexes.go | 7 ++++++- sql/rowexec/insert.go | 3 --- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sql/analyzer/fix_exec_indexes.go b/sql/analyzer/fix_exec_indexes.go index 92382758e4..2c30e3db94 100644 --- a/sql/analyzer/fix_exec_indexes.go +++ b/sql/analyzer/fix_exec_indexes.go @@ -494,6 +494,10 @@ func (s *idxScope) visitSelf(n sql.Node) error { newCheck.Expr = newE s.checks = append(s.checks, &newCheck) } + for _, r := range n.Returning { + newE := fixExprToScope(r, dstScope) + s.expressions = append(s.expressions, newE) + } case *plan.Update: newScope := s.copy() srcScope := s.childScopes[0] @@ -543,7 +547,8 @@ func (s *idxScope) finalizeSelf(n sql.Node) (sql.Node, error) { nn := *n nn.Source = s.children[0] nn.Destination = s.children[1] - nn.OnDupExprs = s.expressions + nn.OnDupExprs = s.expressions[:len(n.OnDupExprs)] + nn.Returning = s.expressions[len(n.OnDupExprs):] return nn.WithChecks(s.checks), nil default: s.ids = columnIdsForNode(n) diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index 4106723fae..7288d05d32 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -179,9 +179,6 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) if i.returnExprs != nil { var retExprRow sql.Row - // TODO NEXT: Why does the GetField expression pull field 1 instead of field 0? - emptyOne := []interface{}{nil} - row = append(emptyOne, row...) for _, returnExpr := range i.returnExprs { result, err := returnExpr.Eval(ctx, row) if err != nil { From f3293c69d51285d29732732199959d3ddd9a8641 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Mon, 17 Mar 2025 15:56:19 -0700 Subject: [PATCH 6/9] Formatting --- sql/plan/insert.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/plan/insert.go b/sql/plan/insert.go index c3a1377f58..9fd0f1d3bc 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -17,10 +17,10 @@ package plan import ( "strings" - "github.com/dolthub/go-mysql-server/sql/expression" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/transform" ) @@ -281,7 +281,7 @@ func (ii *InsertInto) WithExpressions(newExprs ...sql.Expression) (sql.Node, err if err != nil { return nil, err } - + newExprs = newExprs[len(nii.checks):] nii.Returning = newExprs @@ -301,7 +301,7 @@ func (ii *InsertInto) Resolved() bool { } return expression.ExpressionsResolved(ii.OnDupExprs...) && - expression.ExpressionsResolved(ii.Returning...) + expression.ExpressionsResolved(ii.Returning...) } // InsertDestination is a wrapper for a table to be used with InsertInto.Destination that allows the schema to be From 43b5a649e07ac03c91c0c43b0e8afc289bf7c39b Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Mon, 17 Mar 2025 16:28:16 -0700 Subject: [PATCH 7/9] bug fix --- sql/rowexec/insert.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index 7288d05d32..659c508fcc 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -177,7 +177,7 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) i.updateLastInsertId(ctx, row) - if i.returnExprs != nil { + if len(i.returnExprs) > 0 { var retExprRow sql.Row for _, returnExpr := range i.returnExprs { result, err := returnExpr.Eval(ctx, row) From 6df0972abb474f6febb70383a27c325b7007fa5e Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Mon, 17 Mar 2025 16:53:26 -0700 Subject: [PATCH 8/9] Bug fix for omitting an accumulator for many DML nodes --- sql/rowexec/dml_iters.go | 42 ++++++++++++++++------------------------ 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/sql/rowexec/dml_iters.go b/sql/rowexec/dml_iters.go index 4e3263afe0..57dfd72555 100644 --- a/sql/rowexec/dml_iters.go +++ b/sql/rowexec/dml_iters.go @@ -593,35 +593,27 @@ func AddAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Sc case *plan.TableEditorIter: // If the TableEditorIter has RETURNING expressions, then we do NOT actually add the accumulatorIter innerIter := i.InnerIter() - if insertIter, ok := innerIter.(*insertIter); ok { - if len(insertIter.returnExprs) > 0 { - return insertIter, insertIter.returnSchema - } else { - // TODO: How do we use the default logic if this isn't true... ? For now, just copying... - clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0 - rowHandler := getRowHandler(clientFoundRowsToggled, iter) - if rowHandler == nil { - return iter, nil - } - return &accumulatorIter{ - iter: iter, - updateRowHandler: rowHandler, - }, types.OkResultSchema - } + if insertIter, ok := innerIter.(*insertIter); ok && len(insertIter.returnExprs) > 0 { + return insertIter, insertIter.returnSchema } - return iter, nil + return defaultAccumulatorIter(ctx, iter) default: - clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0 - rowHandler := getRowHandler(clientFoundRowsToggled, iter) - if rowHandler == nil { - return iter, nil - } - return &accumulatorIter{ - iter: iter, - updateRowHandler: rowHandler, - }, types.OkResultSchema + return defaultAccumulatorIter(ctx, iter) + } +} + +// defaultAccumulatorIter returns the default accumulator iter for a DML node +func defaultAccumulatorIter(ctx *sql.Context, iter sql.RowIter) (sql.RowIter, sql.Schema) { + clientFoundRowsToggled := (ctx.Client().Capabilities & mysql.CapabilityClientFoundRows) > 0 + rowHandler := getRowHandler(clientFoundRowsToggled, iter) + if rowHandler == nil { + return iter, nil } + return &accumulatorIter{ + iter: iter, + updateRowHandler: rowHandler, + }, types.OkResultSchema } func (a *accumulatorIter) Next(ctx *sql.Context) (r sql.Row, err error) { From 2a94ec7b7f367222aca316c9519b27ea996393f9 Mon Sep 17 00:00:00 2001 From: Zach Musgrave Date: Mon, 17 Mar 2025 17:02:11 -0700 Subject: [PATCH 9/9] Remove out of date TODOs --- sql/plan/insert.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/plan/insert.go b/sql/plan/insert.go index 9fd0f1d3bc..5c7a24da12 100644 --- a/sql/plan/insert.go +++ b/sql/plan/insert.go @@ -130,12 +130,7 @@ func (ii *InsertInto) Schema() sql.Schema { // Postgres allows the returned values of the insert statement to be controlled, so if returning expressions // were specified, then we return a different schema. - // TODO: does anything else depend on the schema returned by insert statements? triggers? - // TODO: Do we need to check for the expressions being fully resolved? (probably!) if ii.Returning != nil { - // TODO: If we don't return the destination schema anymore... does that mess up other things, like trigger processing? - // TODO: we need to look at the expressions in the returning clause - // We know that returning exprs are resolved here, because you can't call Schema() safely until Resolved() is true. returningSchema := sql.Schema{} for _, expr := range ii.Returning {