Skip to content

Commit f05b992

Browse files
committed
added impl.WrapForArray and ShouldWrapForArray
1 parent 82eb6e5 commit f05b992

File tree

6 files changed

+249
-226
lines changed

6 files changed

+249
-226
lines changed

impl/arrays.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
package impl
2+
3+
import (
4+
"database/sql"
5+
"database/sql/driver"
6+
"reflect"
7+
8+
"github.com/lib/pq"
9+
)
10+
11+
func WrapForArray(a interface{}) interface {
12+
driver.Valuer
13+
sql.Scanner
14+
} {
15+
return pq.Array(a)
16+
}
17+
18+
func ShouldWrapForArray(v reflect.Value) bool {
19+
t := v.Type()
20+
switch t.Kind() {
21+
case reflect.Slice:
22+
if t.Elem() == typeOfByte {
23+
return false // Byte slices are scanned as strings
24+
}
25+
return !v.Addr().Type().Implements(typeOfSQLScanner)
26+
case reflect.Array:
27+
return !v.Addr().Type().Implements(typeOfSQLScanner)
28+
}
29+
return false
30+
}
31+
32+
// type ArrayScanner struct {
33+
// Dest reflect.Value
34+
// }
35+
36+
// func MustArrayScanner(destPtr any) sql.Scanner {
37+
// v := reflect.ValueOf(destPtr).Elem()
38+
// if !ShouldWrapForArray(v) {
39+
// panic(fmt.Sprintf("expected pointer to slice or array, got %T", destPtr))
40+
// }
41+
// return &ArrayScanner{Dest: v}
42+
// }
43+
44+
// // Scan implements the sql.Scanner interface.
45+
// func (a *ArrayScanner) Scan(src any) error {
46+
// switch src := src.(type) {
47+
// case []byte:
48+
// return a.scanString(string(src))
49+
// case string:
50+
// return a.scanString(src)
51+
// case nil:
52+
// if a.Dest.Kind() != reflect.Slice {
53+
// return fmt.Errorf("can't scan NULL as %s", a.Dest.Type())
54+
// }
55+
// a.Dest.SetZero()
56+
// return nil
57+
// default:
58+
// return fmt.Errorf("can't scan %T as %s", src, a.Dest.Type())
59+
// }
60+
// }
61+
62+
// func (a *ArrayScanner) scanString(src string) error {
63+
// elems, err := nullable.SplitArray(src)
64+
// if err != nil {
65+
// return err
66+
// }
67+
// destIsSlice := a.Dest.Kind() == reflect.Slice
68+
// if !destIsSlice && len(elems) != a.Dest.Len() {
69+
// return fmt.Errorf("can't scan %d elements into array of length %d", len(elems), a.Dest.Len())
70+
// }
71+
// if destIsSlice && len(elems) == 0 {
72+
// a.Dest.SetZero()
73+
// return nil
74+
// }
75+
// elemType := a.Dest.Type().Elem()
76+
// // allocate new slice or array on heap for scanning
77+
// // only assign after scaning of all elements was successful
78+
// var newDest reflect.Value
79+
// if destIsSlice {
80+
// newDest = reflect.MakeSlice(elemType, len(elems), len(elems))
81+
// } else {
82+
// newDest = reflect.New(a.Dest.Type()).Elem()
83+
// }
84+
// if reflect.PtrTo(elemType).Implements(typeOfSQLScanner) {
85+
// for i, elemStr := range elems {
86+
// err = newDest.Index(i).Addr().Interface().(sql.Scanner).Scan(elemStr)
87+
// if err != nil {
88+
// return fmt.Errorf("can't scan %q as element %d of slice %s because of %w", elemStr, i, elemType, err)
89+
// }
90+
// }
91+
// } else {
92+
// for i, elemStr := range elems {
93+
// // TODO elemStr is a string because we splitted an SQL array string literal, can't scan into an int right now
94+
// err = ScanValue(elemStr, newDest.Index(i))
95+
// if err != nil {
96+
// return fmt.Errorf("can't scan %q as element %d of slice %s because of %w", elemStr, i, elemType, err)
97+
// }
98+
// }
99+
// }
100+
// a.Dest.Set(newDest)
101+
// return nil
102+
// }
103+
104+
// func ScanReflectValue(src any, dest reflect.Value) error {
105+
// if dest.Kind() == reflect.Interface {
106+
// if src != nil {
107+
// dest.Set(reflect.ValueOf(src))
108+
// } else {
109+
// dest.SetZero()
110+
// }
111+
// return nil
112+
// }
113+
114+
// if dest.Addr().Type().Implements(typeOfSQLScanner) {
115+
// return dest.Addr().Interface().(sql.Scanner).Scan(src)
116+
// }
117+
118+
// switch x := src.(type) {
119+
// case int64:
120+
// switch dest.Kind() {
121+
// case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
122+
// dest.SetInt(x)
123+
// return nil
124+
// case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
125+
// dest.SetUint(uint64(x))
126+
// return nil
127+
// case reflect.Float32, reflect.Float64:
128+
// dest.SetFloat(float64(x))
129+
// return nil
130+
// }
131+
132+
// case float64:
133+
// switch dest.Kind() {
134+
// case reflect.Float32, reflect.Float64:
135+
// dest.SetFloat(x)
136+
// return nil
137+
// case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
138+
// dest.SetInt(int64(x))
139+
// return nil
140+
// case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
141+
// dest.SetUint(uint64(x))
142+
// return nil
143+
// }
144+
145+
// case bool:
146+
// dest.SetBool(x)
147+
// return nil
148+
149+
// case []byte:
150+
// switch {
151+
// case dest.Kind() == reflect.String:
152+
// dest.SetString(string(x))
153+
// return nil
154+
// case dest.Kind() == reflect.Slice && dest.Type().Elem().Kind() == reflect.Uint8:
155+
// dest.Set(reflect.ValueOf(x))
156+
// return nil
157+
// }
158+
159+
// case string:
160+
// switch {
161+
// case dest.Kind() == reflect.String:
162+
// dest.SetString(x)
163+
// return nil
164+
// case dest.Type() == typeOfByteSlice:
165+
// dest.Set(reflect.ValueOf([]byte(x)))
166+
// return nil
167+
// }
168+
169+
// case time.Time:
170+
// if srcVal := reflect.ValueOf(src); srcVal.Type().AssignableTo(dest.Type()) {
171+
// dest.Set(srcVal)
172+
// return nil
173+
// }
174+
175+
// case nil:
176+
// switch dest.Kind() {
177+
// case reflect.Ptr, reflect.Slice, reflect.Map:
178+
// dest.SetZero()
179+
// return nil
180+
// }
181+
// }
182+
183+
// return fmt.Errorf("can't scan %#v as %s", src, dest.Type())
184+
// }

