Skip to content

Commit b77c904

Browse files
authored
Merge pull request #120 from aidenwallis/aiden/fix-locking-bug-in-tracer
fix: locking bug when tracer reads from thunk in Load method
2 parents 7adf3cc + cb197a6 commit b77c904

File tree

2 files changed

+66
-5
lines changed

2 files changed

+66
-5
lines changed

dataloader.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,6 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
254254
return v
255255
}
256256

257-
defer l.batchLock.Unlock()
258-
defer l.cacheLock.Unlock()
259-
260257
thunk := func() (V, error) {
261258
<-req.done
262259
result := req.result.Load()
@@ -294,6 +291,12 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
294291
}
295292
}
296293

294+
// NOTE: It is intended that these are not unlocked with a `defer`. This is due to the `defer finish(thunk)` above.
295+
// There is a locking bug where, if you have a tracer that calls the thunk to read the results, the dataloader runs
296+
// into a deadlock scenario, as `finish` is called before these mutexes are free'd on the same goroutine.
297+
l.batchLock.Unlock()
298+
l.cacheLock.Unlock()
299+
297300
return thunk
298301
}
299302

dataloader_test.go

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,22 @@ func TestLoader(t *testing.T) {
153153
}
154154
})
155155

156+
t.Run("test Load method does not create a deadlock mutex condition", func(t *testing.T) {
157+
t.Parallel()
158+
159+
loader, _ := IDLoader(1, WithTracer[string, string](&TracerWithThunkReading[string, string]{}))
160+
161+
value, err := loader.Load(context.Background(), "1")()
162+
if err != nil {
163+
t.Error(err.Error())
164+
}
165+
if value != "1" {
166+
t.Error("load didn't return the right value")
167+
}
168+
169+
// By this function completing, we confirm that there is not a deadlock condition, else the test will hang
170+
})
171+
156172
t.Run("test LoadMany returns errors", func(t *testing.T) {
157173
t.Parallel()
158174
errorLoader, _ := ErrorLoader[string](0)
@@ -202,6 +218,26 @@ func TestLoader(t *testing.T) {
202218
}
203219
})
204220

221+
t.Run("test LoadMany method does not create a deadlock mutex condition", func(t *testing.T) {
222+
t.Parallel()
223+
224+
loader, _ := IDLoader(1, WithTracer[string, string](&TracerWithThunkReading[string, string]{}))
225+
226+
values, errs := loader.LoadMany(context.Background(), []string{"1", "2", "3"})()
227+
for _, err := range errs {
228+
if err != nil {
229+
t.Error(err.Error())
230+
}
231+
}
232+
for _, value := range values {
233+
if value == "" {
234+
t.Error("unexpected empty value in LoadMany returned")
235+
}
236+
}
237+
238+
// By this function completing, we confirm that there is not a deadlock condition, else the test will hang
239+
})
240+
205241
t.Run("test thunkmany does not contain race conditions", func(t *testing.T) {
206242
t.Parallel()
207243
identityLoader, _ := IDLoader[string](0)
@@ -590,7 +626,7 @@ func TestLoader(t *testing.T) {
590626
}
591627

592628
// test helpers
593-
func IDLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
629+
func IDLoader[K comparable](max int, options ...Option[K, K]) (*Loader[K, K], *[][]K) {
594630
var mu sync.Mutex
595631
var loadCalls [][]K
596632
identityLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] {
@@ -602,7 +638,7 @@ func IDLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
602638
results = append(results, &Result[K]{key, nil})
603639
}
604640
return results
605-
}, WithBatchCapacity[K, K](max))
641+
}, append([]Option[K, K]{WithBatchCapacity[K, K](max)}, options...)...)
606642
return identityLoader, &loadCalls
607643
}
608644
func BatchOnlyLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
@@ -788,6 +824,28 @@ func FaultyLoader[K comparable]() (*Loader[K, K], *[][]K) {
788824
return loader, &loadCalls
789825
}
790826

827+
type TracerWithThunkReading[K comparable, V any] struct{}
828+
829+
var _ Tracer[string, struct{}] = (*TracerWithThunkReading[string, struct{}])(nil)
830+
831+
func (_ *TracerWithThunkReading[K, V]) TraceLoad(ctx context.Context, key K) (context.Context, TraceLoadFinishFunc[V]) {
832+
return ctx, func(thunk Thunk[V]) {
833+
_, _ = thunk()
834+
}
835+
}
836+
837+
func (_ *TracerWithThunkReading[K, V]) TraceLoadMany(ctx context.Context, keys []K) (context.Context, TraceLoadManyFinishFunc[V]) {
838+
return ctx, func(thunks ThunkMany[V]) {
839+
_, _ = thunks()
840+
}
841+
}
842+
843+
func (_ *TracerWithThunkReading[K, V]) TraceBatch(ctx context.Context, keys []K) (context.Context, TraceBatchFinishFunc[V]) {
844+
return ctx, func(thunks []*Result[V]) {
845+
//
846+
}
847+
}
848+
791849
/*
792850
Benchmarks
793851
*/

0 commit comments

Comments
 (0)