Skip to content

Commit 8d07a66

Browse files
committed
reflectStructColumnPointers uses pq.Array to wrap slices and arrays
1 parent 9bcabeb commit 8d07a66

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

impl/foreachrow.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ var (
1515
typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem()
1616
typeOfSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
1717
typeOfTime = reflect.TypeOf(time.Time{})
18+
typeOfByteSlice = reflect.TypeOf((*[]byte)(nil)).Elem()
1819
)
1920

2021
// ForEachRowCallFunc will call the passed callback with scanned values or a struct for every row.

impl/reflectstruct.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"reflect"
77
"strings"
88

9+
"github.com/lib/pq"
910
"golang.org/x/exp/slices"
1011

1112
"github.com/domonda/go-sqldb"
@@ -73,12 +74,10 @@ func ReflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel
7374
}
7475

7576
func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFieldMapper, columns []string, pointers []any) error {
76-
var (
77-
structType = structVal.Type()
78-
)
77+
structType := structVal.Type()
7978
for i := 0; i < structType.NumField(); i++ {
80-
fieldType := structType.Field(i)
81-
_, column, _, use := namer.MapStructField(fieldType)
79+
field := structType.Field(i)
80+
_, column, _, use := namer.MapStructField(field)
8281
if !use {
8382
continue
8483
}
@@ -99,10 +98,16 @@ func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel
9998
}
10099

101100
if pointers[colIndex] != nil {
102-
return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, fieldType.Name, structType)
101+
return fmt.Errorf("duplicate mapped column %s onto field %s of struct %s", column, field.Name, structType)
103102
}
104103

105-
pointers[colIndex] = fieldValue.Addr().Interface()
104+
pointer := fieldValue.Addr().Interface()
105+
// If field is a slice or array that does not implement sql.Scanner
106+
// then wrap it with pq.Array to make it scannable
107+
if k := field.Type.Kind(); (k == reflect.Slice || k == reflect.Array) && field.Type != typeOfByteSlice && !fieldValue.Addr().Type().Implements(typeOfSQLScanner) {
108+
pointer = pq.Array(pointer)
109+
}
110+
pointers[colIndex] = pointer
106111
}
107112
return nil
108113
}

0 commit comments

Comments
 (0)