@@ -21,6 +21,7 @@ import (
2121
2222 "github.com/dolthub/go-mysql-server/sql"
2323 "github.com/dolthub/go-mysql-server/sql/expression"
24+ "github.com/dolthub/go-mysql-server/sql/transform"
2425)
2526
2627var ErrUpdateNotSupported = errors .NewKind ("table doesn't support UPDATE" )
@@ -35,6 +36,10 @@ type Update struct {
3536 IsJoin bool
3637 HasSingleRel bool
3738 IsProcNested bool
39+
40+ // Returning is a list of expressions to return after the update operation. This feature is not
41+ // supported in MySQL's syntax, but is exposed through PostgreSQL's syntax.
42+ Returning []sql.Expression
3843}
3944
4045var _ sql.Node = (* Update )(nil )
@@ -112,6 +117,24 @@ func GetDatabase(node sql.Node) sql.Database {
112117 return nil
113118}
114119
120+ // Schema implements the sql.Node interface.
121+ func (u * Update ) Schema () sql.Schema {
122+ // Postgres allows the returned values of the update statement to be controlled, so if returning
123+ // expressions were specified, then we return a different schema.
124+ if u .Returning != nil {
125+ // We know that returning exprs are resolved here, because you can't call Schema()
126+ // safely until Resolved() is true.
127+ returningSchema := sql.Schema {}
128+ for _ , expr := range u .Returning {
129+ returningSchema = append (returningSchema , transform .ExpressionToColumn (expr , "" ))
130+ }
131+
132+ return returningSchema
133+ }
134+
135+ return u .Child .Schema ()
136+ }
137+
115138func (u * Update ) Checks () sql.CheckConstraints {
116139 return u .checks
117140}
@@ -140,24 +163,32 @@ func (u *Update) Database() string {
140163}
141164
142165func (u * Update ) Expressions () []sql.Expression {
143- return u .checks .ToExpressions ()
166+ exprs := append ([]sql.Expression {}, u .checks .ToExpressions ()... )
167+ exprs = append (exprs , u .Returning ... )
168+ return exprs
144169}
145170
146171func (u * Update ) Resolved () bool {
147- return u .Child .Resolved () && expression .ExpressionsResolved (u .checks .ToExpressions ()... )
172+ return u .Child .Resolved () &&
173+ expression .ExpressionsResolved (u .checks .ToExpressions ()... ) &&
174+ expression .ExpressionsResolved (u .Returning ... )
175+
148176}
149177
150178func (u Update ) WithExpressions (newExprs ... sql.Expression ) (sql.Node , error ) {
151- if len (newExprs ) != len (u .checks ) {
152- return nil , sql .ErrInvalidChildrenNumber .New (u , len (newExprs ), len (u .checks ))
179+ expectedLength := len (u .checks ) + len (u .Returning )
180+ if len (newExprs ) != expectedLength {
181+ return nil , sql .ErrInvalidChildrenNumber .New (u , len (newExprs ), expectedLength )
153182 }
154183
155184 var err error
156- u .checks , err = u .checks .FromExpressions (newExprs )
185+ u .checks , err = u .checks .FromExpressions (newExprs [: len ( u . checks )] )
157186 if err != nil {
158187 return nil , err
159188 }
160189
190+ u .Returning = newExprs [len (u .checks ):]
191+
161192 return & u , nil
162193}
163194
0 commit comments