Skip to content

Commit 5093181

Browse files
authored
allow before insert trigger to specify missing column, but better (#2883)
1 parent e8ce0df commit 5093181

File tree

5 files changed

+126
-13
lines changed

5 files changed

+126
-13
lines changed

enginetest/queries/trigger_queries.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,91 @@ var TriggerTests = []ScriptTest{
506506
},
507507
},
508508
},
509+
{
510+
Name: "insert trigger with missing column default value",
511+
SetUpScript: []string{
512+
"CREATE TABLE t (i INT PRIMARY KEY, j INT NOT NULL);",
513+
`
514+
CREATE TRIGGER trig BEFORE INSERT ON t
515+
FOR EACH ROW
516+
BEGIN
517+
SET new.j = 10;
518+
END;`,
519+
},
520+
Assertions: []ScriptTestAssertion{
521+
{
522+
Query: "INSERT INTO t (i) VALUES (1);",
523+
Expected: []sql.Row{
524+
{types.OkResult{RowsAffected: 1}},
525+
},
526+
},
527+
{
528+
Query: "INSERT INTO t (i, j) VALUES (2, null);",
529+
Expected: []sql.Row{
530+
{types.OkResult{RowsAffected: 1}},
531+
},
532+
},
533+
{
534+
Query: "SELECT * FROM t;",
535+
Expected: []sql.Row{
536+
{1, 10},
537+
{2, 10},
538+
},
539+
},
540+
},
541+
},
542+
{
543+
Name: "not null column with trigger that sets null should error",
544+
SetUpScript: []string{
545+
"CREATE TABLE t (i INT PRIMARY KEY, j INT NOT NULL);",
546+
`
547+
CREATE TRIGGER trig BEFORE INSERT ON t
548+
FOR EACH ROW
549+
BEGIN
550+
SET new.j = null;
551+
END;`,
552+
},
553+
Assertions: []ScriptTestAssertion{
554+
{
555+
Query: "INSERT INTO t (i) VALUES (1);",
556+
ExpectedErr: sql.ErrInsertIntoNonNullableDefaultNullColumn,
557+
},
558+
{
559+
Query: "INSERT INTO t (i, j) VALUES (1, 2);",
560+
ExpectedErr: sql.ErrInsertIntoNonNullableProvidedNull,
561+
},
562+
},
563+
},
564+
{
565+
Name: "not null column with before insert trigger should error",
566+
SetUpScript: []string{
567+
"CREATE TABLE t (i INT PRIMARY KEY, j INT NOT NULL);",
568+
`
569+
CREATE TRIGGER trig BEFORE INSERT ON t
570+
FOR EACH ROW
571+
BEGIN
572+
SET new.i = 10 * new.i;
573+
END;`,
574+
},
575+
Assertions: []ScriptTestAssertion{
576+
{
577+
Query: "INSERT INTO t (i) VALUES (1);",
578+
ExpectedErr: sql.ErrInsertIntoNonNullableDefaultNullColumn,
579+
},
580+
{
581+
Query: "INSERT INTO t (i, j) VALUES (1, 2);",
582+
Expected: []sql.Row{
583+
{types.NewOkResult(1)},
584+
},
585+
},
586+
{
587+
Query: "SELECT * FROM t;",
588+
Expected: []sql.Row{
589+
{10, 2},
590+
},
591+
},
592+
},
593+
},
509594

510595
// UPDATE triggers
511596
{

sql/analyzer/inserts.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,23 @@ func resolveInsertRows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Sc
8181
}
8282

8383
// The schema of the destination node and the underlying table differ subtly in terms of defaults
84-
project, firstGeneratedAutoIncRowIdx, err := wrapRowSource(ctx, source, insertable, insert.Destination.Schema(), columnNames)
84+
var deferredDefaults sql.FastIntSet
85+
project, firstGeneratedAutoIncRowIdx, deferredDefaults, err := wrapRowSource(
86+
ctx,
87+
source,
88+
insertable,
89+
insert.Destination.Schema(),
90+
columnNames,
91+
)
8592
if err != nil {
8693
return nil, transform.SameTree, err
8794
}
8895

89-
return insert.WithSource(project).WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx), transform.NewTree, nil
96+
return insert.WithSource(project).
97+
WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx).
98+
WithDeferredDefaults(deferredDefaults),
99+
transform.NewTree,
100+
nil
90101
})
91102
}
92103

@@ -117,8 +128,9 @@ func findColIdx(colName string, colNames []string) int {
117128
// wrapRowSource returns a projection that wraps the original row source so that its schema matches the full schema of
118129
// the underlying table in the same order. Also, returns an integer value that indicates when this row source will
119130
// result in an automatically generated value for an auto_increment column.
120-
func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, error) {
131+
func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, schema sql.Schema, columnNames []string) (sql.Node, int, sql.FastIntSet, error) {
121132
projExprs := make([]sql.Expression, len(schema))
133+
deferredDefaults := sql.NewFastIntSet()
122134
firstGeneratedAutoIncRowIdx := -1
123135

124136
for i, col := range schema {
@@ -130,7 +142,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
130142
defaultExpr = col.Generated
131143
}
132144
if !col.Nullable && defaultExpr == nil && !col.AutoIncrement {
133-
return nil, -1, sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name)
145+
deferredDefaults.Add(i)
134146
}
135147

