Skip to content

Commit 767327f

Browse files
authored
fix: Reset timers on flush (#1076)
This updates the ticker logic in the batch writers to reset the ticker when a flush happens. This is better, as it still guarantees that a message won't be delayed by more than batch_timeout, but we don't risk flushing a very small batch because we must flush at regular intervals either. The choice of resetting _after_ the flush is deliberate: it means that the maximum amount of time between flushes is: ``` max_time = flush_time + batch_timeout ``` otherwise we could do: ``` max_time = batch_timeout ``` but if a flush were to then longer than the batch timeout, we can end up in a cycle of flushing again immediately after the previous flush finishes.
1 parent 88f08ee commit 767327f

File tree

7 files changed

+99
-65
lines changed

7 files changed

+99
-65
lines changed

writers/batchwriter/batchwriter.go

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ func (w *BatchWriter) Close(context.Context) error {
122122
func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *message.WriteInsert, flush <-chan chan bool) {
123123
sizeBytes := int64(0)
124124
resources := make([]*message.WriteInsert, 0, w.batchSize)
125-
tick, done := writers.NewTicker(w.batchTimeout)
126-
defer done()
125+
ticker := writers.NewTicker(w.batchTimeout)
126+
defer ticker.Stop()
127127
for {
128128
select {
129129
case r, ok := <-ch:
@@ -136,19 +136,22 @@ func (w *BatchWriter) worker(ctx context.Context, tableName string, ch <-chan *m
136136

137137
if (w.batchSize > 0 && len(resources) >= w.batchSize) || (w.batchSizeBytes > 0 && sizeBytes+util.TotalRecordSize(r.Record) >= int64(w.batchSizeBytes)) {
138138
w.flushTable(ctx, tableName, resources)
139+
ticker.Reset(w.batchTimeout)
139140
resources, sizeBytes = resources[:0], 0
140141
}
141142

142143
resources = append(resources, r)
143144
sizeBytes += util.TotalRecordSize(r.Record)
144-
case <-tick:
145+
case <-ticker.Chan():
145146
if len(resources) > 0 {
146147
w.flushTable(ctx, tableName, resources)
148+
ticker.Reset(w.batchTimeout)
147149
resources, sizeBytes = resources[:0], 0
148150
}
149151
case done := <-flush:
150152
if len(resources) > 0 {
151153
w.flushTable(ctx, tableName, resources)
154+
ticker.Reset(w.batchTimeout)
152155
resources, sizeBytes = resources[:0], 0
153156
}
154157
done <- true
@@ -170,33 +173,6 @@ func (w *BatchWriter) flushTable(ctx context.Context, tableName string, resource
170173
}
171174
}
172175

173-
// func (*BatchWriter) removeDuplicatesByPK(table *schema.Table, resources []*message.Insert) []*message.Insert {
174-
// pkIndices := table.PrimaryKeysIndexes()
175-
// // special case where there's no PK at all
176-
// if len(pkIndices) == 0 {
177-
// return resources
178-
// }
179-
180-
// pks := make(map[string]struct{}, len(resources))
181-
// res := make([]*message.Insert, 0, len(resources))
182-
// for _, r := range resources {
183-
// if r.Record.NumRows() > 1 {
184-
// panic(fmt.Sprintf("record with more than 1 row: %d", r.Record.NumRows()))
185-
// }
186-
// key := pk.String(r.Record)
187-
// _, ok := pks[key]
188-
// if !ok {
189-
// pks[key] = struct{}{}
190-
// res = append(res, r)
191-
// continue
192-
// }
193-
// // duplicate, release
194-
// r.Release()
195-
// }
196-
197-
// return res
198-
// }
199-
200176
func (w *BatchWriter) flushMigrateTables(ctx context.Context) error {
201177
w.migrateTableLock.Lock()
202178
defer w.migrateTableLock.Unlock()

writers/mixedbatchwriter/mixedbatchwriter.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri
114114
}
115115
prevMsgType := writers.MsgTypeUnset
116116
var err error
117-
tick, done := w.tickerFn(w.batchTimeout)
118-
defer done()
117+
ticker := w.tickerFn(w.batchTimeout)
118+
defer ticker.Stop()
119119
loop:
120120
for {
121121
select {
@@ -128,6 +128,7 @@ loop:
128128
if err := flush(prevMsgType); err != nil {
129129
return err
130130
}
131+
ticker.Reset(w.batchTimeout)
131132
}
132133
prevMsgType = msgType
133134
switch v := msg.(type) {
@@ -143,10 +144,11 @@ loop:
143144
if err != nil {
144145
return err
145146
}
146-
case <-tick:
147+
case <-ticker.Chan():
147148
if err := flush(prevMsgType); err != nil {
148149
return err
149150
}
151+
ticker.Reset(w.batchTimeout)
150152
prevMsgType = writers.MsgTypeUnset
151153
}
152154
}

writers/mixedbatchwriter/mixedbatchwriter_test.go

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/apache/arrow/go/v13/arrow/memory"
1111
"github.com/cloudquery/plugin-sdk/v4/message"
1212
"github.com/cloudquery/plugin-sdk/v4/schema"
13+
"github.com/cloudquery/plugin-sdk/v4/writers"
1314
"golang.org/x/sync/errgroup"
1415
)
1516

@@ -214,6 +215,38 @@ func TestMixedBatchWriter(t *testing.T) {
214215
}
215216
}
216217

218+
type mockTicker struct {
219+
C chan time.Time
220+
trigger chan struct{}
221+
}
222+
223+
func (m *mockTicker) Chan() <-chan time.Time {
224+
return m.C
225+
}
226+
227+
func (m *mockTicker) Trigger() chan<- struct{} {
228+
return m.trigger
229+
}
230+
231+
func (m *mockTicker) Stop() {
232+
close(m.C)
233+
}
234+
235+
func (*mockTicker) Reset(_ time.Duration) {}
236+
237+
func newMockTicker(trigger chan struct{}) *mockTicker {
238+
c := make(chan time.Time)
239+
go func() {
240+
for range trigger {
241+
c <- time.Now()
242+
}
243+
}()
244+
return &mockTicker{
245+
C: c,
246+
trigger: trigger,
247+
}
248+
}
249+
217250
func TestMixedBatchWriterTimeout(t *testing.T) {
218251
tm := getTestMessages()
219252
cases := []struct {
@@ -240,17 +273,12 @@ func TestMixedBatchWriterTimeout(t *testing.T) {
240273
receivedBatches: make([][]message.WriteMessage, 0),
241274
}
242275
triggerTimeout := make(chan struct{})
276+
defer close(triggerTimeout)
243277
wr, err := New(client,
244278
WithBatchSize(1000),
245279
WithBatchSizeBytes(1000000),
246-
withTickerFn(func(_ time.Duration) (<-chan time.Time, func()) {
247-
c := make(chan time.Time)
248-
go func() {
249-
for range triggerTimeout {
250-
c <- time.Now()
251-
}
252-
}()
253-
return c, func() { close(c) }
280+
withTickerFn(func(_ time.Duration) writers.Ticker {
281+
return newMockTicker(triggerTimeout)
254282
}),
255283
)
256284
if err != nil {

writers/streamingbatchwriter/mocktimer_test.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,26 @@ import (
66
"github.com/cloudquery/plugin-sdk/v4/writers"
77
)
88

9-
type mockTimer struct {
9+
type mockTicker struct {
1010
expire chan time.Time
1111
}
1212

13-
func (t *mockTimer) timer(time.Duration) (<-chan time.Time, func()) {
14-
return t.expire, t.close
13+
func (t *mockTicker) Stop() {
14+
close(t.expire)
1515
}
1616

17-
func (t *mockTimer) close() {
18-
close(t.expire)
17+
func (*mockTicker) Reset(time.Duration) {}
18+
19+
func (t *mockTicker) Chan() <-chan time.Time {
20+
return t.expire
1921
}
2022

21-
func newMockTimer() (writers.TickerFunc, chan time.Time) {
23+
func newMockTicker() (writers.TickerFunc, chan<- time.Time) {
2224
expire := make(chan time.Time)
23-
t := &mockTimer{
25+
t := &mockTicker{
2426
expire: expire,
2527
}
26-
return t.timer, expire
28+
return func(time.Duration) writers.Ticker {
29+
return t
30+
}, expire
2731
}

writers/streamingbatchwriter/streamingbatchwriter.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,8 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
343343
}
344344
defer closeFlush()
345345

346-
tick, done := s.tickerFn(s.batchTimeout)
347-
defer done()
346+
ticker := s.tickerFn(s.batchTimeout)
347+
defer ticker.Stop()
348348
for {
349349
select {
350350
case r, ok := <-s.ch:
@@ -359,19 +359,21 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
359359

360360
if (s.batchSizeRows > 0 && sizeRows >= s.batchSizeRows) || (s.batchSizeBytes > 0 && sizeBytes+recSize >= s.batchSizeBytes) {
361361
closeFlush()
362+
ticker.Reset(s.batchTimeout)
362363
}
363364

364365
ensureOpened()
365366
clientCh <- r
366367
sizeRows++
367368
sizeBytes += recSize
368-
case <-tick:
369+
case <-ticker.Chan():
369370
if sizeRows > 0 {
370371
closeFlush()
371372
}
372373
case done := <-s.flush:
373374
if sizeRows > 0 {
374375
closeFlush()
376+
ticker.Reset(s.batchTimeout)
375377
}
376378
done <- true
377379
}

writers/streamingbatchwriter/streamingbatchwriter_test.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,9 @@ func TestStreamingBatchTimeout(t *testing.T) {
227227
ch := make(chan message.WriteMessage)
228228

229229
testClient := newClient()
230-
timerFn, timerExpire := newMockTimer()
230+
tickerFn, expire := newMockTicker()
231231

232-
wr, err := New(testClient, withTickerFn(timerFn))
232+
wr, err := New(testClient, withTickerFn(tickerFn))
233233
if err != nil {
234234
t.Fatal(err)
235235
}
@@ -258,7 +258,7 @@ func TestStreamingBatchTimeout(t *testing.T) {
258258
}
259259

260260
// flush
261-
close(timerExpire)
261+
close(expire)
262262
waitForLength(t, testClient.MessageLen, messageTypeInsert, 1)
263263

264264
close(ch)
@@ -332,8 +332,8 @@ func TestStreamingBatchUpserts(t *testing.T) {
332332
ch := make(chan message.WriteMessage)
333333

334334
testClient := newClient()
335-
timerFn, timerExpire := newMockTimer()
336-
wr, err := New(testClient, WithBatchSizeRows(2), withTickerFn(timerFn))
335+
tickerFn, expire := newMockTicker()
336+
wr, err := New(testClient, WithBatchSizeRows(2), withTickerFn(tickerFn))
337337
if err != nil {
338338
t.Fatal(err)
339339
}
@@ -363,7 +363,7 @@ func TestStreamingBatchUpserts(t *testing.T) {
363363
time.Sleep(50 * time.Millisecond)
364364

365365
// flush the batch
366-
close(timerExpire)
366+
close(expire)
367367
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
368368

369369
close(ch)
@@ -379,9 +379,10 @@ func TestStreamingBatchUpserts(t *testing.T) {
379379
func waitForLength(t *testing.T, checkLen func(messageType) int, msgType messageType, want int) {
380380
t.Helper()
381381
lastValue := -1
382+
timeout := time.After(5 * time.Second)
382383
for {
383384
select {
384-
case <-time.After(time.Second):
385+
case <-timeout:
385386
t.Fatalf("timed out waiting for %v message length %d (last value: %d)", msgType, want, lastValue)
386387
default:
387388
if lastValue = checkLen(msgType); lastValue == want {

writers/ticker.go

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,35 @@ import (
44
"time"
55
)
66

7-
type TickerFunc func(interval time.Duration) (ch <-chan time.Time, done func())
7+
type TickerFunc func(time.Duration) Ticker
88

9-
func NewTicker(interval time.Duration) (<-chan time.Time, func()) {
9+
type Ticker interface {
10+
Stop()
11+
Reset(d time.Duration)
12+
Chan() <-chan time.Time
13+
}
14+
15+
func NewTicker(interval time.Duration) Ticker {
1016
if interval <= 0 {
11-
return nil, nop
17+
return nopTicker{}
1218
}
13-
t := time.NewTicker(interval)
14-
return t.C, t.Stop
19+
return &ticker{time.NewTicker(interval)}
20+
}
21+
22+
type ticker struct {
23+
*time.Ticker
24+
}
25+
26+
func (t *ticker) Chan() <-chan time.Time {
27+
return t.C
1528
}
1629

17-
func nop() {}
30+
type nopTicker struct{}
31+
32+
func (nopTicker) Stop() {}
33+
34+
func (nopTicker) Reset(_ time.Duration) {}
35+
36+
func (nopTicker) Chan() <-chan time.Time {
37+
return nil
38+
}

0 commit comments

Comments
 (0)