Skip to content

Commit b57eb10

Browse files
authored
Merge pull request #1152 from dolthub/zachmu/warnings
bug fix for column defaults with function values
2 parents 1644bb9 + a80e17c commit b57eb10

File tree

8 files changed

+351
-15
lines changed

8 files changed

+351
-15
lines changed

server/analyzer/init.go

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
package analyzer
1616

1717
import (
18-
"github.com/cockroachdb/errors"
19-
2018
"github.com/dolthub/go-mysql-server/sql/analyzer"
2119
)
2220

@@ -37,6 +35,7 @@ const (
3735
ruleId_ResolveType // resolveType
3836
ruleId_ReplaceArithmeticExpressions // replaceArithmeticExpressions
3937
ruleId_OptimizeFunctions // optimizeFunctions
38+
ruleId_ValidateColumnDefaults // validateColumnDefaults
4039
)
4140

4241
// Init adds additional rules to the analyzer to handle Doltgres-specific functionality.
@@ -45,13 +44,13 @@ func Init() {
4544
analyzer.Rule{Id: ruleId_ResolveType, Apply: ResolveType},
4645
analyzer.Rule{Id: ruleId_TypeSanitizer, Apply: TypeSanitizer},
4746
analyzer.Rule{Id: ruleId_AddDomainConstraints, Apply: AddDomainConstraints},
48-
getAnalyzerRule(analyzer.OnceBeforeDefault, analyzer.ValidateColumnDefaultsId),
47+
analyzer.Rule{Id: ruleId_ValidateColumnDefaults, Apply: ValidateColumnDefaults},
4948
analyzer.Rule{Id: ruleId_AssignInsertCasts, Apply: AssignInsertCasts},
5049
analyzer.Rule{Id: ruleId_AssignUpdateCasts, Apply: AssignUpdateCasts},
5150
analyzer.Rule{Id: ruleId_ReplaceIndexedTables, Apply: ReplaceIndexedTables},
5251
)
5352

54-
// Column default validation was moved to occur after type sanitization, so we'll remove it from its original place
53+
// We remove the original column default rule, as we have our own implementation
5554
analyzer.OnceBeforeDefault = removeAnalyzerRules(analyzer.OnceBeforeDefault, analyzer.ValidateColumnDefaultsId)
5655

5756
// PostgreSQL doesn't have the concept of prefix lengths, so we add a rule to implicitly add them
@@ -76,17 +75,6 @@ func Init() {
7675
analyzer.Rule{Id: ruleId_InsertContextRootFinalizer, Apply: InsertContextRootFinalizer})
7776
}
7877