136148
var err error
@@ -151,7 +163,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
151163
}
152164
})
153165
if err != nil {
154-
return nil, -1, err
166+
return nil, -1, sql.FastIntSet{}, err
155167
}
156168
projExprs[i] = def
157169
} else {
@@ -163,7 +175,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
163175
// wrap it in an AutoIncrement expression.
164176
ai, err := expression.NewAutoIncrement(ctx, destTbl, projExprs[i])
165177
if err != nil {
166-
return nil, -1, err
178+
return nil, -1, sql.FastIntSet{}, err
167179
}
168180
projExprs[i] = ai
169181

@@ -206,7 +218,7 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
206218
// ColumnDefaultValue to create the UUID), then update the project to include the AutoUuid expression.
207219
newExpr, identity, err := insertAutoUuidExpression(ctx, columnDefaultValue, autoUuidCol)
208220
if err != nil {
209-
return nil, -1, err
221+
return nil, -1, sql.FastIntSet{}, err
210222
}
211223
if identity == transform.NewTree {
212224
projExprs[autoUuidColIdx] = newExpr
@@ -217,12 +229,12 @@ func wrapRowSource(ctx *sql.Context, insertSource sql.Node, destTbl sql.Table, s
217229
// the AutoUuid expression to it.
218230
err := wrapAutoUuidInValuesTuples(ctx, autoUuidCol, insertSource, columnNames)
219231
if err != nil {
220-
return nil, -1, err
232+
return nil, -1, sql.FastIntSet{}, err
221233
}
222234
}
223235
}
224236

225-
return plan.NewProject(projExprs, insertSource), firstGeneratedAutoIncRowIdx, nil
237+
return plan.NewProject(projExprs, insertSource), firstGeneratedAutoIncRowIdx, deferredDefaults, nil
226238
}
227239

228240
// isZero returns true if the specified literal value |lit| has a value equal to 0.

sql/plan/insert.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ type InsertInto struct {
7272

7373
// FirstGenerateAutoIncRowIdx is the index of the first row inserted that increments last_insert_id.
7474
FirstGeneratedAutoIncRowIdx int
75+
76+
// DeferredDefaults marks which columns in the destination schema are expected to have default values.
77+
DeferredDefaults sql.FastIntSet
7578
}
7679

7780
var _ sql.Databaser = (*InsertInto)(nil)
@@ -201,6 +204,14 @@ func (ii *InsertInto) WithAutoIncrementIdx(firstGeneratedAutoIncRowIdx int) *Ins
201204
return &np
202205
}
203206

207+
// WithDeferredDefaults sets the flags for the insert destination columns, which mark which of the columns are expected
208+
// to be filled with the DEFAULT or GENERATED value.
209+
func (ii *InsertInto) WithDeferredDefaults(deferredDefaults sql.FastIntSet) *InsertInto {
210+
np := *ii
211+
np.DeferredDefaults = deferredDefaults
212+
return &np
213+
}
214+
204215
// String implements the fmt.Stringer interface.
205216
func (ii *InsertInto) String() string {
206217
pr := sql.NewTreePrinter()

sql/rowexec/dml.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ func (b *BaseBuilder) buildInsertInto(ctx *sql.Context, ii *plan.InsertInto, row
9090
ctx: ctx,
9191
ignore: ii.Ignore,
9292
firstGeneratedAutoIncRowIdx: ii.FirstGeneratedAutoIncRowIdx,
93+
deferredDefaults: ii.DeferredDefaults,
9394
}
9495

9596
var ed sql.EditOpenerCloser

sql/rowexec/insert.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ type insertIter struct {
4545
ignore bool
4646

4747
firstGeneratedAutoIncRowIdx int
48+
49+
deferredDefaults sql.FastIntSet
4850
}
4951

5052
func getInsertExpressions(values sql.Node) []sql.Expression {
@@ -395,12 +397,14 @@ func (i *insertIter) validateNullability(ctx *sql.Context, dstSchema sql.Schema,
395397
for count, col := range dstSchema {
396398
if !col.Nullable && row[count] == nil {
397399
// In the case of an IGNORE we set the nil value to a default and add a warning
398-
if i.ignore {
399-
row[count] = col.Type.Zero()
400-
_ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil
401-
} else {
400+
if !i.ignore {
401+
if i.deferredDefaults.Contains(count) {
402+
return sql.ErrInsertIntoNonNullableDefaultNullColumn.New(col.Name)
403+
}
402404
return sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)
403405
}
406+
row[count] = col.Type.Zero()
407+
_ = warnOnIgnorableError(ctx, row, sql.ErrInsertIntoNonNullableProvidedNull.New(col.Name)) // will always return nil
404408
}
405409
}
406410
return nil

0 commit comments

Comments
 (0)