Skip to content

Commit 7507922

Browse files
committed
support driver.Valuer type in Struct and interpolate methods
1 parent 167e2ad commit 7507922

File tree

4 files changed

+187
-81
lines changed

4 files changed

+187
-81
lines changed

interpolate.go

Lines changed: 103 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
package sqlbuilder
55

66
import (
7+
"database/sql/driver"
78
"fmt"
9+
"reflect"
810
"strconv"
911
"time"
1012
"unicode"
@@ -389,78 +391,13 @@ func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
389391
case nil:
390392
buf = append(buf, "NULL"...)
391393

392-
case bool:
393-
if v {
394-
buf = append(buf, "TRUE"...)
394+
case driver.Valuer:
395+
if val, err := v.Value(); err != nil {
396+
return nil, err
395397
} else {
396-
buf = append(buf, "FALSE"...)
398+
return encodeValue(buf, val, flavor)
397399
}
398400

399-
case int:
400-
buf = strconv.AppendInt(buf, int64(v), 10)
401-
402-
case int8:
403-
buf = strconv.AppendInt(buf, int64(v), 10)
404-
405-
case int16:
406-
buf = strconv.AppendInt(buf, int64(v), 10)
407-
408-
case int32:
409-
buf = strconv.AppendInt(buf, int64(v), 10)
410-
411-
case int64:
412-
buf = strconv.AppendInt(buf, v, 10)
413-
414-
case uint:
415-
buf = strconv.AppendUint(buf, uint64(v), 10)
416-
417-
case uint8:
418-
buf = strconv.AppendUint(buf, uint64(v), 10)
419-
420-
case uint16:
421-
buf = strconv.AppendUint(buf, uint64(v), 10)
422-
423-
case uint32:
424-
buf = strconv.AppendUint(buf, uint64(v), 10)
425-
426-
case uint64:
427-
buf = strconv.AppendUint(buf, v, 10)
428-
429-
case float32:
430-
buf = strconv.AppendFloat(buf, float64(v), 'g', -1, 32)
431-
432-
case float64:
433-
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
434-
435-
case []byte:
436-
if v == nil {
437-
buf = append(buf, "NULL"...)
438-
break
439-
}
440-
441-
switch flavor {
442-
case MySQL:
443-
buf = append(buf, "_binary"...)
444-
buf = quoteStringValue(buf, *(*string)(unsafe.Pointer(&v)), flavor)
445-
446-
case PostgreSQL:
447-
buf = append(buf, "E'\\\\x"...)
448-
buf = appendHex(buf, v)
449-
buf = append(buf, "'::bytea"...)
450-
451-
case SQLite:
452-
buf = append(buf, "X'"...)
453-
buf = appendHex(buf, v)
454-
buf = append(buf, '\'')
455-
456-
case SQLServer:
457-
buf = append(buf, "0x"...)
458-
buf = appendHex(buf, v)
459-
}
460-
461-
case string:
462-
buf = quoteStringValue(buf, v, flavor)
463-
464401
case time.Time:
465402
if v.IsZero() {
466403
buf = append(buf, "'0000-00-00'"...)
@@ -492,7 +429,103 @@ func encodeValue(buf []byte, arg interface{}, flavor Flavor) ([]byte, error) {
492429
buf = quoteStringValue(buf, v.String(), flavor)
493430

494431
default:
495-
return nil, ErrInterpolateUnsupportedArgs
432+
primative := reflect.ValueOf(arg)
433+
434+
switch k := primative.Kind(); k {
435+
case reflect.Bool:
436+
if primative.Bool() {
437+
buf = append(buf, "TRUE"...)
438+
} else {
439+
buf = append(buf, "FALSE"...)
440+
}
441+
442+
case reflect.Int:
443+
buf = strconv.AppendInt(buf, primative.Int(), 10)
444+
445+
case reflect.Int8:
446+
buf = strconv.AppendInt(buf, primative.Int(), 10)
447+
448+
case reflect.Int16:
449+
buf = strconv.AppendInt(buf, primative.Int(), 10)
450+
451+
case reflect.Int32:
452+
buf = strconv.AppendInt(buf, primative.Int(), 10)
453+
454+
case reflect.Int64:
455+
buf = strconv.AppendInt(buf, primative.Int(), 10)
456+
457+
case reflect.Uint:
458+
buf = strconv.AppendUint(buf, primative.Uint(), 10)
459+
460+
case reflect.Uint8:
461+
buf = strconv.AppendUint(buf, primative.Uint(), 10)
462+
463+
case reflect.Uint16:
464+
buf = strconv.AppendUint(buf, primative.Uint(), 10)
465+
466+
case reflect.Uint32:
467+
buf = strconv.AppendUint(buf, primative.Uint(), 10)
468+
469+
case reflect.Uint64:
470+
buf = strconv.AppendUint(buf, primative.Uint(), 10)
471+
472+
case reflect.Float32:
473+
buf = strconv.AppendFloat(buf, primative.Float(), 'g', -1, 32)
474+
475+
case reflect.Float64:
476+
buf = strconv.AppendFloat(buf, primative.Float(), 'g', -1, 64)
477+
478+
case reflect.String:
479+
buf = quoteStringValue(buf, primative.String(), flavor)
480+
481+
case reflect.Slice, reflect.Array:
482+
if k == reflect.Slice && primative.IsNil() {
483+
buf = append(buf, "NULL"...)
484+
break
485+
}
486+
487+
if elem := primative.Type().Elem(); elem.Kind() != reflect.Uint8 {
488+
return nil, ErrInterpolateUnsupportedArgs
489+
}
490+
491+
var data []byte
492+
493+
// Bytes() will panic if primative is an array and cannot be addressed.
494+
// Copy all bytes to data as a fallback.
495+
if k == reflect.Array && !primative.CanAddr() {
496+
l := primative.Len()
497+
data = make([]byte, l)
498+
499+
for i := 0; i < l; i++ {
500+
data[i] = byte(primative.Index(i).Uint())
501+
}
502+
} else {
503+
data = primative.Bytes()
504+
}
505+
506+
switch flavor {
507+
case MySQL:
508+
buf = append(buf, "_binary"...)
509+
buf = quoteStringValue(buf, *(*string)(unsafe.Pointer(&data)), flavor)
510+
511+
case PostgreSQL:
512+
buf = append(buf, "E'\\\\x"...)
513+
buf = appendHex(buf, data)
514+
buf = append(buf, "'::bytea"...)
515+
516+
case SQLite:
517+
buf = append(buf, "X'"...)
518+
buf = appendHex(buf, data)
519+
buf = append(buf, '\'')
520+
521+
case SQLServer:
522+
buf = append(buf, "0x"...)
523+
buf = appendHex(buf, data)
524+
}
525+
526+
default:
527+
return nil, ErrInterpolateUnsupportedArgs
528+
}
496529
}
497530

498531
return buf, nil

interpolate_test.go

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,34 @@
11
package sqlbuilder
22

33
import (
4+
"database/sql/driver"
5+
"errors"
46
"strconv"
57
"testing"
68
"time"
79

810
"github.com/huandu/go-assert"
911
)
1012

13+
type errorValuer int
14+
15+
var ErrErrorValuer = errors.New("error valuer")
16+
17+
func (v errorValuer) Value() (driver.Value, error) {
18+
return 0, ErrErrorValuer
19+
}
20+
1121
func TestFlavorInterpolate(t *testing.T) {
1222
a := assert.New(t)
1323
dt := time.Date(2019, 4, 24, 12, 23, 34, 123456789, time.FixedZone("CST", 8*60*60)) // 2019-04-24 12:23:34.987654321 CST
1424
_, errOutOfRange := strconv.ParseInt("12345678901234567890", 10, 32)
25+
byteArr := [...]byte{'f', 'o', 'o'}
1526
cases := []struct {
16-
flavor Flavor
17-
sql string
18-
args []interface{}
19-
query string
20-
err error
27+
Flavor Flavor
28+
SQL string
29+
Args []interface{}
30+
Query string
31+
Err error
2132
}{
2233
{
2334
MySQL,
@@ -39,6 +50,11 @@ func TestFlavorInterpolate(t *testing.T) {
3950
"SELECT '\\'?', \"\\\"?\", `\\`?`, \\?", []interface{}{MySQL},
4051
"SELECT '\\'?', \"\\\"?\", `\\`?`, \\'MySQL'", nil,
4152
},
53+
{
54+
MySQL,
55+
"SELECT ?", []interface{}{byteArr},
56+
"SELECT _binary'foo'", nil,
57+
},
4258
{
4359
MySQL,
4460
"SELECT ?", nil,
@@ -49,6 +65,16 @@ func TestFlavorInterpolate(t *testing.T) {
4965
"SELECT ?", []interface{}{complex(1, 2)},
5066
"", ErrInterpolateUnsupportedArgs,
5167
},
68+
{
69+
MySQL,
70+
"SELECT ?", []interface{}{[]complex128{complex(1, 2)}},
71+
"", ErrInterpolateUnsupportedArgs,
72+
},
73+
{
74+
MySQL,
75+
"SELECT ?", []interface{}{errorValuer(1)},
76+
"", ErrErrorValuer,
77+
},
5278

5379
{
5480
PostgreSQL,
@@ -141,9 +167,9 @@ func TestFlavorInterpolate(t *testing.T) {
141167

142168
for idx, c := range cases {
143169
a.Use(&idx, &c)
144-
query, err := c.flavor.Interpolate(c.sql, c.args)
170+
query, err := c.Flavor.Interpolate(c.SQL, c.Args)
145171

146-
a.Equal(query, c.query)
147-
a.Assert(err == c.err || err.Error() == c.err.Error())
172+
a.Equal(query, c.Query)
173+
a.Assert(err == c.Err || err.Error() == c.Err.Error())
148174
}
149175
}

struct.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package sqlbuilder
55

66
import (
77
"bytes"
8+
"database/sql/driver"
89
"math"
910
"reflect"
1011
"regexp"
@@ -36,6 +37,8 @@ const (
3637

3738
var optRegex = regexp.MustCompile(`(?P<` + optName + `>\w+)(\((?P<` + optParams + `>.*)\))?`)
3839

40+
var typeOfSQLDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
41+
3942
// Struct represents a struct type.
4043
//
4144
// All methods in Struct are thread-safe.
@@ -179,7 +182,7 @@ func (s *Struct) UpdateForTag(table string, tag string, value interface{}) *Upda
179182
continue
180183
}
181184
} else {
182-
val = dereferencedValue(val)
185+
val = dereferencedFieldValue(val)
183186
}
184187

185188
data := val.Interface()
@@ -237,7 +240,7 @@ func (s *Struct) buildColsAndValuesForTag(ib *InsertBuilder, tag string, value .
237240

238241
for _, item := range value {
239242
v := reflect.ValueOf(item)
240-
v = dereferencedValue(v)
243+
v = dereferencedFieldValue(v)
241244

242245
if v.Type() == s.structType {
243246
vs = append(vs, v)
@@ -265,7 +268,7 @@ func (s *Struct) buildColsAndValuesForTag(ib *InsertBuilder, tag string, value .
265268
nilCnt++
266269
}
267270

268-
val = dereferencedValue(val)
271+
val = dereferencedFieldValue(val)
269272

270273
if val.IsValid() {
271274
values[i] = append(values[i], val.Interface())
@@ -485,6 +488,18 @@ func dereferencedValue(v reflect.Value) reflect.Value {
485488
return v
486489
}
487490

491+
func dereferencedFieldValue(v reflect.Value) reflect.Value {
492+
for k := v.Kind(); k == reflect.Ptr || k == reflect.Interface; k = v.Kind() {
493+
if v.Type().Implements(typeOfSQLDriverValuer) {
494+
break
495+
}
496+
497+
v = v.Elem()
498+
}
499+
500+
return v
501+
}
502+
488503
// isEmptyValue checks if v is zero.
489504
// Following code is borrowed from `IsZero` method in `reflect.Value` since Go 1.13.
490505
func isEmptyValue(v reflect.Value) bool {

struct_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package sqlbuilder
55

66
import (
7+
"database/sql/driver"
78
"fmt"
89
"testing"
910
"time"
@@ -811,6 +812,37 @@ func TestStructFieldAs(t *testing.T) {
811812
a.Equal(sql, `UPDATE t SET t1 = ?, t2 = ?, t4 = ?`)
812813
}
813814

815+
type structImplValuer int
816+
817+
func (v *structImplValuer) Value() (driver.Value, error) {
818+
return *v * 2, nil
819+
}
820+
821+
type structContainsValuer struct {
822+
F1 string
823+
F2 *structImplValuer
824+
}
825+
826+
func TestStructFieldsImplValuer(t *testing.T) {
827+
a := assert.New(t)
828+
st := NewStruct(new(structContainsValuer))
829+
f1 := "foo"
830+
f2 := structImplValuer(100)
831+
832+
sql, args := st.Update("t", structContainsValuer{
833+
F1: f1,
834+
F2: &f2,
835+
}).BuildWithFlavor(MySQL)
836+
837+
a.Equal(sql, "UPDATE t SET F1 = ?, F2 = ?")
838+
a.Equal(args[0], f1)
839+
a.Equal(args[1], &f2)
840+
841+
result, err := MySQL.Interpolate(sql, args)
842+
a.NilError(err)
843+
a.Equal(result, "UPDATE t SET F1 = 'foo', F2 = 200")
844+
}
845+
814846
func SomeOtherMapper(string) string {
815847
return ""
816848
}

0 commit comments

Comments
 (0)