Skip to content
This repository was archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.

Commit 0f2b0dd

Browse files
committed
add quote policy and quote mod support
1 parent c5ee68f commit 0f2b0dd

22 files changed

+340
-191
lines changed

engine.go

Lines changed: 8 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ type Engine struct {
5555
cacherLock sync.RWMutex
5656

5757
defaultContext context.Context
58+
59+
quotePolicy QuotePolicy
60+
quoteMode QuoteMode
5861
}
5962

6063
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {
@@ -175,64 +178,6 @@ func (engine *Engine) SupportInsertMany() bool {
175178
return engine.dialect.SupportInsertMany()
176179
}
177180

178-
func (engine *Engine) quoteColumns(columnStr string) string {
179-
columns := strings.Split(columnStr, ",")
180-
for i := 0; i < len(columns); i++ {
181-
columns[i] = engine.Quote(strings.TrimSpace(columns[i]))
182-
}
183-
return strings.Join(columns, ",")
184-
}
185-
186-
// Quote Use QuoteStr quote the string sql
187-
func (engine *Engine) Quote(value string) string {
188-
value = strings.TrimSpace(value)
189-
if len(value) == 0 {
190-
return value
191-
}
192-
193-
buf := strings.Builder{}
194-
engine.QuoteTo(&buf, value)
195-
196-
return buf.String()
197-
}
198-
199-
// QuoteTo quotes string and writes into the buffer
200-
func (engine *Engine) QuoteTo(buf *strings.Builder, value string) {
201-
if buf == nil {
202-
return
203-
}
204-
205-
value = strings.TrimSpace(value)
206-
if value == "" {
207-
return
208-
}
209-
210-
quotePair := engine.dialect.Quote("")
211-
212-
if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote
213-
_, _ = buf.WriteString(value)
214-
return
215-
} else {
216-
prefix, suffix := quotePair[0], quotePair[1]
217-
218-
_ = buf.WriteByte(prefix)
219-
for i := 0; i < len(value); i++ {
220-
if value[i] == '.' {
221-
_ = buf.WriteByte(suffix)
222-
_ = buf.WriteByte('.')
223-
_ = buf.WriteByte(prefix)
224-
} else {
225-
_ = buf.WriteByte(value[i])
226-
}
227-
}
228-
_ = buf.WriteByte(suffix)
229-
}
230-
}
231-
232-
func (engine *Engine) quote(sql string) string {
233-
return engine.dialect.Quote(sql)
234-
}
235-
236181
// SqlType will be deprecated, please use SQLType instead
237182
//
238183
// Deprecated: use SQLType instead
@@ -467,6 +412,8 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
467412
return err
468413
}
469414

