Skip to content

Commit c038a45

Browse files
ivanauthclaude
andcommitted
fix: prevent race in ConcurrentBatch when a batch fails with workers=1
The test TestConcurrentBatchWhenOneBatchFailsAndWorkersIsOne was flaky because there was a race between the semaphore release and errgroup's context cancellation. When a batch failed, it would release the semaphore (via defer) before errgroup cancelled the context, allowing the next iteration's sem.Acquire to succeed and execute an additional batch. Fix this by using an atomic.Bool flag that is set immediately when a batch returns an error. The main loop checks this flag after acquiring the semaphore, ensuring no new batches are launched after a failure. Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 85c0916 commit c038a45

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

internal/grpcutil/batch.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"runtime"
7+
"sync/atomic"
78

89
"golang.org/x/sync/errgroup"
910
"golang.org/x/sync/semaphore"
@@ -47,6 +48,7 @@ func ConcurrentBatch(ctx context.Context, n int, batchSize int, maxWorkers int,
4748
maxWorkers = runtime.GOMAXPROCS(0)
4849
}
4950

51+
var failed atomic.Bool
5052
sem := semaphore.NewWeighted(int64(maxWorkers))
5153
g, ctx := errgroup.WithContext(ctx)
5254
numBatches := (n + batchSize - 1) / batchSize
@@ -55,12 +57,24 @@ func ConcurrentBatch(ctx context.Context, n int, batchSize int, maxWorkers int,
5557
break
5658
}
5759

60+
// After acquiring the semaphore, check whether a previous batch
61+
// has already failed. This handles the race where a failing batch
62+
// releases the semaphore before the errgroup cancels the context.
63+
if failed.Load() {
64+
sem.Release(1)
65+
break
66+
}
67+
5868
batchNum := i
5969
g.Go(func() error {
6070
defer sem.Release(1)
6171
start := batchNum * batchSize
6272
end := minimum(start+batchSize, n)
63-
return each(ctx, batchNum, start, end)
73+
err := each(ctx, batchNum, start, end)
74+
if err != nil {
75+
failed.Store(true)
76+
}
77+
return err
6478
})
6579
}
6680
return g.Wait()

0 commit comments

Comments
 (0)