@@ -20,6 +20,7 @@ import (
2020 "gopkg.in/src-d/go-errors.v1"
2121
2222 "github.com/dolthub/go-mysql-server/sql"
23+ "github.com/dolthub/go-mysql-server/sql/expression"
2324 "github.com/dolthub/go-mysql-server/sql/transform"
2425)
2526
@@ -70,6 +71,10 @@ type InsertInto struct {
7071 // a |Values| node with only literal expressions.
7172 LiteralValueSource bool
7273
74+ // Returning is a list of expressions to return after the insert operation. This feature is not supported
75+ // in MySQL's syntax, but is exposed through PostgreSQL's syntax.
76+ Returning []sql.Expression
77+
7378 // FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.
7479 FirstGeneratedAutoIncRowIdx int
7580
@@ -122,6 +127,19 @@ func (ii *InsertInto) Schema() sql.Schema {
122127 if ii .IsReplace {
123128 return append (ii .Destination .Schema (), ii .Destination .Schema ()... )
124129 }
130+
131+ // Postgres allows the returned values of the insert statement to be controlled, so if returning expressions
132+ // were specified, then we return a different schema.
133+ if ii .Returning != nil {
134+ // We know that returning exprs are resolved here, because you can't call Schema() safely until Resolved() is true.
135+ returningSchema := sql.Schema {}
136+ for _ , expr := range ii .Returning {
137+ returningSchema = append (returningSchema , transform .ExpressionToColumn (expr , "" ))
138+ }
139+
140+ return returningSchema
141+ }
142+
125143 return ii .Destination .Schema ()
126144}
127145
@@ -238,24 +256,30 @@ func (ii *InsertInto) DebugString() string {
238256
239257// Expressions implements the sql.Expressioner interface.
240258func (ii * InsertInto ) Expressions () []sql.Expression {
241- return append (ii .OnDupExprs , ii .checks .ToExpressions ()... )
259+ exprs := append (ii .OnDupExprs , ii .checks .ToExpressions ()... )
260+ exprs = append (exprs , ii .Returning ... )
261+ return exprs
242262}
243263
244264// WithExpressions implements the sql.Expressioner interface.
245265func (ii * InsertInto ) WithExpressions (newExprs ... sql.Expression ) (sql.Node , error ) {
246- if len (newExprs ) != len (ii .OnDupExprs )+ len (ii .checks ) {
247- return nil , sql .ErrInvalidChildrenNumber .New (ii , len (newExprs ), len (ii .OnDupExprs )+ len (ii .checks ))
266+ if len (newExprs ) != len (ii .OnDupExprs )+ len (ii .checks )+ len ( ii . Returning ) {
267+ return nil , sql .ErrInvalidChildrenNumber .New (ii , len (newExprs ), len (ii .OnDupExprs )+ len (ii .checks )+ len ( ii . Returning ) )
248268 }
249269
250270 nii := * ii
251271 nii .OnDupExprs = newExprs [:len (nii .OnDupExprs )]
272+ newExprs = newExprs [len (nii .OnDupExprs ):]
252273
253274 var err error
254- nii .checks , err = nii .checks .FromExpressions (newExprs [len (nii .OnDupExprs ): ])
275+ nii .checks , err = nii .checks .FromExpressions (newExprs [: len (nii .checks ) ])
255276 if err != nil {
256277 return nil , err
257278 }
258279
280+ newExprs = newExprs [len (nii .checks ):]
281+ nii .Returning = newExprs
282+
259283 return & nii , nil
260284}
261285
@@ -264,17 +288,15 @@ func (ii *InsertInto) Resolved() bool {
264288 if ! ii .Destination .Resolved () || ! ii .Source .Resolved () {
265289 return false
266290 }
267- for _ , updateExpr := range ii .OnDupExprs {
268- if ! updateExpr .Resolved () {
269- return false
270- }
271- }
291+
272292 for _ , checkExpr := range ii .checks {
273293 if ! checkExpr .Expr .Resolved () {
274294 return false
275295 }
276296 }
277- return true
297+
298+ return expression .ExpressionsResolved (ii .OnDupExprs ... ) &&
299+ expression .ExpressionsResolved (ii .Returning ... )
278300}
279301
280302// InsertDestination is a wrapper for a table to be used with InsertInto.Destination that allows the schema to be
0 commit comments