diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 6e853d3562..444b7c3daf 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -10146,7 +10146,6 @@ from typestable`, {uint32(1000)}, }, }, - { Query: `select distinct pk1 from two_pk order by pk1`, Expected: []sql.Row{ @@ -10255,14 +10254,12 @@ from typestable`, {""}, }, }, - { Query: "select @@sql_mode = 1", Expected: []sql.Row{ {false}, }, }, - { Query: "explain select 1", SkipServerEngine: true, @@ -10290,6 +10287,20 @@ from typestable`, {" └─ name: "}, }, }, + { + Query: "select quote(i), quote(s) from mytable", + Expected: []sql.Row{ + {"'1'", "'first row'"}, + {"'2'", "'second row'"}, + {"'3'", "'third row'"}, + }, + }, + { + Query: "select i, s from mytable where quote(i) = quote(2)", + Expected: []sql.Row{ + {2, "second row"}, + }, + }, } var KeylessQueries = []QueryTest{ diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index 77c05e17c9..a6bccbc828 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -190,6 +190,7 @@ var BuiltIns = []sql.Function{ sql.Function2{Name: "pow", Fn: NewPower}, sql.Function2{Name: "power", Fn: NewPower}, sql.Function1{Name: "quarter", Fn: NewQuarter}, + sql.Function1{Name: "quote", Fn: NewQuote}, sql.Function1{Name: "radians", Fn: NewRadians}, sql.FunctionN{Name: "rand", Fn: NewRand}, sql.FunctionN{Name: "regexp_instr", Fn: NewRegexpInstr}, diff --git a/sql/expression/function/string.go b/sql/expression/function/string.go index bb5aae426f..73f72577b4 100644 --- a/sql/expression/function/string.go +++ b/sql/expression/function/string.go @@ -15,6 +15,7 @@ package function import ( + "bytes" "encoding/hex" "fmt" "math" @@ -614,3 +615,55 @@ func (h *Bitlength) WithChildren(children ...sql.Expression) (sql.Expression, er } return NewBitlength(children[0]), nil } + +type Quote struct { + *UnaryFunc +} + +var _ sql.FunctionExpression = (*Bitlength)(nil) +var _ sql.CollationCoercible = (*Bitlength)(nil) + +func NewQuote(arg sql.Expression) sql.Expression { + return &Quote{UnaryFunc: NewUnaryFunc(arg, "QUOTE", types.Text)} +} + +func (q *Quote) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + arg, err := q.EvalChild(ctx, row) + if err != nil { + return nil, err + } + + val, _, err := types.Blob.Convert(ctx, arg) + if err != nil { + return nil, err + } + if val == nil { + return nil, nil + } + valBytes := val.([]byte) + + ret := new(bytes.Buffer) + ret.WriteByte('\'') + for _, c := range valBytes { + switch c { + // '\032' is CTRL+Z character + case '\\', '\'', '\032': + ret.WriteByte('\\') + ret.WriteByte(c) + case '\000': + ret.WriteByte('\\') + ret.WriteByte('0') + default: + ret.WriteByte(c) + } + } + ret.WriteByte('\'') + return ret.String(), nil +} + +func (q *Quote) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(q, len(children), 1) + } + return NewQuote(children[0]), nil +} diff --git a/sql/expression/function/string_test.go b/sql/expression/function/string_test.go index f4d968482e..9c255c07ba 100644 --- a/sql/expression/function/string_test.go +++ b/sql/expression/function/string_test.go @@ -163,3 +163,15 @@ func TestBitLength(t *testing.T) { tf.AddSucceeding(128, time.Now()) tf.Test(t, nil, nil) } + +func TestQuote(t *testing.T) { + f := sql.Function1{Name: "quote", Fn: NewQuote} + tf := NewTestFactory(f.Fn) + tf.AddSucceeding(nil, nil) + tf.AddSucceeding("'test'", "test") + tf.AddSucceeding("'0'", false) + tf.AddSucceeding("'1'", true) + tf.AddSucceeding("'12345'", 12345) + tf.AddSucceeding("'\\\\, \\', \\0, \\\032'", "\\, ', \000, \032") + tf.Test(t, nil, nil) +}