@@ -29,43 +29,49 @@ import (
2929)
3030
3131func resolveInsertRows (ctx * sql.Context , a * Analyzer , n sql.Node , scope * plan.Scope , sel RuleSelector , qFlags * sql.QueryFlags ) (sql.Node , transform.TreeIdentity , error ) {
32- if _ , ok := n .(* plan.TriggerExecutor ); ok {
33- return n , transform .SameTree , nil
34- } else if _ , ok := n .(* plan.CreateProcedure ); ok {
32+ switch n .(type ) {
33+ case * plan.TriggerExecutor , * plan.CreateProcedure :
3534 return n , transform .SameTree , nil
3635 }
36+
3737 // We capture all INSERTs along the tree, such as those inside of block statements.
38- return transform .Node (n , func (n sql.Node ) (sql.Node , transform.TreeIdentity , error ) {
38+ var selFunc transform.SelectorFunc = func (c transform.Context ) bool {
39+ switch c .Node .(type ) {
40+ case * plan.InsertInto :
41+ return true
42+ default :
43+ return false
44+ }
45+ }
46+
47+ var ctxFunc transform.CtxFunc = func (c transform.Context ) (sql.Node , transform.TreeIdentity , error ) {
3948 insert , ok := n .(* plan.InsertInto )
4049 if ! ok {
4150 return n , transform .SameTree , nil
4251 }
4352
53+ source := insert .Source
4454 table := getResolvedTable (insert .Destination )
45-
4655 insertable , err := plan .GetInsertable (table )
4756 if err != nil {
4857 return nil , transform .SameTree , err
4958 }
59+ dstSchema := insertable .Schema ()
5060
51- source := insert .Source
52- // TriggerExecutor has already been analyzed
53- if _ , ok := insert .Source .(* plan.TriggerExecutor ); ! ok && ! insert .LiteralValueSource {
54- // Analyze the source of the insert independently
55- if _ , ok := insert .Source .(* plan.Values ); ok {
61+ // Analyze the source of the insert independently
62+ if ! insert .LiteralValueSource {
63+ if _ , isValues := source .(* plan.Values ); isValues {
5664 scope = scope .NewScope (plan .NewProject (
57- expression .SchemaToGetFields (insert . Source .Schema ()[:len (insert .ColumnNames )], sql.ColSet {}),
58- plan .NewSubqueryAlias ("dummy" , "" , insert . Source ),
65+ expression .SchemaToGetFields (source .Schema ()[:len (insert .ColumnNames )], sql.ColSet {}),
66+ plan .NewSubqueryAlias ("dummy" , "" , source ),
5967 ))
6068 }
61- source , _ , err = a .analyzeWithSelector (ctx , insert . Source , scope , SelectAllBatches , newInsertSourceSelector (sel ), qFlags )
69+ source , _ , err = a .analyzeWithSelector (ctx , source , scope , SelectAllBatches , newInsertSourceSelector (sel ), qFlags )
6270 if err != nil {
6371 return nil , transform .SameTree , err
6472 }
6573 }
6674
67- dstSchema := insertable .Schema ()
68-
6975 // normalize the column name
7076 columnNames := make ([]string , len (insert .ColumnNames ))
7177 for i , name := range insert .ColumnNames {
@@ -81,13 +87,15 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
8187 }
8288
8389 // The schema of the destination node and the underlying table differ subtly in terms of defaults
84- project , firstGeneratedAutoIncRowIdx , err := wrapRowSource (ctx , source , insertable , insert . Destination . Schema () , columnNames )
90+ project , firstGeneratedAutoIncRowIdx , err := wrapRowSource (ctx , source , insertable , dstSchema , columnNames )
8591 if err != nil {
8692 return nil , transform .SameTree , err
8793 }
8894
89- return insert .WithSource (project ).WithAutoIncrementIdx (firstGeneratedAutoIncRowIdx ), transform .NewTree , nil
90- })
95+ newInsert := insert .WithSource (project ).WithAutoIncrementIdx (firstGeneratedAutoIncRowIdx )
96+ return newInsert , transform .NewTree , nil
97+ }
98+ return transform .NodeWithCtx (n , selFunc , ctxFunc )
9199}
92100
93101func validateInsertRows (ctx * sql.Context , a * Analyzer , n sql.Node , scope * plan.Scope , sel RuleSelector , qFlags * sql.QueryFlags ) (sql.Node , transform.TreeIdentity , error ) {
@@ -177,6 +185,19 @@ func findColIdx(colName string, colNames []string) int {
177185// the underlying table in the same order. Also, returns an integer value that indicates when this row source will
178186// result in an automatically generated value for an auto_increment column.
179187func wrapRowSource (ctx * sql.Context , insertSource sql.Node , destTbl sql.Table , schema sql.Schema , columnNames []string ) (sql.Node , int , error ) {
188+ // if source is a triggerExec node, we need to wrap the child in the project not the triggerExec node itself
189+ if trigExec , isTrigExec := insertSource .(* plan.TriggerExecutor ); isTrigExec {
190+ newLeft , firstGeneratedAutoIncRowIdx , err := wrapRowSource (ctx , trigExec .Left (), destTbl , schema , columnNames )
191+ if err != nil {
192+ return nil , - 1 , err
193+ }
194+ newTrigExec , err := trigExec .WithChildren (newLeft , trigExec .Right ())
195+ if err != nil {
196+ return nil , - 1 , err
197+ }
198+ return newTrigExec , firstGeneratedAutoIncRowIdx , nil
199+ }
200+
180201 projExprs := make ([]sql.Expression , len (schema ))
181202 firstGeneratedAutoIncRowIdx := - 1
182203
@@ -188,6 +209,10 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
188209 defaultExpr = col .Generated
189210 }
190211
212+ if ! col .Nullable && defaultExpr == nil && ! col .AutoIncrement {
213+ return nil , - 1 , sql .ErrInsertIntoNonNullableDefaultNullColumn .New (col .Name )
214+ }
215+
191216 var err error
192217 colNameToIdx := make (map [string ]int )
193218 for i , c := range schema {
@@ -227,26 +252,30 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
227252 firstGeneratedAutoIncRowIdx = 0
228253 } else {
229254 // Additionally, the first NULL, DEFAULT, or empty value is what the last_insert_id should be set to.
230- switch src := insertSource .(type ) {
231- case * plan.Values :
232- for ii , tup := range src .ExpressionTuples {
233- expr := tup [colIdx ]
234- if unwrap , ok := expr .(* expression.Wrapper ); ok {
235- expr = unwrap .Unwrap ()
236- }
237- if _ , isDef := expr .(* sql.ColumnDefaultValue ); isDef {
238- firstGeneratedAutoIncRowIdx = ii
239- break
240- }
241- if lit , isLit := expr .(* expression.Literal ); isLit {
242- // If a literal NULL or if 0 is specified and the NO_AUTO_VALUE_ON_ZERO SQL mode is
243- // not active, then MySQL will fill in an auto_increment value.
244- if types .Null .Equals (lit .Type ()) ||
245- (! sql .LoadSqlMode (ctx ).ModeEnabled (sql .NoAutoValueOnZero ) && isZero (lit )) {
246- firstGeneratedAutoIncRowIdx = ii
247- break
248- }
249- }
255+ src , isValues := insertSource .(* plan.Values )
256+ if ! isValues {
257+ continue
258+ }
259+
260+ for ii , tup := range src .ExpressionTuples {
261+ expr := tup [colIdx ]
262+ if unwrap , ok := expr .(* expression.Wrapper ); ok {
263+ expr = unwrap .Unwrap ()
264+ }
265+ if _ , isDef := expr .(* sql.ColumnDefaultValue ); isDef {
266+ firstGeneratedAutoIncRowIdx = ii
267+ break
268+ }
269+ lit , isLit := expr .(* expression.Literal )
270+ if ! isLit {
271+ continue
272+ }
273+ // If a literal NULL or if 0 is specified and the NO_AUTO_VALUE_ON_ZERO SQL mode is
274+ // not active, then MySQL will fill in an auto_increment value.
275+ if types .Null .Equals (lit .Type ()) ||
276+ (! sql .LoadSqlMode (ctx ).ModeEnabled (sql .NoAutoValueOnZero ) && isZero (lit )) {
277+ firstGeneratedAutoIncRowIdx = ii
278+ break
250279 }
251280 }
252281 }
0 commit comments