Skip to content

Commit 0261491

Browse files
tonyghitanicksrandall
authored andcommitted
Change LoadMany() to return nil []error with no errors (#32)
1 parent 1445aaf commit 0261491

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

dataloader.go

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,13 @@ func (l *Loader) Load(originalContext context.Context, key interface{}) Thunk {
270270
func (l *Loader) LoadMany(originalContext context.Context, keys []interface{}) ThunkMany {
271271
ctx, finish := l.tracer.TraceLoadMany(originalContext, keys)
272272

273-
length := len(keys)
274-
data := make([]interface{}, length)
275-
errors := make([]error, length)
276-
c := make(chan *ResultMany, 1)
277-
wg := sync.WaitGroup{}
273+
var (
274+
length = len(keys)
275+
data = make([]interface{}, length)
276+
errors = make([]error, length)
277+
c = make(chan *ResultMany, 1)
278+
wg sync.WaitGroup
279+
)
278280

279281
wg.Add(length)
280282
for i := range keys {
@@ -289,7 +291,18 @@ func (l *Loader) LoadMany(originalContext context.Context, keys []interface{}) T
289291

290292
go func() {
291293
wg.Wait()
292-
c <- &ResultMany{data, errors}
294+
295+
// errs is nil unless there exists a non-nil error.
296+
// This prevents dataloader from returning a slice of all-nil errors.
297+
var errs []error
298+
for _, e := range errors {
299+
if e != nil {
300+
errs = errors
301+
break
302+
}
303+
}
304+
305+
c <- &ResultMany{Data: data, Error: errs}
293306
close(c)
294307
}()
295308

dataloader_test.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ func TestLoader(t *testing.T) {
9494
ctx := context.Background()
9595
future := loader.LoadMany(ctx, []interface{}{"1", "2", "3"})
9696
_, err := future()
97-
log.Printf("errs: %#v", err)
9897
if len(err) != 3 {
9998
t.Errorf("LoadMany didn't return right number of errors (should match size of input)")
10099
}
@@ -108,6 +107,16 @@ func TestLoader(t *testing.T) {
108107
}
109108
})
110109

110+
t.Run("test LoadMany returns nil []error when no errors occurred", func(t *testing.T) {
111+
t.Parallel()
112+
loader, _ := IDLoader(0)
113+
ctx := context.Background()
114+
_, err := loader.LoadMany(ctx, []interface{}{"1", "2", "3"})()
115+
if err != nil {
116+
t.Errorf("Expected LoadMany() to return nil error slice when no errors occurred")
117+
}
118+
})
119+
111120
t.Run("test thunkmany does not contain race conditions", func(t *testing.T) {
112121
t.Parallel()
113122
identityLoader, _ := IDLoader(0)
@@ -491,7 +500,7 @@ func OneErrorLoader(max int) (*Loader, *[][]interface{}) {
491500
var mu sync.Mutex
492501
var loadCalls [][]interface{}
493502
identityLoader := NewBatchedLoader(func(_ context.Context, keys []interface{}) []*Result {
494-
results := make([]*Result, max, max)
503+
results := make([]*Result, max)
495504
mu.Lock()
496505
loadCalls = append(loadCalls, keys)
497506
mu.Unlock()

0 commit comments

Comments
 (0)