Skip to content

Commit 4611304

Browse files
authored
Merge pull request #86 from vivek-ng/vivek-ng/fix-panic-caching
Panic errors from batch function should not be cached
2 parents a7ede83 + 58f8c20 commit 4611304

File tree

2 files changed

+112
-5
lines changed

2 files changed

+112
-5
lines changed

dataloader.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package dataloader
44

55
import (
66
"context"
7+
"errors"
78
"fmt"
89
"log"
910
"runtime"
@@ -48,6 +49,17 @@ type ResultMany[V any] struct {
4849
Error []error
4950
}
5051

52+
// PanicErrorWrapper wraps the error interface.
53+
// This is used to check if the error is a panic error.
54+
// We should not cache panic errors.
55+
type PanicErrorWrapper struct {
56+
panicError error
57+
}
58+
59+
func (p *PanicErrorWrapper) Error() string {
60+
return p.panicError.Error()
61+
}
62+
5163
// Loader implements the dataloader.Interface.
5264
type Loader[K comparable, V any] struct {
5365
// the batch function to be used by this loader
@@ -219,6 +231,10 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
219231
}
220232
result.mu.RLock()
221233
defer result.mu.RUnlock()
234+
var ev *PanicErrorWrapper
235+
if result.value.Error != nil && errors.As(result.value.Error, &ev) {
236+
l.Clear(ctx, key)
237+
}
222238
return result.value.Data, result.value.Error
223239
}
224240
defer finish(thunk)
@@ -431,7 +447,7 @@ func (b *batcher[K, V]) batch(originalContext context.Context) {
431447

432448
if panicErr != nil {
433449
for _, req := range reqs {
434-
req.channel <- &Result[V]{Error: fmt.Errorf("Panic received in batch function: %v", panicErr)}
450+
req.channel <- &Result[V]{Error: &PanicErrorWrapper{panicError: fmt.Errorf("Panic received in batch function: %v", panicErr)}}
435451
close(req.channel)
436452
}
437453
return

dataloader_test.go

Lines changed: 95 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,30 @@ func TestLoader(t *testing.T) {
5555
}
5656
})
5757

58+
t.Run("test Load Method cache error", func(t *testing.T) {
59+
t.Parallel()
60+
errorCacheLoader, _ := ErrorCacheLoader[string](0)
61+
ctx := context.Background()
62+
futures := []Thunk[string]{}
63+
for i := 0; i < 2; i++ {
64+
futures = append(futures, errorCacheLoader.Load(ctx, strconv.Itoa(i)))
65+
}
66+
67+
for _, f := range futures {
68+
_, err := f()
69+
if err == nil {
70+
t.Error("Error was not propagated")
71+
}
72+
}
73+
nextFuture := errorCacheLoader.Load(ctx, "1")
74+
_, err := nextFuture()
75+
76+
// Normal errors should be cached.
77+
if err == nil {
78+
t.Error("Error from batch function was not cached")
79+
}
80+
})
81+
5882
t.Run("test Load Method Panic Safety in multiple keys", func(t *testing.T) {
5983
t.Parallel()
6084
defer func() {
@@ -63,7 +87,7 @@ func TestLoader(t *testing.T) {
6387
t.Error("Panic Loader's panic should have been handled'")
6488
}
6589
}()
66-
panicLoader, _ := PanicLoader[string](0)
90+
panicLoader, _ := PanicCacheLoader[string](0)
6791
futures := []Thunk[string]{}
6892
ctx := context.Background()
6993
for i := 0; i < 3; i++ {
@@ -75,6 +99,18 @@ func TestLoader(t *testing.T) {
7599
t.Error("Panic was not propagated as an error.")
76100
}
77101
}
102+
103+
futures = []Thunk[string]{}
104+
for i := 0; i < 3; i++ {
105+
futures = append(futures, panicLoader.Load(ctx, strconv.Itoa(1)))
106+
}
107+
108+
for _, f := range futures {
109+
_, err := f()
110+
if err != nil {
111+
t.Error("Panic error from batch function was cached")
112+
}
113+
}
78114
})
79115

80116
t.Run("test LoadMany returns errors", func(t *testing.T) {
@@ -143,13 +179,21 @@ func TestLoader(t *testing.T) {
143179
t.Error("Panic Loader's panic should have been handled'")
144180
}
145181
}()
146-
panicLoader, _ := PanicLoader[string](0)
182+
panicLoader, _ := PanicCacheLoader[string](0)
147183
ctx := context.Background()
148-
future := panicLoader.LoadMany(ctx, []string{"1"})
184+
future := panicLoader.LoadMany(ctx, []string{"1", "2"})
149185
_, errs := future()
150-
if len(errs) < 1 || errs[0].Error() != "Panic received in batch function: Programming error" {
186+
if len(errs) < 2 || errs[0].Error() != "Panic received in batch function: Programming error" {
151187
t.Error("Panic was not propagated as an error.")
152188
}
189+
190+
future = panicLoader.LoadMany(ctx, []string{"1"})
191+
_, errs = future()
192+
193+
if len(errs) > 0 {
194+
t.Error("Panic error from batch function was cached")
195+
}
196+
153197
})
154198

155199
t.Run("test LoadMany method", func(t *testing.T) {
@@ -531,6 +575,53 @@ func PanicLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
531575
}, WithBatchCapacity[K, K](max), withSilentLogger[K, K]())
532576
return panicLoader, &loadCalls
533577
}
578+
579+
func PanicCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
580+
var loadCalls [][]K
581+
panicCacheLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] {
582+
if len(keys) > 1 {
583+
panic("Programming error")
584+
}
585+
586+
returnResult := make([]*Result[K], len(keys))
587+
for idx := range returnResult {
588+
returnResult[idx] = &Result[K]{
589+
keys[0],
590+
nil,
591+
}
592+
}
593+
594+
return returnResult
595+
596+
}, WithBatchCapacity[K, K](max), withSilentLogger[K, K]())
597+
return panicCacheLoader, &loadCalls
598+
}
599+
600+
func ErrorCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
601+
var loadCalls [][]K
602+
errorCacheLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] {
603+
if len(keys) > 1 {
604+
var results []*Result[K]
605+
for _, key := range keys {
606+
results = append(results, &Result[K]{key, fmt.Errorf("this is a test error")})
607+
}
608+
return results
609+
}
610+
611+
returnResult := make([]*Result[K], len(keys))
612+
for idx := range returnResult {
613+
returnResult[idx] = &Result[K]{
614+
keys[0],
615+
nil,
616+
}
617+
}
618+
619+
return returnResult
620+
621+
}, WithBatchCapacity[K, K](max), withSilentLogger[K, K]())
622+
return errorCacheLoader, &loadCalls
623+
}
624+
534625
func BadLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
535626
var mu sync.Mutex
536627
var loadCalls [][]K

0 commit comments

Comments
 (0)