Skip to content

[WIP] Version 1.0 #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions anyvalue.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"slices"
"unicode/utf8"
)

Expand All @@ -14,23 +15,24 @@ var (
_ fmt.GoStringer = AnyValue{}
)

// AnyValue wraps a driver.Value and is useful for
// AnyValue can hold any value and is useful for
// generic code that can handle unknown column types.
//
// AnyValue implements the following interfaces:
// database/sql.Scanner
// database/sql/driver.Valuer
// fmt.Stringer
// fmt.GoStringer
// - database/sql.Scanner
// - database/sql/driver.Valuer
// - fmt.Stringer
// - fmt.GoStringer
//
// When scanned, Val can have one of the following underlying types:
// int64
// float64
// bool
// []byte
// string
// time.Time
// nil - for NULL values
// When scanned with the Scan method
// Val will have one of the following types:
// - int64
// - float64
// - bool
// - []byte
// - string
// - time.Time
// - nil (for SQL NULL values)
type AnyValue struct {
Val any
}
Expand All @@ -39,7 +41,7 @@ type AnyValue struct {
func (any *AnyValue) Scan(val any) error {
if b, ok := val.([]byte); ok {
// Copy bytes because they won't be valid after this method call
any.Val = append([]byte(nil), b...)
any.Val = slices.Clone(b)
} else {
any.Val = val
}
Expand All @@ -52,7 +54,7 @@ func (any AnyValue) Value() (driver.Value, error) {
}

// String returns the value formatted as string using fmt.Sprint
// except when it's of type []byte and valid UTF-8,
// except when it is of type []byte and valid UTF-8,
// then it is directly converted into a string.
func (any AnyValue) String() string {
if b, ok := any.Val.([]byte); ok && utf8.Valid(b) {
Expand All @@ -64,7 +66,7 @@ func (any AnyValue) String() string {
// GoString returns a Go representation of the wrapped value.
func (any AnyValue) GoString() string {
if b, ok := any.Val.([]byte); ok && utf8.Valid(b) {
return fmt.Sprintf("[]byte(%q)", b)
return fmt.Sprintf("[]byte(%#v)", string(b))
}
return fmt.Sprintf("%#v", any.Val)
}
114 changes: 103 additions & 11 deletions impl/arrays.go → arrays.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,56 @@
package impl
package sqldb

import (
"database/sql"
"database/sql/driver"
"reflect"

"github.com/lib/pq"
)

func WrapForArray(a interface{}) interface {
driver.Valuer
sql.Scanner
} {
return pq.Array(a)
// func WrapForArray(a any) interface {
// driver.Valuer
// sql.Scanner
// } {
// return pq.Array(a)
// }

type ArrayHandler interface {
AsArrayScanner(dest any) sql.Scanner
}

func MakeArrayScannable(dest []any, arrayHandler ArrayHandler) []any {
if arrayHandler == nil {
return dest
}
var wrappedDest []any
for i, d := range dest {
if ShouldWrapForArrayScanning(reflect.ValueOf(d).Elem()) {
if wrappedDest == nil {
// Allocate new slice for wrapped element
wrappedDest = make([]any, len(dest))
// Copy previous elements
for h := 0; h < i; h++ {
wrappedDest[h] = dest[h]
}
}
wrappedDest[i] = arrayHandler.AsArrayScanner(d)
} else if wrappedDest != nil {
wrappedDest[i] = d
}
}
if wrappedDest != nil {
return wrappedDest
}
return dest
}

func ShouldWrapForArray(v reflect.Value) bool {
func ShouldWrapForArrayScanning(v reflect.Value) bool {
t := v.Type()
if t.Implements(typeOfSQLScanner) {
return false
}
if t.Kind() == reflect.Pointer && !v.IsNil() {
v = v.Elem()
t = v.Type()
}
switch t.Kind() {
case reflect.Slice:
if t.Elem() == typeOfByte {
Expand All @@ -29,6 +63,64 @@ func ShouldWrapForArray(v reflect.Value) bool {
return false
}

// IsSliceOrArray returns true if passed value is a slice or array,
// or a pointer to a slice or array and in case of a slice
// not of type []byte.
func IsSliceOrArray(value any) bool {
if value == nil {
return false
}
v := reflect.ValueOf(value)
if v.Kind() == reflect.Pointer {
if v.IsNil() {
return false
}
v = v.Elem()
}
t := v.Type()
k := t.Kind()
return k == reflect.Slice && t != typeOfByteSlice || k == reflect.Array
}

// IsNonDriverValuerSliceOrArrayType returns true if passed type
// does not implement driver.Valuer and is a slice or array,
// or a pointer to a slice or array and in case of a slice
// not of type []byte.
func IsNonDriverValuerSliceOrArrayType(t reflect.Type) bool {
if t == nil || t.Implements(typeOfDriverValuer) {
return false
}
k := t.Kind()
if k == reflect.Pointer {
t = t.Elem()
k = t.Kind()
}
return k == reflect.Slice && t != typeOfByteSlice || k == reflect.Array
}

// func FormatArrays(args []any) []any {
// var wrappedArgs []any
// for i, arg := range args {
// if ShouldFormatArray(arg) {
// if wrappedArgs == nil {
// // Allocate new slice for wrapped element
// wrappedArgs = make([]any, len(args))
// // Copy previous elements
// for h := 0; h < i; h++ {
// wrappedArgs[h] = args[h]
// }
// }
// wrappedArgs[i], _ = pq.Array(arg).Value()
// } else if wrappedArgs != nil {
// wrappedArgs[i] = arg
// }
// }
// if wrappedArgs != nil {
// return wrappedArgs
// }
// return args
// }

// type ArrayScanner struct {
// Dest reflect.Value
// }
Expand Down Expand Up @@ -81,7 +173,7 @@ func ShouldWrapForArray(v reflect.Value) bool {
// } else {
// newDest = reflect.New(a.Dest.Type()).Elem()
// }
// if reflect.PtrTo(elemType).Implements(typeOfSQLScanner) {
// if reflect.PointerTo(elemType).Implements(typeOfSQLScanner) {
// for i, elemStr := range elems {
// err = newDest.Index(i).Addr().Interface().(sql.Scanner).Scan(elemStr)
// if err != nil {
Expand Down
94 changes: 94 additions & 0 deletions arrays_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package sqldb

import (
"database/sql"
"database/sql/driver"
"encoding/json"
"reflect"
"testing"

"github.com/lib/pq"
"github.com/stretchr/testify/assert"

"github.com/domonda/go-types/nullable"
)

func defaultAsArrayScanner(a any) sql.Scanner {
return pq.Array(a) // TODO replace with own implementation
}

func TestShouldWrapForArrayScanning(t *testing.T) {
tests := []struct {
v reflect.Value
want bool
}{
{v: reflect.ValueOf([]byte(nil)), want: false},
{v: reflect.ValueOf([]byte{}), want: false},
{v: reflect.ValueOf(""), want: false},
{v: reflect.ValueOf(0), want: false},
{v: reflect.ValueOf(json.RawMessage([]byte("null"))), want: false},
{v: reflect.ValueOf(nullable.JSON([]byte("null"))), want: false},
{v: reflect.ValueOf(new(sql.NullInt64)).Elem(), want: false},
{v: reflect.ValueOf(defaultAsArrayScanner([]int{0, 1})), want: false},

{v: reflect.ValueOf(new([3]string)).Elem(), want: true},
{v: reflect.ValueOf(new([]string)).Elem(), want: true},
{v: reflect.ValueOf(new([]sql.NullString)).Elem(), want: true},
}
for _, tt := range tests {
got := ShouldWrapForArrayScanning(tt.v)
assert.Equal(t, tt.want, got)
}
}

func TestIsNonDriverValuerSliceOrArrayType(t *testing.T) {
tests := []struct {
t reflect.Type
want bool
}{
{t: reflect.TypeOf(nil), want: false},
{t: reflect.TypeOf(0), want: false},
{t: reflect.TypeOf(new(int)), want: false},
{t: reflect.TypeOf("string"), want: false},
{t: reflect.TypeOf([]byte("string")), want: false},
{t: reflect.TypeOf(new([]byte)), want: false},
{t: reflect.TypeOf(pq.BoolArray{true}), want: false},
{t: reflect.TypeOf(new(pq.BoolArray)), want: false},
{t: reflect.TypeOf(new(*[]int)), want: false}, // pointer to a pointer to a slice
{t: reflect.TypeOf((*driver.Valuer)(nil)), want: false},
{t: reflect.TypeOf((*driver.Valuer)(nil)).Elem(), want: false},

{t: reflect.TypeOf([3]int{1, 2, 3}), want: true},
{t: reflect.TypeOf((*[3]int)(nil)), want: true},
{t: reflect.TypeOf([]int{1, 2, 3}), want: true},
{t: reflect.TypeOf((*[]int)(nil)), want: true},
{t: reflect.TypeOf((*[][]byte)(nil)), want: true},
}
for _, tt := range tests {
got := IsNonDriverValuerSliceOrArrayType(tt.t)
assert.Equalf(t, tt.want, got, "IsNonDriverValuerSliceOrArrayType(%s)", tt.t)
}
}

// func TestWrapArgsForArrays(t *testing.T) {
// tests := []struct {
// args []any
// want []any
// }{
// {args: nil, want: nil},
// {args: []any{}, want: []any{}},
// {args: []any{0}, want: []any{0}},
// {args: []any{nil}, want: []any{nil}},
// {args: []any{new(int)}, want: []any{new(int)}},
// {args: []any{0, []int{0, 1}, "string"}, want: []any{0, wrapArgForArray([]int{0, 1}), "string"}},
// {args: []any{wrapArgForArray([]int{0, 1})}, want: []any{wrapArgForArray([]int{0, 1})}},
// {args: []any{[]byte("don't wrap []byte")}, want: []any{[]byte("don't wrap []byte")}},
// {args: []any{pq.BoolArray{true}}, want: []any{pq.BoolArray{true}}},
// {args: []any{[3]int{1, 2, 3}}, want: []any{wrapArgForArray([3]int{1, 2, 3})}},
// {args: []any{wrapArgForArray([3]int{1, 2, 3})}, want: []any{wrapArgForArray([3]int{1, 2, 3})}},
// }
// for _, tt := range tests {
// got := WrapArgsForArrays(tt.args)
// assert.Equal(t, tt.want, got)
// }
// }
6 changes: 0 additions & 6 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,12 @@ type Config struct {
//
// If ConnMaxLifetime <= 0, connections are not closed due to a connection's age.
ConnMaxLifetime time.Duration `json:"connMaxLifetime,omitempty"`

DefaultIsolationLevel sql.IsolationLevel `json:"-"`
Err error `json:"-"`
}

// Validate returns Config.Err if it is not nil
// or an error if the Config does not have
// a Driver, Host, or Database.
func (c *Config) Validate() error {
if c.Err != nil {
return c.Err
}
if c.Driver == "" {
return fmt.Errorf("missing sqldb.Config.Driver")
}
Expand Down
Loading