79-
// getAnalyzerRule returns the rule matching the given ID.
80-
func getAnalyzerRule(rules []analyzer.Rule, id analyzer.RuleId) analyzer.Rule {
81-
for _, rule := range rules {
82-
if rule.Id == id {
83-
return rule
84-
}
85-
}
86-
// This will only occur if GMS has been changed
87-
panic(errors.Errorf("rule not found: %d", id))
88-
}
89-
9078
// insertAnalyzerRules inserts the given rule(s) before or after the given analyzer.RuleId, returning an updated slice.
9179
func insertAnalyzerRules(rules []analyzer.Rule, id analyzer.RuleId, before bool, additionalRules ...analyzer.Rule) []analyzer.Rule {
9280
newRules := make([]analyzer.Rule, len(rules)+len(additionalRules))
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
// Copyright 2020-2021 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package analyzer
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/analyzer"
20+
"github.com/dolthub/go-mysql-server/sql/expression"
21+
"github.com/dolthub/go-mysql-server/sql/plan"
22+
"github.com/dolthub/go-mysql-server/sql/transform"
23+
24+
pgnode "github.com/dolthub/doltgresql/server/node"
25+
)
26+
27+
// validateColumnDefaults ensures that newly created column defaults from a DDL statement are legal for the type of
28+
// column, various other business logic checks to match MySQL's logic.
29+
func ValidateColumnDefaults(ctx *sql.Context, _ *analyzer.Analyzer, n sql.Node, _ *plan.Scope, _ analyzer.RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
30+
span, ctx := ctx.Span("validateColumnDefaults")
31+
defer span.End()
32+
33+
return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
34+
switch node := n.(type) {
35+
case *plan.AlterDefaultSet:
36+
table := getResolvedTable(node)
37+
sch := table.Schema()
38+
index := sch.IndexOfColName(node.ColumnName)
39+
if index == -1 {
40+
return nil, transform.SameTree, sql.ErrColumnNotFound.New(node.ColumnName)
41+
}
42+
col := sch[index]
43+
err := validateColumnDefault(ctx, col, node.Default)
44+
if err != nil {
45+
return node, transform.SameTree, err
46+
}
47+
48+
return node, transform.SameTree, nil
49+
50+
case sql.SchemaTarget:
51+
switch node.(type) {
52+
case *plan.AlterPK, *plan.AddColumn, *plan.ModifyColumn, *plan.AlterDefaultDrop, *plan.CreateTable, *plan.DropColumn, *pgnode.CreateTable:
53+
// DDL nodes must validate any new column defaults, continue to logic below
54+
default:
55+
// other node types are not altering the schema and therefore don't need validation of column defaults
56+
return n, transform.SameTree, nil
57+
}
58+
59+
// There may be multiple DDL nodes in the plan (ALTER TABLE statements can have many clauses), and for each of them
60+
// we need to count the column indexes in the very hacky way outlined above.
61+
i := 0
62+
return transform.NodeExprs(n, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
63+
eWrapper, ok := e.(*expression.Wrapper)
64+
if !ok {
65+
return e, transform.SameTree, nil
66+
}
67+
68+
defer func() {
69+
i++
70+
}()
71+
72+
eVal := eWrapper.Unwrap()
73+
if eVal == nil {
74+
return e, transform.SameTree, nil
75+
}
76+
colDefault, ok := eVal.(*sql.ColumnDefaultValue)
77+
if !ok {
78+
return e, transform.SameTree, nil
79+
}
80+
81+
col, err := lookupColumnForTargetSchema(ctx, node, i)
82+
if err != nil {
83+
return nil, transform.SameTree, err
84+
}
85+
86+
err = validateColumnDefault(ctx, col, colDefault)
87+
if err != nil {
88+
return nil, transform.SameTree, err
89+
}
90+
91+
return e, transform.SameTree, nil
92+
})
93+
default:
94+
return node, transform.SameTree, nil
95+
}
96+
})
97+
}
98+
99+
// lookupColumnForTargetSchema looks at the target schema for the specified SchemaTarget node and returns
100+
// the column based on the specified index. For most node types, this is simply indexing into the target
101+
// schema but a few types require special handling.
102+
func lookupColumnForTargetSchema(_ *sql.Context, node sql.SchemaTarget, colIndex int) (*sql.Column, error) {
103+
schema := node.TargetSchema()
104+
105+
switch n := node.(type) {
106+
case *plan.ModifyColumn:
107+
if colIndex < len(schema) {
108+
return schema[colIndex], nil
109+
} else {
110+
return n.NewColumn(), nil
111+
}
112+
case *plan.AddColumn:
113+
if colIndex < len(schema) {
114+
return schema[colIndex], nil
115+
} else {
116+
return n.Column(), nil
117+
}
118+
case *plan.AlterDefaultSet:
119+
index := schema.IndexOfColName(n.ColumnName)
120+
if index == -1 {
121+
return nil, sql.ErrTableColumnNotFound.New(n.Table, n.ColumnName)
122+
}
123+
return schema[index], nil
124+
default:
125+
if colIndex < len(schema) {
126+
return schema[colIndex], nil
127+
} else {
128+
// TODO: sql.ErrColumnNotFound would be a better error here, but we need to add all the different node types to
129+
// the switch to get it
130+
return nil, expression.ErrIndexOutOfBounds.New(colIndex, len(schema))
131+
}
132+
}
133+
}
134+
135+
// validateColumnDefault validates that the column default expression is valid for the column type and returns an error
136+
// if not
137+
func validateColumnDefault(ctx *sql.Context, col *sql.Column, colDefault *sql.ColumnDefaultValue) error {
138+
if colDefault == nil {
139+
return nil
140+
}
141+
142+
var err error
143+
sql.Inspect(colDefault.Expr, func(e sql.Expression) bool {
144+
switch e.(type) {
145+
case sql.FunctionExpression, *expression.UnresolvedFunction:
146+
// TODO: functions must be deterministic to be used in column defaults
147+
return true
148+
case *plan.Subquery:
149+
err = sql.ErrColumnDefaultSubquery.New(col.Name)
150+
return false
151+
case *expression.GetField:
152+
if !colDefault.IsParenthesized() {
153+
err = sql.ErrInvalidColumnDefaultValue.New(col.Name)
154+
return false
155+
}
156+
return true
157+
default:
158+
return true
159+
}
160+
})
161+
162+
if err != nil {
163+
return err
164+
}
165+
166+
// validate type of default expression
167+
if err = colDefault.CheckType(ctx); err != nil {
168+
return err
169+
}
170+
171+
return nil
172+
}
173+
174+
// Finds first ResolvedTable node that is a descendant of the node given
175+
// This function will not look inside SubqueryAliases
176+
func getResolvedTable(node sql.Node) *plan.ResolvedTable {
177+
var table *plan.ResolvedTable
178+
transform.Inspect(node, func(n sql.Node) bool {
179+
// Inspect is called on all children of a node even if an earlier child's call returns false.
180+
// We only want the first TableNode match.
181+
if table != nil {
182+
return false
183+
}
184+
switch nn := n.(type) {
185+
case *plan.SubqueryAlias:
186+
// We should not be matching with ResolvedTables inside SubqueryAliases
187+
return false
188+
case *plan.ResolvedTable:
189+
if !plan.IsDualTable(nn) {
190+
table = nn
191+
return false
192+
}
193+
case *plan.IndexedTableAccess:
194+
if rt, ok := nn.TableNode.(*plan.ResolvedTable); ok {
195+
table = rt
196+
return false
197+
}
198+
}
199+
return true
200+
})
201+
return table
202+
}

