Skip to content

Commit c3ef1b5

Browse files
author
James Cor
committed
test
1 parent cd9ddef commit c3ef1b5

File tree

3 files changed

+73
-45
lines changed

3 files changed

+73
-45
lines changed

sql/analyzer/inserts.go

Lines changed: 67 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -29,43 +29,49 @@ import (
2929
)
3030

3131
func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
32-
if _, ok := n.(*plan.TriggerExecutor); ok {
33-
return n, transform.SameTree, nil
34-
} else if _, ok := n.(*plan.CreateProcedure); ok {
32+
switch n.(type) {
33+
case *plan.TriggerExecutor, *plan.CreateProcedure:
3534
return n, transform.SameTree, nil
3635
}
36+
3737
// We capture all INSERTs along the tree, such as those inside of block statements.
38-
return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
38+
var selFunc transform.SelectorFunc = func(c transform.Context) bool {
39+
switch c.Node.(type) {
40+
case *plan.InsertInto:
41+
return true
42+
default:
43+
return false
44+
}
45+
}
46+
47+
var ctxFunc transform.CtxFunc = func(c transform.Context) (sql.Node, transform.TreeIdentity, error) {
3948
insert, ok := n.(*plan.InsertInto)
4049
if !ok {
4150
return n, transform.SameTree, nil
4251
}
4352

53+
source := insert.Source
4454
table := getResolvedTable(insert.Destination)
45-
4655
insertable, err := plan.GetInsertable(table)
4756
if err != nil {
4857
return nil, transform.SameTree, err
4958
}
59+
dstSchema := insertable.Schema()
5060

51-
source := insert.Source
52-
// TriggerExecutor has already been analyzed
53-
if _, ok := insert.Source.(*plan.TriggerExecutor); !ok && !insert.LiteralValueSource {
54-
// Analyze the source of the insert independently
55-
if _, ok := insert.Source.(*plan.Values); ok {
61+
// Analyze the source of the insert independently
62+
if !insert.LiteralValueSource {
63+
if _, isValues := source.(*plan.Values); isValues {
5664
scope = scope.NewScope(plan.NewProject(
57-
expression.SchemaToGetFields(insert.Source.Schema()[:len(insert.ColumnNames)], sql.ColSet{}),
58-
plan.NewSubqueryAlias("dummy", "", insert.Source),
65+
expression.SchemaToGetFields(source.Schema()[:len(insert.ColumnNames)], sql.ColSet{}),
66+
plan.NewSubqueryAlias("dummy", "", source),
5967
))
6068
}
61-
source, _, err = a.analyzeWithSelector(ctx, insert.Source, scope, SelectAllBatches, newInsertSourceSelector(sel), qFlags)
69+
source, _, err = a.analyzeWithSelector(ctx, source, scope, SelectAllBatches, newInsertSourceSelector(sel), qFlags)
6270
if err != nil {
6371
return nil, transform.SameTree, err
6472
}
6573
}
6674

67-
dstSchema := insertable.Schema()
68-
6975
// normalize the column name
7076
columnNames := make([]string, len(insert.ColumnNames))
7177
for i, name := range insert.ColumnNames {
@@ -81,13 +87,15 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
8187
}
8288

8389
// The schema of the destination node and the underlying table differ subtly in terms of defaults
84-
project, firstGeneratedAutoIncRowIdx, err := wrapRowSource(ctx, source, insertable, insert.Destination.Schema(), columnNames)
90+
project, firstGeneratedAutoIncRowIdx, err := wrapRowSource(ctx, source, insertable, dstSchema, columnNames)
8591
if err != nil {
8692
return nil, transform.SameTree, err
8793
}
8894

89-
return insert.WithSource(project).WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx), transform.NewTree, nil
90-
})
95+
newInsert := insert.WithSource(project).WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx)
96+
return newInsert, transform.NewTree, nil
97+
}
98+
return transform.NodeWithCtx(n, selFunc, ctxFunc)
9199
}
92100

93101
func validateInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
@@ -177,6 +185,19 @@ func findColIdx(colName string, colNames []string) int {
177185
// the underlying table in the same order. Also, returns an integer value that indicates when this row source will
178186
// result in an automatically generated value for an auto_increment column.
179187
func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, error) {
188+
// if source is a triggerExec node, we need to wrap the child in the project not the triggerExec node itself
189+
if trigExec, isTrigExec := insertSource.(*plan.TriggerExecutor); isTrigExec {
190+
newLeft, firstGeneratedAutoIncRowIdx, err := wrapRowSource(ctx, trigExec.Left(), destTbl, schema, columnNames)
191+
if err != nil {
192+
return nil, -1, err
193+
}
194+
newTrigExec, err := trigExec.WithChildren(newLeft, trigExec.Right())
195+
if err != nil {
196+
return nil, -1, err
197+
}
198+
return newTrigExec, firstGeneratedAutoIncRowIdx, nil
199+
}
200+
180201
projExprs := make([]sql.Expression, len(schema))
181202
firstGeneratedAutoIncRowIdx := -1
182203

@@ -188,6 +209,10 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
188209
defaultExpr = col.Generated
189210
}
190211

