Skip to content

Commit 28fdce5

Browse files
committed
Use the parser to quote identifiers
1 parent f392710 commit 28fdce5

File tree

4 files changed

+35
-17
lines changed

4 files changed

+35
-17
lines changed

memory/table.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ func stripTblNames(e sql.Expression) (sql.Expression, transform.TreeIdentity, er
134134
case *expression.GetField:
135135
// strip table names
136136
ne := expression.NewGetField(e.Index(), e.Type(), e.Name(), e.IsNullable())
137-
ne = ne.WithBackTickNames(e.IsBackTickNames())
137+
ne = ne.WithQuotedNames(sql.GlobalParser, e.IsQuotedIdentifier())
138138
return ne, transform.NewTree, nil
139139
default:
140140
}

sql/analyzer/resolve_column_defaults.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,15 +307,15 @@ func stripTableNamesFromDefault(e *expression.Wrapper) (sql.Expression, transfor
307307
return expression.WrapExpression(&nd), transform.NewTree, nil
308308
}
309309

310-
func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
310+
func backtickDefaultColumnValueNames(ctx *sql.Context, a *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
311311
span, ctx := ctx.Span("backtickDefaultColumnValueNames")
312312
defer span.End()
313313

314314
return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
315315
switch node := n.(type) {
316316
case *plan.AlterDefaultSet:
317317
eWrapper := expression.WrapExpression(node.Default)
318-
newExpr, same, err := backtickDefault(eWrapper)
318+
newExpr, same, err := quoteIdentifiers(a.Parser, eWrapper)
319319
if err != nil {
320320
return node, transform.SameTree, err
321321
}
@@ -335,7 +335,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node,
335335
return e, transform.SameTree, nil
336336
}
337337

338-
return backtickDefault(eWrapper)
338+
return quoteIdentifiers(a.Parser, eWrapper)
339339
})
340340
case *plan.ResolvedTable:
341341
ct, ok := node.Table.(*information_schema.ColumnsTable)
@@ -354,7 +354,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node,
354354
return e, transform.SameTree, nil
355355
}
356356

357-
return backtickDefault(eWrapper)
357+
return quoteIdentifiers(a.Parser, eWrapper)
358358
})
359359

360360
if err != nil {
@@ -376,7 +376,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node,
376376
})
377377
}
378378

379-
func backtickDefault(wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) {
379+
func quoteIdentifiers(parser sql.Parser, wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) {
380380
newDefault, ok := wrap.Unwrap().(*sql.ColumnDefaultValue)
381381
if !ok {
382382
return wrap, transform.SameTree, nil
@@ -388,7 +388,7 @@ func backtickDefault(wrap *expression.Wrapper) (sql.Expression, transform.TreeId
388388

389389
newExpr, same, err := transform.Expr(newDefault.Expr, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
390390
if e, isGf := expr.(*expression.GetField); isGf {
391-
return e.WithBackTickNames(true), transform.NewTree, nil
391+
return e.WithQuotedNames(parser,true), transform.NewTree, nil
392392
}
393393
return expr, transform.SameTree, nil
394394
})

sql/expression/get_field.go

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ type GetField struct {
3838
fieldType2 sql.Type2
3939
nullable bool
4040

41-
backTickNames bool
41+
// parser is the parser used to parse the expression and print it
42+
parser sql.Parser
43+
44+
// quoteName indicates whether the field name should be quoted when printed with String()
45+
quoteName bool
4246
}
4347

4448
var _ sql.Expression = (*GetField)(nil)
@@ -161,10 +165,14 @@ func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, err
161165
}
162166

163167
func (p *GetField) String() string {
164-
if p.table == "" {
165-
if p.backTickNames {
166-
return "`" + p.name + "`"
168+
if p.quoteName {
169+
if p.table == "" {
170+
return p.parser.QuoteIdentifier(p.name)
167171
}
172+
return p.parser.QuoteIdentifier(p.table) + "." + p.parser.QuoteIdentifier(p.name)
173+
}
174+
175+
if p.table == "" {
168176
return p.name
169177
}
170178
return p.table + "." + p.name
@@ -188,16 +196,17 @@ func (p *GetField) WithIndex(n int) sql.Expression {
188196
return &p2
189197
}
190198

191-
// WithBackTickNames returns a copy of this expression with the backtick names flag set to the given value.
192-
func (p *GetField) WithBackTickNames(backtick bool) *GetField {
199+
// WithQuotedNames returns a copy of this expression with the backtick names flag set to the given value.
200+
func (p *GetField) WithQuotedNames(parser sql.Parser, quoteNames bool) *GetField {
193201
p2 := *p
194-
p2.backTickNames = backtick
202+
p2.quoteName = quoteNames
203+
p2.parser = parser
195204
return &p2
196205
}
197206

198-
// IsBackTickNames returns whether the field name should be quoted with backticks.
199-
func (p *GetField) IsBackTickNames() bool {
200-
return p.backTickNames
207+
// IsQuotedIdentifier returns whether the field name should be quoted.
208+
func (p *GetField) IsQuotedIdentifier() bool {
209+
return p.quoteName
201210
}
202211

203212
// CollationCoercibility implements the interface sql.CollationCoercible.

sql/parser.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package sql
1616

1717
import (
1818
"context"
19+
"fmt"
1920
trace2 "runtime/trace"
2021
"strings"
2122
"unicode"
@@ -44,6 +45,10 @@ type Parser interface {
4445
// the index of the start of the next query. If |query| represents a no-op statement, such as ";" or "-- comment",
4546
// then implementations must return Vitess' ErrEmpty error.
4647
ParseOneWithOptions(context.Context, string, ast.ParserOptions) (ast.Statement, int, error)
48+
// QuoteIdentifier returns the identifier given quoted according to this parser's dialect. This is used to
49+
// standardize identifiers that cannot be parsed without quoting, because they break the normal identifier naming
50+
// rules (such as containing spaces)
51+
QuoteIdentifier(identifier string) string
4752
}
4853

4954
var _ Parser = &MysqlParser{}
@@ -99,3 +104,7 @@ func RemoveSpaceAndDelimiter(query string, d rune) string {
99104
return r == d || unicode.IsSpace(r)
100105
})
101106
}
107+
108+
func (m *MysqlParser) QuoteIdentifier(identifier string) string {
109+
return fmt.Sprintf("`%s`", strings.ReplaceAll(identifier, "`", "``"))
110+
}

0 commit comments

Comments
 (0)