Skip to content

Commit 8d34778

Browse files
author
James Cor
committed
fix
1 parent e8ce0df commit 8d34778

File tree

4 files changed

+41
-13
lines changed

4 files changed

+41
-13
lines changed

sql/analyzer/inserts.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,23 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
8181
}
8282

8383
// The schema of the destination node and the underlying table differ subtly in terms of defaults
84-
project, firstGeneratedAutoIncRowIdx, err := wrapRowSource(ctx, source, insertable, insert.Destination.Schema(), columnNames)
84+
var missingValFlags []bool
85+
project, firstGeneratedAutoIncRowIdx, missingValFlags, err := wrapRowSource(
86+
ctx,
87+
source,
88+
insertable,
89+
insert.Destination.Schema(),
90+
columnNames,
91+
)
8592
if err != nil {
8693
return nil, transform.SameTree, err
8794
}
8895

89-
return insert.WithSource(project).WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx), transform.NewTree, nil
96+
return insert.WithSource(project).
97+
WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx).
98+
WithMissingValFlags(missingValFlags),
99+
transform.NewTree,
100+
nil
90101
})
91102
}
92103

@@ -117,8 +128,9 @@ func findColIdx(colName string, colNames []string) int {
117128
// wrapRowSource returns a projection that wraps the original row source so that its schema matches the full schema of
118129
// the underlying table in the same order. Also, returns an integer value that indicates when this row source will
119130
// result in an automatically generated value for an auto_increment column.
120-
func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, error) {
131+
func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, []bool, error) {
121132
projExprs := make([]sql.Expression, len(schema))
133+
missingVals := make([]bool, len(schema))
122134
firstGeneratedAutoIncRowIdx := -1
123135

124136
for i, col := range schema {
@@ -130,7 +142,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
130142
defaultExpr = col.Generated
131143
}
132144
if !col.Nullable && defaultExpr == nil && !col.AutoIncrement {
133-
return nil, -1, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name)
145+
missingVals[colIdx] = true
134146
}
135147

136148
var err error
@@ -151,7 +163,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
151163
}
152164
})
153165
if err != nil {
154-
return nil, -1, err
166+
return nil, -1, nil, err
155167
}
156168
projExprs[i] = def
157169
} else {
@@ -163,7 +175,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
163175
// wrap it in an AutoIncrement expression.
164176
ai, err := expression.NewAutoIncrement(ctx, destTbl, projExprs[i])
165177
if err != nil {
166-
return nil, -1, err
178+
return nil, -1, nil, err
167179
}
168180
projExprs[i] = ai
169181

@@ -206,7 +218,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
206218
// ColumnDefaultValue to create the UUID), then update the project to include the AutoUuid expression.
207219
newExpr, identity, err := insertAutoUuidExpression(ctx, columnDefaultValue, autoUuidCol)
208220
if err != nil {
209-
return nil, -1, err
221+
return nil, -1, nil, err
210222
}
211223
if identity == transform.NewTree {
212224
projExprs[autoUuidColIdx] = newExpr
@@ -217,12 +229,12 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
217229
// the AutoUuid expression to it.
218230
err := wrapAutoUuidInValuesTuples(ctx, autoUuidCol, insertSource, columnNames)
219231
if err != nil {
220-
return nil, -1, err
232+
return nil, -1, nil, err
221233
}
222234
}
223235
}
224236

225-
return plan.NewProject(projExprs, insertSource), firstGeneratedAutoIncRowIdx, nil
237+
return plan.NewProject(projExprs, insertSource), firstGeneratedAutoIncRowIdx, missingVals, nil
226238
}
227239

228240
// isZero returns true if the specified literal value |lit| has a value equal to 0.

sql/plan/insert.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ type InsertInto struct {
7272

7373
// FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.
7474
FirstGeneratedAutoIncRowIdx int
75+
76+
// MissingValFlags marks which columns in the destination schema are expected to have default values.
77+
MissingValFlags []bool
7578
}
7679

7780
var _ sql.Databaser = (*InsertInto)(nil)
@@ -201,6 +204,14 @@ func (ii *InsertInto) WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx int) *Ins
201204
return &np
202205
}
203206

207+
// WithMissingValFlags sets the flags for the insert destination columns, which mark which of the columns are expected
208+
// to be filled with the DEFAULT or GENERATED value.
209+
func (ii *InsertInto) WithMissingValFlags(missingValFlags []bool) *InsertInto {
210+
np := *ii
211+
np.MissingValFlags = missingValFlags
212+
return &np
213+
}
214+
204215
// String implements the fmt.Stringer interface.
205216
func (ii *InsertInto) String() string {
206217
pr := sql.NewTreePrinter()

sql/rowexec/dml.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row
9090
ctx: ctx,
9191
ignore: ii.Ignore,
9292
firstGeneratedAutoIncRowIdx: ii.FirstGeneratedAutoIncRowIdx,
93+
missingValFlags: ii.MissingValFlags,
9394
}
9495

9596
var ed sql.EditOpenerCloser

sql/rowexec/insert.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ type insertIter struct {
4545
ignore bool
4646

4747
firstGeneratedAutoIncRowIdx int
48+
49+
missingFlagVal []bool
4850
}
4951

5052
func getInsertExpressions(values sql.Node) []sql.Expression {
@@ -395,12 +397,14 @@ func (i *insertIter) validateNullability(ctx *sql.Context, dstSchema sql.Schema,
395397
for count, col := range dstSchema {
396398
if !col.Nullable && row[count] == nil {
397399
// In the case of an IGNORE we set the nil value to a default and add a warning
398-
if i.ignore {
399-
row[count] = col.Type.Zero()
400-
_ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil
401-
} else {
400+
if !i.ignore {
401+
if i.missingFlagVal[count] {
402+
return sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name)
403+
}
402404
return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)
403405
}
406+
row[count] = col.Type.Zero()
407+
_ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil
404408
}
405409
}
406410
return nil

0 commit comments

Comments
 (0)