server/doltgres_handler.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,17 @@ func (h *DoltgresHandler) ComBind(ctx context.Context, c *mysql.Conn, query stri
104104

105105
bvs, err := h.convertBindParameters(sqlCtx, bindVars.varTypes, bindVars.formatCodes, bindVars.parameters)
106106
if err != nil {
107+
if printErrorStackTraces {
108+
fmt.Printf("unable to convert bind params: %+v\n", err)
109+
}
107110
return nil, nil, err
108111
}
109112

110113
queryPlan, err := h.e.BoundQueryPlan(sqlCtx, query, stmt, bvs)
111114
if err != nil {
115+
if printErrorStackTraces {
116+
fmt.Printf("unable to bind query plan: %+v\n", err)
117+
}
112118
return nil, nil, err
113119
}
114120

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright 2025 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package functions
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
20+
"github.com/dolthub/doltgresql/postgres/parser/uuid"
21+
"github.com/dolthub/doltgresql/server/functions/framework"
22+
pgtypes "github.com/dolthub/doltgresql/server/types"
23+
)
24+
25+
// initGenRandomUuid registers the functions to the catalog.
26+
func initGenRandomUuid() {
27+
framework.RegisterFunction(gen_random_uuid)
28+
}
29+
30+
var gen_random_uuid = framework.Function0{
31+
Name: "gen_random_uuid",
32+
Return: pgtypes.Uuid,
33+
Strict: true,
34+
Callable: func(ctx *sql.Context) (any, error) {
35+
return uuid.NewV4()
36+
},
37+
}

server/functions/init.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ func Init() {
102102
initFloor()
103103
initFormatType()
104104
initGcd()
105+
initGenRandomUuid()
105106
initInitcap()
106107
initLcm()
107108
initLeft()

server/node/create_table.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type CreateTable struct {
2929
}
3030

3131
var _ sql.ExecSourceRel = (*CreateTable)(nil)
32+
var _ sql.SchemaTarget = (*CreateTable)(nil)
3233

3334
// NewCreateTable returns a new *CreateTable.
3435
func NewCreateTable(createTable *plan.CreateTable, sequences []*CreateSequence) *CreateTable {
@@ -101,3 +102,18 @@ func (c *CreateTable) WithChildren(children ...sql.Node) (sql.Node, error) {
101102
sequences: c.sequences,
102103
}, nil
103104
}
105+
106+
func (c *CreateTable) TargetSchema() sql.Schema {
107+
return c.gmsCreateTable.TargetSchema()
108+
}
109+
110+
func (c CreateTable) WithTargetSchema(schema sql.Schema) (sql.Node, error) {
111+
n, err := c.gmsCreateTable.WithTargetSchema(schema)
112+
if err != nil {
113+
return nil, err
114+
}
115+
116+
c.gmsCreateTable = n.(*plan.CreateTable)
117+
118+
return &c, nil
119+
}

0 commit comments

Comments
 (0)