415+
quoter := newQuoter(dialect, engine.quoteMode, engine.quotePolicy)
416+
470417
for i, table := range tables {
471418
if i > 0 {
472419
_, err = io.WriteString(w, "\n")
@@ -486,10 +433,10 @@ func (engine *Engine) dumpTables(tables []*core.Table, w io.Writer, tp ...core.D
486433
}
487434

488435
cols := table.ColumnsSeq()
489-
colNames := engine.dialect.Quote(strings.Join(cols, engine.dialect.Quote(", ")))
490-
destColNames := dialect.Quote(strings.Join(cols, dialect.Quote(", ")))
436+
colNames := quoteJoin(engine, cols)
437+
destColNames := quoteJoin(quoter, cols)
491438

492-
rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.Quote(table.Name))
439+
rows, err := engine.DB().Query("SELECT " + colNames + " FROM " + engine.quote(table.Name, false))
493440
if err != nil {
494441
return err
495442
}

engine_cond.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ func (engine *Engine) buildConds(table *core.Table, bean interface{},
4444
if len(aliasName) > 0 {
4545
nm = aliasName
4646
}
47-
colName = engine.Quote(nm) + "." + engine.Quote(col.Name)
47+
colName = engine.quote(nm, false) + "." + engine.quote(col.Name, true)
4848
} else {
49-
colName = engine.Quote(col.Name)
49+
colName = engine.quote(col.Name, true)
5050
}
5151

5252
fieldValuePtr, err := col.ValueOf(bean)

engine_quote.go

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
// Copyright 2019 The Xorm Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package xorm
6+
7+
import (
8+
"fmt"
9+
"strings"
10+
11+
"xorm.io/core"
12+
)
13+
14+
// QuotePolicy describes quote handle policy
15+
type QuotePolicy int
16+
17+
// All QuotePolicies
18+
const (
19+
QuoteAddAlways QuotePolicy = iota
20+
QuoteNoAdd
21+
QuoteAddReserved
22+
)
23+
24+
// QuoteMode quote on which types
25+
type QuoteMode int
26+
27+
// All QuoteModes
28+
const (
29+
QuoteTableAndColumns QuoteMode = iota
30+
QuoteTableOnly
31+
QuoteColumnsOnly
32+
)
33+
34+
// Quoter represents an object has Quote method
35+
type Quoter interface {
36+
Quotes() (byte, byte)
37+
QuotePolicy() QuotePolicy
38+
QuoteMode() QuoteMode
39+
IsReserved(string) bool
40+
}
41+
42+
type quoter struct {
43+
dialect core.Dialect
44+
quoteMode QuoteMode
45+
quotePolicy QuotePolicy
46+
}
47+
48+
func newQuoter(dialect core.Dialect, quoteMode QuoteMode, quotePolicy QuotePolicy) Quoter {
49+
return &quoter{
50+
dialect: dialect,
51+
quoteMode: quoteMode,
52+
quotePolicy: quotePolicy,
53+
}
54+
}
55+
56+
func (q *quoter) Quotes() (byte, byte) {
57+
quotes := q.dialect.Quote("")
58+
return quotes[0], quotes[1]
59+
}
60+
61+
func (q *quoter) QuotePolicy() QuotePolicy {
62+
return q.quotePolicy
63+
}
64+
65+
func (q *quoter) QuoteMode() QuoteMode {
66+
return q.quoteMode
67+
}
68+
69+
func (q *quoter) IsReserved(value string) bool {
70+
return q.dialect.IsReserved(value)
71+
}
72+
73+
func quoteColumns(quoter Quoter, columnStr string) string {
74+
columns := strings.Split(columnStr, ",")
75+
return quoteJoin(quoter, columns)
76+
}
77+
78+
func quoteJoin(quoter Quoter, columns []string) string {
79+
for i := 0; i < len(columns); i++ {
80+
columns[i] = quote(quoter, columns[i], true)
81+
}
82+
return strings.Join(columns, ",")
83+
}
84+
85+
// quote Use QuoteStr quote the string sql
86+
func quote(quoter Quoter, value string, isColumn bool) string {
87+
buf := strings.Builder{}
88+
quoteTo(quoter, &buf, value, isColumn)
89+
return buf.String()
90+
}
91+
92+
// Quote add quotes to the value
93+
func (engine *Engine) quote(value string, isColumn bool) string {
94+
return quote(engine, value, isColumn)
95+
}
96+
97+
// Quote add quotes to the value
98+
func (engine *Engine) Quote(value string, isColumn bool) string {
99+
return engine.quote(value, isColumn)
100+
}
101+
102+
// Quotes return the left quote and right quote
103+
func (engine *Engine) Quotes() (byte, byte) {
104+
quotes := engine.dialect.Quote("")
105+
return quotes[0], quotes[1]
106+
}
107+
108+
// QuoteMode returns quote mode
109+
func (engine *Engine) QuoteMode() QuoteMode {
110+
return engine.quoteMode
111+
}
112+
113+
// QuotePolicy returns quote policy
114+
func (engine *Engine) QuotePolicy() QuotePolicy {
115+
return engine.quotePolicy
116+
}
117+
118+
// IsReserved return true if the value is a reserved word of the database
119+
func (engine *Engine) IsReserved(value string) bool {
120+
return engine.dialect.IsReserved(value)
121+
}
122+
123+
// quoteTo quotes string and writes into the buffer
124+
func quoteTo(quoter Quoter, buf *strings.Builder, value string, isColumn bool) {
125+
if isColumn {
126+
if quoter.QuoteMode() == QuoteTableAndColumns ||
127+
quoter.QuoteMode() == QuoteColumnsOnly {
128+
if quoter.QuotePolicy() == QuoteAddAlways {
129+
realQuoteTo(quoter, buf, value)
130+
return
131+
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
132+
realQuoteTo(quoter, buf, value)
133+
return
134+
}
135+
}
136+
buf.WriteString(value)
137+
return
138+
}
139+
140+
if quoter.QuoteMode() == QuoteTableAndColumns ||
141+
quoter.QuoteMode() == QuoteTableOnly {
142+
if quoter.QuotePolicy() == QuoteAddAlways {
143+
realQuoteTo(quoter, buf, value)
144+
return
145+
} else if quoter.QuotePolicy() == QuoteAddReserved && quoter.IsReserved(value) {
146+
realQuoteTo(quoter, buf, value)
147+
return
148+
}
149+
}
150+
buf.WriteString(value)
151+
return
152+
}
153+
154+
func realQuoteTo(quoter Quoter, buf *strings.Builder, value string) {
155+
if buf == nil {
156+
return
157+
}
158+
159+
value = strings.TrimSpace(value)
160+
if value == "" {
161+
return
162+
} else if value == "*" {
163+
buf.WriteString("*")
164+
return
165+
}
166+
167+
quoteLeft, quoteRight := quoter.Quotes()
168+
169+
if value[0] == '`' || value[0] == quoteLeft { // no quote
170+
_, _ = buf.WriteString(value)
171+
return
172+
} else {
173+
_ = buf.WriteByte(quoteLeft)
174+
for i := 0; i < len(value); i++ {
175+
if value[i] == '.' {
176+
_ = buf.WriteByte(quoteRight)
177+
_ = buf.WriteByte('.')
178+
_ = buf.WriteByte(quoteLeft)
179+
} else {
180+
_ = buf.WriteByte(value[i])
181+
}
182+
}
183+
_ = buf.WriteByte(quoteRight)
184+
}
185+
}
186+
187+
func unQuote(quoter Quoter, value string) string {
188+
left, right := quoter.Quotes()
189+
return strings.Trim(value, fmt.Sprintf("%v%v`", left, right))
190+
}
191+
192+
func quoteJoinFunc(cols []string, quoteFunc func(string) string, sep string) string {
193+
for i := range cols {
194+
cols[i] = quoteFunc(cols[i])
195+
}
196+
return strings.Join(cols, sep+" ")
197+
}

engine_quote_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright 2019 The Xorm Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package xorm
6+
7+
import (
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestQuoteColumns(t *testing.T) {
14+
cols := []string{"f1", "f2", "f3"}
15+
quoteFunc := func(value string) string {
16+
return "[" + value + "]"
17+
}
18+
19+
assert.EqualValues(t, "[f1], [f2], [f3]", quoteJoinFunc(cols, quoteFunc, ","))
20+
}

engine_table.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,9 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
6363
case []string:
6464
t := tablename.([]string)
6565
if len(t) > 1 {
66-
return fmt.Sprintf("%v AS %v", engine.Quote(t[0]), engine.Quote(t[1]))
66+
return fmt.Sprintf("%v AS %v", engine.quote(t[0], false), engine.quote(t[1], false))
6767
} else if len(t) == 1 {
68-
return engine.Quote(t[0])
68+
return engine.quote(t[0], false)
6969
}
7070
case []interface{}:
7171
t := tablename.([]interface{})
@@ -84,15 +84,15 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
8484
if t.Kind() == reflect.Struct {
8585
table = engine.tbNameForMap(v)
8686
} else {
87-
table = engine.Quote(fmt.Sprintf("%v", f))
87+
table = engine.quote(fmt.Sprintf("%v", f), false)
8888
}
8989
}
9090
}
9191
if l > 1 {
92-
return fmt.Sprintf("%v AS %v", engine.Quote(table),
93-
engine.Quote(fmt.Sprintf("%v", t[1])))
92+
return fmt.Sprintf("%v AS %v", engine.quote(table, false),
93+
engine.quote(fmt.Sprintf("%v", t[1]), false))
9494
} else if l == 1 {
95-
return engine.Quote(table)
95+
return engine.quote(table, false)
9696
}
9797
case TableName:
9898
return tablename.(TableName).TableName()
@@ -107,7 +107,7 @@ func (engine *Engine) tbNameNoSchema(tablename interface{}) string {
107107
if t.Kind() == reflect.Struct {
108108
return engine.tbNameForMap(v)
109109
}
110-
return engine.Quote(fmt.Sprintf("%v", tablename))
110+
return engine.quote(fmt.Sprintf("%v", tablename), false)
111111
}
112112
return ""
113113
}

helpers.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,3 @@ func eraseAny(value string, strToErase ...string) string {
323323

324324
return replacer.Replace(value)
325325
}
326-
327-
func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string {
328-
for i := range cols {
329-
cols[i] = quoteFunc(cols[i])
330-
}
331-
return strings.Join(cols, sep+" ")
332-
}

helpers_test.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,3 @@ func TestEraseAny(t *testing.T) {
1616
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
1717
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
1818
}
19-
20-
func TestQuoteColumns(t *testing.T) {
21-
cols := []string{"f1", "f2", "f3"}
22-
quoteFunc := func(value string) string {
23-
return "[" + value + "]"
24-
}
25-
26-
assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ","))
27-
}

0 commit comments

Comments
 (0)