Skip to content

Commit 2ae1676

Browse files
authored
pubsub: Ensure batch flushes on shutdown even if MinBatchSize isn't met (#3543)
1 parent 1e85529 commit 2ae1676

File tree

2 files changed

+51
-10
lines changed

2 files changed

+51
-10
lines changed

pubsub/batcher/batcher.go

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ type Options struct {
9191
// Maximum number of concurrent handlers. Defaults to 1.
9292
MaxHandlers int
9393
// Minimum size of a batch. Defaults to 1.
94+
// May be ignored during shutdown.
9495
MinBatchSize int
9596
// Maximum size of a batch. 0 means no limit.
9697
MaxBatchSize int
@@ -199,26 +200,33 @@ func (b *Batcher) AddNoWait(item any) <-chan error {
199200
b.pending = append(b.pending, waiter{item, c})
200201
if b.nHandlers < b.opts.MaxHandlers {
201202
// If we can start a handler, do so with the item just added and any others that are pending.
202-
batch := b.nextBatch()
203-
if batch != nil {
204-
b.wg.Add(1)
205-
go func() {
206-
b.callHandler(batch)
207-
b.wg.Done()
208-
}()
209-
b.nHandlers++
210-
}
203+
b.handleBatch(b.nextBatch())
211204
}
212205
// If we can't start a handler, then one of the currently running handlers will
213206
// take our item.
214207
return c
215208
}
216209

210+
// Requires b.mu be held.
211+
func (b *Batcher) handleBatch(batch []waiter) {
212+
if len(batch) == 0 {
213+
return
214+
}
215+
b.wg.Add(1)
216+
go func() {
217+
b.callHandler(batch)
218+
b.wg.Done()
219+
}()
220+
b.nHandlers++
221+
}
222+
217223
// nextBatch returns the batch to process, and updates b.pending.
218224
// It returns nil if there's no batch ready for processing.
219225
// b.mu must be held.
220226
func (b *Batcher) nextBatch() []waiter {
221-
if len(b.pending) < b.opts.MinBatchSize {
227+
// If we're not shutting down, respect minimums. If we're shutting down
228+
// though, we ignore minimums to make sure everything is flushed.
229+
if !b.shutdown && len(b.pending) < b.opts.MinBatchSize {
222230
return nil
223231
}
224232

@@ -282,6 +290,12 @@ func (b *Batcher) callHandler(batch []waiter) {
282290
func (b *Batcher) Shutdown() {
283291
b.mu.Lock()
284292
b.shutdown = true
293+
// If there aren't any handlers running, there might be a partial
294+
// batch. Make sure it gets flushed even if it hasn't reached the
295+
// minimums.
296+
if b.nHandlers == 0 {
297+
b.handleBatch(b.nextBatch())
298+
}
285299
b.mu.Unlock()
286300
b.wg.Wait()
287301
}

pubsub/batcher/batcher_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,33 @@ func TestShutdown(t *testing.T) {
271271
}
272272
}
273273

274+
// TestMinBatchSizeFlushesOnShutdown ensures that Shutdown() flushes batches, even if
275+
// the pending count is less than the minimum batch size.
276+
func TestMinBatchSizeFlushesOnShutdown(t *testing.T) {
277+
var got [][]int
278+
279+
batchSize := 3
280+
b := batcher.New(reflect.TypeOf(int(0)), &batcher.Options{MinBatchSize: batchSize}, func(items interface{}) error {
281+
got = append(got, items.([]int))
282+
return nil
283+
})
284+
for i := 0; i < (batchSize - 1); i++ {
285+
b.AddNoWait(i)
286+
}
287+
288+
// Ensure that we've received nothing
289+
if len(got) > 0 {
290+
t.Errorf("got batch unexpectedly: %+v", got)
291+
}
292+
293+
b.Shutdown()
294+
295+
want := [][]int{{0, 1}}
296+
if !cmp.Equal(got, want) {
297+
t.Errorf("got %+v, want %+v on shutdown", got, want)
298+
}
299+
}
300+
274301
func TestItemCanBeInterface(t *testing.T) {
275302
readerType := reflect.TypeOf([]io.Reader{}).Elem()
276303
called := false

0 commit comments

Comments
 (0)