impl/reflectstruct_test.go renamed to impl/arrays_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/domonda/go-types/nullable"
1010
)
1111

12-
func Test_shouldWrapArray(t *testing.T) {
12+
func TestShouldWrapForArray(t *testing.T) {
1313
tests := []struct {
1414
v reflect.Value
1515
want bool
@@ -27,7 +27,7 @@ func Test_shouldWrapArray(t *testing.T) {
2727
{v: reflect.ValueOf(new([]sql.NullString)).Elem(), want: true},
2828
}
2929
for _, tt := range tests {
30-
if got := shouldWrapArray(tt.v); got != tt.want {
30+
if got := ShouldWrapForArray(tt.v); got != tt.want {
3131
t.Errorf("shouldWrapArray() = %v, want %v", got, tt.want)
3232
}
3333
}

impl/foreachrow.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ var (
1616
typeOfSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
1717
typeOfTime = reflect.TypeOf(time.Time{})
1818
typeOfByte = reflect.TypeOf(byte(0))
19+
typeOfByteSlice = reflect.TypeOf((*[]byte)(nil)).Elem()
1920
)
2021

2122
// ForEachRowCallFunc will call the passed callback with scanned values or a struct for every row.

impl/reflectstruct.go

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"reflect"
77
"strings"
88

9-
"github.com/lib/pq"
109
"golang.org/x/exp/slices"
1110

1211
"github.com/domonda/go-sqldb"
@@ -104,29 +103,15 @@ func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel
104103
pointer := fieldValue.Addr().Interface()
105104
// If field is a slice or array that does not implement sql.Scanner
106105
// and it's not a string scannable []byte type underneath
107-
// then wrap it with pq.Array to make it scannable
108-
if shouldWrapArray(fieldValue) {
109-
pointer = pq.Array(pointer)
106+
// then wrap it with WrapForArray to make it scannable
107+
if ShouldWrapForArray(fieldValue) {
108+
pointer = WrapForArray(pointer)
110109
}
111110
pointers[colIndex] = pointer
112111
}
113112
return nil
114113
}
115114

116-
func shouldWrapArray(v reflect.Value) bool {
117-
t := v.Type()
118-
switch t.Kind() {
119-
case reflect.Slice:
120-
if t.Elem() == typeOfByte {
121-
return false // Byte slices are scanned as strings
122-
}
123-
return !v.Addr().Type().Implements(typeOfSQLScanner)
124-
case reflect.Array:
125-
return !v.Addr().Type().Implements(typeOfSQLScanner)
126-
}
127-
return false
128-
}
129-
130115
func ignoreColumn(filters []sqldb.ColumnFilter, name string, flags sqldb.FieldFlag, fieldType reflect.StructField, fieldValue reflect.Value) bool {
131116
for _, filter := range filters {
132117
if filter.IgnoreColumn(name, flags, fieldType, fieldValue) {

impl/rowsscanner.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package impl
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
7+
"reflect"
68

79
sqldb "github.com/domonda/go-sqldb"
810
)
@@ -93,3 +95,60 @@ func (s *RowsScanner) ForEachRowCall(callback any) error {
9395
}
9496
return s.ForEachRow(forEachRowFunc)
9597
}
98+
99+
// ScanRowsAsSlice scans all srcRows as slice into dest.
100+
// The rows must either have only one column compatible with the element type of the slice,
101+
// or if multiple columns are returned then the slice element type must me a struct or struction pointer
102+
// so that every column maps on exactly one struct field using structFieldNamer.
103+
// In case of single column rows, nil must be passed for structFieldNamer.
104+
// ScanRowsAsSlice calls srcRows.Close().
105+
func ScanRowsAsSlice(ctx context.Context, srcRows Rows, dest any, structFieldNamer sqldb.StructFieldMapper) error {
106+
defer srcRows.Close()
107+
108+
destVal := reflect.ValueOf(dest)
109+
if destVal.Kind() != reflect.Ptr {
110+
return fmt.Errorf("scan dest is not a pointer but %s", destVal.Type())
111+
}
112+
if destVal.IsNil() {
113+
return errors.New("scan dest is nil")
114+
}
115+
slice := destVal.Elem()
116+
if slice.Kind() != reflect.Slice {
117+
return fmt.Errorf("scan dest is not pointer to slice but %s", destVal.Type())
118+
}
119+
sliceElemType := slice.Type().Elem()
120+
121+
newSlice := reflect.MakeSlice(slice.Type(), 0, 32)
122+
123+
for srcRows.Next() {
124+
if ctx.Err() != nil {
125+
return ctx.Err()
126+
}
127+
128+
newSlice = reflect.Append(newSlice, reflect.Zero(sliceElemType))
129+
target := newSlice.Index(newSlice.Len() - 1).Addr()
130+
if structFieldNamer != nil {
131+
err := ScanStruct(srcRows, target.Interface(), structFieldNamer)
132+
if err != nil {
133+
return err
134+
}
135+
} else {
136+
err := srcRows.Scan(target.Interface())
137+
if err != nil {
138+
return err
139+
}
140+
}
141+
}
142+
if srcRows.Err() != nil {
143+
return srcRows.Err()
144+
}
145+
146+
// Assign newSlice if there were no errors
147+
if newSlice.Len() == 0 {
148+
slice.SetLen(0)
149+
} else {
150+
slice.Set(newSlice)
151+
}
152+
153+
return nil
154+
}

0 commit comments

Comments
 (0)