Skip to content

Commit a3dc12d

Browse files
authored
fix(retrieve): Allowing a slice of pointed structs (#6)
Allowing a slice of pointed structs
1 parent da27a9f commit a3dc12d

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

query_builder.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package pagefilter
22

33
import (
4+
"errors"
45
"fmt"
56
"net/http"
67
"reflect"
@@ -9,6 +10,11 @@ import (
910
"github.com/jmoiron/sqlx"
1011
)
1112

13+
var (
14+
// ErrNoDestination is returned when the destination is nil
15+
ErrNoDestination = errors.New("destination is nil")
16+
)
17+
1218
// Paginator is the struct that provides the paging.
1319
type Paginator struct {
1420
db DB
@@ -228,22 +234,31 @@ func (p *Paginator) Pivot() (string, error) {
228234
}
229235

230236
// Retrieve pulls the next page given the pivot point and requires a destination *[]struct to load the data into.
231-
func (p *Paginator) Retrieve(pivot string, dest interface{}) error {
237+
func (p *Paginator) Retrieve(pivot string, dest any) error {
238+
if dest == nil {
239+
return ErrNoDestination
240+
}
241+
232242
// Gracefully locate all the columns to load.
233243
t := reflect.TypeOf(dest)
234244
if t.Kind() != reflect.Ptr {
235245
return fmt.Errorf("unexpected type %s (expected pointer)", t.Kind())
236246
}
237-
if t = t.Elem(); t.Kind() != reflect.Slice {
247+
t = t.Elem()
248+
if t.Kind() != reflect.Slice {
238249
return fmt.Errorf("unexpected type %s (expected slice)", t.Kind())
239250
}
240-
if t = t.Elem(); t.Kind() != reflect.Struct {
241-
return fmt.Errorf("unexpected type %s (expected struct)", t.Kind())
251+
elemType := t.Elem()
252+
if elemType.Kind() == reflect.Ptr {
253+
elemType = elemType.Elem()
254+
}
255+
if elemType.Kind() != reflect.Struct {
256+
return fmt.Errorf("unexpected type %s (expected struct)", elemType.Kind())
242257
}
243258

244259
var cols strings.Builder
245-
for i := 0; i < t.NumField(); i++ {
246-
field := t.Field(i)
260+
for i := 0; i < elemType.NumField(); i++ {
261+
field := elemType.Field(i)
247262
dbTag := field.Tag.Get("db")
248263
switch dbTag {
249264
case "":

0 commit comments

Comments
 (0)