212+
if !col.Nullable && defaultExpr == nil && !col.AutoIncrement {
213+
return nil, -1, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name)
214+
}
215+
191216
var err error
192217
colNameToIdx := make(map[string]int)
193218
for i, c := range schema {
@@ -227,26 +252,30 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
227252
firstGeneratedAutoIncRowIdx = 0
228253
} else {
229254
// Additionally, the first NULL, DEFAULT, or empty value is what the last_insert_id should be set to.
230-
switch src := insertSource.(type) {
231-
case *plan.Values:
232-
for ii, tup := range src.ExpressionTuples {
233-
expr := tup[colIdx]
234-
if unwrap, ok := expr.(*expression.Wrapper); ok {
235-
expr = unwrap.Unwrap()
236-
}
237-
if _, isDef := expr.(*sql.ColumnDefaultValue); isDef {
238-
firstGeneratedAutoIncRowIdx = ii
239-
break
240-
}
241-
if lit, isLit := expr.(*expression.Literal); isLit {
242-
// If a literal NULL or if 0 is specified and the NO_AUTO_VALUE_ON_ZERO SQL mode is
243-
// not active, then MySQL will fill in an auto_increment value.
244-
if types.Null.Equals(lit.Type()) ||
245-
(!sql.LoadSqlMode(ctx).ModeEnabled(sql.NoAutoValueOnZero) && isZero(lit)) {
246-
firstGeneratedAutoIncRowIdx = ii
247-
break
248-
}
249-
}
255+
src, isValues := insertSource.(*plan.Values)
256+
if !isValues {
257+
continue
258+
}
259+
260+
for ii, tup := range src.ExpressionTuples {
261+
expr := tup[colIdx]
262+
if unwrap, ok := expr.(*expression.Wrapper); ok {
263+
expr = unwrap.Unwrap()
264+
}
265+
if _, isDef := expr.(*sql.ColumnDefaultValue); isDef {
266+
firstGeneratedAutoIncRowIdx = ii
267+
break
268+
}
269+
lit, isLit := expr.(*expression.Literal)
270+
if !isLit {
271+
continue
272+
}
273+
// If a literal NULL or if 0 is specified and the NO_AUTO_VALUE_ON_ZERO SQL mode is
274+
// not active, then MySQL will fill in an auto_increment value.
275+
if types.Null.Equals(lit.Type()) ||
276+
(!sql.LoadSqlMode(ctx).ModeEnabled(sql.NoAutoValueOnZero) && isZero(lit)) {
277+
firstGeneratedAutoIncRowIdx = ii
278+
break
250279
}
251280
}
252281
}

sql/analyzer/rules.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@ package analyzer
1717
func init() {
1818
OnceAfterAll = []Rule{
1919
{assignExecIndexesId, assignExecIndexes},
20-
// resolveInsertRows inserts a projection wrapping values that cannot be seen by fixup
21-
{resolveInsertRowsId, resolveInsertRows},
2220
{applyTriggersId, applyTriggers},
23-
{validateInsertRowsId, validateInsertRows},
21+
{resolveInsertRowsId, resolveInsertRows},
2422
{applyProceduresId, applyProcedures},
2523
{inlineSubqueryAliasRefsId, inlineSubqueryAliasRefs},
2624
{cacheSubqueryAliasesInJoinsId, cacheSubqueryAliasesInJoins},

sql/planbuilder/dml.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,16 @@ func (b *Builder) buildInsertValues(inScope *scope, v *ast.AliasedValues, column
181181
for i, columnName := range columnNames {
182182
index := destSchema.IndexOfColName(columnName)
183183
if index == -1 {
184+
// ignore missing columns when in trigger context
184185
if !b.TriggerCtx().Call && len(b.TriggerCtx().UnresolvedTables) > 0 {
185186
continue
186187
}
187-
err := sql.ErrUnknownColumn.New(columnName, tableName)
188-
b.handleErr(err)
188+
b.handleErr(sql.ErrUnknownColumn.New(columnName, tableName))
189189
}
190190

191-
columnDefaultValues[i] = destSchema[index].Default
192-
if columnDefaultValues[i] == nil && destSchema[index].Generated != nil {
191+
if destSchema[index].Default != nil {
192+
columnDefaultValues[i] = destSchema[index].Default
193+
} else if columnDefaultValues[i] == nil {
193194
columnDefaultValues[i] = destSchema[index].Generated
194195
}
195196
}

0 commit comments

Comments
 (0)