diff --git a/dataloader.go b/dataloader.go index 65de9ed..c4ea540 100644 --- a/dataloader.go +++ b/dataloader.go @@ -254,9 +254,6 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { return v } - defer l.batchLock.Unlock() - defer l.cacheLock.Unlock() - thunk := func() (V, error) { <-req.done result := req.result.Load() @@ -294,6 +291,12 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] { } } + // NOTE: It is intended that these are not unlocked with a `defer`. This is due to the `defer finish(thunk)` above. + // There is a locking bug where, if you have a tracer that calls the thunk to read the results, the dataloader runs + // into a deadlock scenario, as `finish` is called before these mutexes are free'd on the same goroutine. + l.batchLock.Unlock() + l.cacheLock.Unlock() + return thunk } diff --git a/dataloader_test.go b/dataloader_test.go index 08692ee..3473e9a 100644 --- a/dataloader_test.go +++ b/dataloader_test.go @@ -153,6 +153,22 @@ func TestLoader(t *testing.T) { } }) + t.Run("test Load method does not create a deadlock mutex condition", func(t *testing.T) { + t.Parallel() + + loader, _ := IDLoader(1, WithTracer[string, string](&TracerWithThunkReading[string, string]{})) + + value, err := loader.Load(context.Background(), "1")() + if err != nil { + t.Error(err.Error()) + } + if value != "1" { + t.Error("load didn't return the right value") + } + + // By this function completing, we confirm that there is not a deadlock condition, else the test will hang + }) + t.Run("test LoadMany returns errors", func(t *testing.T) { t.Parallel() errorLoader, _ := ErrorLoader[string](0) @@ -202,6 +218,26 @@ func TestLoader(t *testing.T) { } }) + t.Run("test LoadMany method does not create a deadlock mutex condition", func(t *testing.T) { + t.Parallel() + + loader, _ := IDLoader(1, WithTracer[string, string](&TracerWithThunkReading[string, string]{})) + + values, errs := loader.LoadMany(context.Background(), []string{"1", "2", "3"})() + for _, err := range errs { + if err != nil { + t.Error(err.Error()) + } + } + for _, value := range values { + if value == "" { + t.Error("unexpected empty value in LoadMany returned") + } + } + + // By this function completing, we confirm that there is not a deadlock condition, else the test will hang + }) + t.Run("test thunkmany does not contain race conditions", func(t *testing.T) { t.Parallel() identityLoader, _ := IDLoader[string](0) @@ -590,7 +626,7 @@ func TestLoader(t *testing.T) { } // test helpers -func IDLoader[K comparable](max int) (*Loader[K, K], *[][]K) { +func IDLoader[K comparable](max int, options ...Option[K, K]) (*Loader[K, K], *[][]K) { var mu sync.Mutex var loadCalls [][]K identityLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] { @@ -602,7 +638,7 @@ func IDLoader[K comparable](max int) (*Loader[K, K], *[][]K) { results = append(results, &Result[K]{key, nil}) } return results - }, WithBatchCapacity[K, K](max)) + }, append([]Option[K, K]{WithBatchCapacity[K, K](max)}, options...)...) return identityLoader, &loadCalls } func BatchOnlyLoader[K comparable](max int) (*Loader[K, K], *[][]K) { @@ -788,6 +824,28 @@ func FaultyLoader[K comparable]() (*Loader[K, K], *[][]K) { return loader, &loadCalls } +type TracerWithThunkReading[K comparable, V any] struct{} + +var _ Tracer[string, struct{}] = (*TracerWithThunkReading[string, struct{}])(nil) + +func (_ *TracerWithThunkReading[K, V]) TraceLoad(ctx context.Context, key K) (context.Context, TraceLoadFinishFunc[V]) { + return ctx, func(thunk Thunk[V]) { + _, _ = thunk() + } +} + +func (_ *TracerWithThunkReading[K, V]) TraceLoadMany(ctx context.Context, keys []K) (context.Context, TraceLoadManyFinishFunc[V]) { + return ctx, func(thunks ThunkMany[V]) { + _, _ = thunks() + } +} + +func (_ *TracerWithThunkReading[K, V]) TraceBatch(ctx context.Context, keys []K) (context.Context, TraceBatchFinishFunc[V]) { + return ctx, func(thunks []*Result[V]) { + // + } +} + /* Benchmarks */