Skip to content

Commit 7fe7c64

Browse files
authored
feat: Add batch timeout support to mixed batch writer (#1055)
1 parent fbde15a commit 7fe7c64

File tree

2 files changed

+193
-61
lines changed

2 files changed

+193
-61
lines changed

writers/mixedbatchwriter/mixedbatchwriter.go

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mixedbatchwriter
22

33
import (
44
"context"
5+
"time"
56

67
"github.com/apache/arrow/go/v13/arrow/util"
78
"github.com/cloudquery/plugin-sdk/v4/message"
@@ -16,11 +17,15 @@ type Client interface {
1617
DeleteStaleBatch(ctx context.Context, messages message.WriteDeleteStales) error
1718
}
1819

20+
type timerFn func(timeout time.Duration) <-chan time.Time
21+
1922
type MixedBatchWriter struct {
2023
client Client
2124
logger zerolog.Logger
2225
batchSize int
2326
batchSizeBytes int
27+
batchTimeout time.Duration
28+
timerFn timerFn
2429
}
2530

2631
// Assert at compile-time that MixedBatchWriter implements the Writer interface
@@ -46,10 +51,22 @@ func WithBatchSizeBytes(size int) Option {
4651
}
4752
}
4853

54+
func WithBatchTimeout(timeout time.Duration) Option {
55+
return func(p *MixedBatchWriter) {
56+
p.batchTimeout = timeout
57+
}
58+
}
59+
60+
func withTimerFn(timer timerFn) Option {
61+
return func(p *MixedBatchWriter) {
62+
p.timerFn = timer
63+
}
64+
}
65+
4966
const (
50-
defaultBatchTimeoutSeconds = 20
51-
defaultBatchSize = 10000
52-
defaultBatchSizeBytes = 5 * 1024 * 1024 // 5 MiB
67+
defaultBatchTimeout = 20 * time.Second
68+
defaultBatchSize = 10000
69+
defaultBatchSizeBytes = 5 * 1024 * 1024 // 5 MiB
5370
)
5471

5572
func New(client Client, opts ...Option) (*MixedBatchWriter, error) {
@@ -58,6 +75,8 @@ func New(client Client, opts ...Option) (*MixedBatchWriter, error) {
5875
logger: zerolog.Nop(),
5976
batchSize: defaultBatchSize,
6077
batchSizeBytes: defaultBatchSizeBytes,
78+
batchTimeout: defaultBatchTimeout,
79+
timerFn: timer,
6180
}
6281
for _, opt := range opts {
6382
opt(c)
@@ -81,6 +100,9 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri
81100
writeFunc: w.client.DeleteStaleBatch,
82101
}
83102
flush := func(msgType writers.MsgType) error {
103+
if msgType == writers.MsgTypeUnset {
104+
return nil
105+
}
84106
switch msgType {
85107
case writers.MsgTypeMigrateTable:
86108
return migrateTable.flush(ctx)
@@ -94,31 +116,42 @@ func (w *MixedBatchWriter) Write(ctx context.Context, msgChan <-chan message.Wri
94116
}
95117
prevMsgType := writers.MsgTypeUnset
96118
var err error
97-
for msg := range msgChan {
98-
msgType := writers.MsgID(msg)
99-
if prevMsgType != writers.MsgTypeUnset && prevMsgType != msgType {
119+
tick := w.timerFn(w.batchTimeout)
120+
loop:
121+
for {
122+
select {
123+
case msg, ok := <-msgChan:
124+
if !ok {
125+
break loop
126+
}
127+
msgType := writers.MsgID(msg)
128+
if prevMsgType != msgType {
129+
if err := flush(prevMsgType); err != nil {
130+
return err
131+
}
132+
}
133+
prevMsgType = msgType
134+
switch v := msg.(type) {
135+
case *message.WriteMigrateTable:
136+
err = migrateTable.append(ctx, v)
137+
case *message.WriteInsert:
138+
err = insert.append(ctx, v)
139+
case *message.WriteDeleteStale:
140+
err = deleteStale.append(ctx, v)
141+
default:
142+
panic("unknown message type")
143+
}
144+
if err != nil {
145+
return err
146+
}
147+
case <-tick:
100148
if err := flush(prevMsgType); err != nil {
101149
return err
102150
}
103-
}
104-
prevMsgType = msgType
105-
switch v := msg.(type) {
106-
case *message.WriteMigrateTable:
107-
err = migrateTable.append(ctx, v)
108-
case *message.WriteInsert:
109-
err = insert.append(ctx, v)
110-
case *message.WriteDeleteStale:
111-
err = deleteStale.append(ctx, v)
112-
default:
113-
panic("unknown message type")
114-
}
115-
if err != nil {
116-
return err
151+
prevMsgType = writers.MsgTypeUnset
152+
tick = w.timerFn(w.batchTimeout)
117153
}
118154
}
119-
if prevMsgType == writers.MsgTypeUnset {
120-
return nil
121-
}
122155
return flush(prevMsgType)
123156
}
124157

@@ -182,3 +215,10 @@ func (m *insertBatchManager) flush(ctx context.Context) error {
182215
m.batch = m.batch[:0]
183216
return nil
184217
}
218+
219+
func timer(timeout time.Duration) <-chan time.Time {
220+
if timeout == 0 {
221+
return nil
222+
}
223+
return time.After(timeout)
224+
}

writers/mixedbatchwriter/mixedbatchwriter_test.go

Lines changed: 130 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package mixedbatchwriter_test
1+
package mixedbatchwriter
22

33
import (
44
"context"
@@ -10,7 +10,7 @@ import (
1010
"github.com/apache/arrow/go/v13/arrow/memory"
1111
"github.com/cloudquery/plugin-sdk/v4/message"
1212
"github.com/cloudquery/plugin-sdk/v4/schema"
13-
"github.com/cloudquery/plugin-sdk/v4/writers/mixedbatchwriter"
13+
"golang.org/x/sync/errgroup"
1414
)
1515

1616
type testMixedBatchClient struct {
@@ -44,11 +44,18 @@ func (c *testMixedBatchClient) DeleteStaleBatch(_ context.Context, messages mess
4444
return nil
4545
}
4646

47-
var _ mixedbatchwriter.Client = (*testMixedBatchClient)(nil)
47+
var _ Client = (*testMixedBatchClient)(nil)
4848

49-
func TestMixedBatchWriter(t *testing.T) {
50-
ctx := context.Background()
49+
type testMessages struct {
50+
migrateTable1 *message.WriteMigrateTable
51+
migrateTable2 *message.WriteMigrateTable
52+
insert1 *message.WriteInsert
53+
insert2 *message.WriteInsert
54+
deleteStale1 *message.WriteDeleteStale
55+
deleteStale2 *message.WriteDeleteStale
56+
}
5157

58+
func getTestMessages() testMessages {
5259
// message to create table1
5360
table1 := &schema.Table{
5461
Name: "table1",
@@ -105,6 +112,18 @@ func TestMixedBatchWriter(t *testing.T) {
105112
SyncTime: time.Now(),
106113
}
107114

115+
return testMessages{
116+
migrateTable1: msgMigrateTable1,
117+
migrateTable2: msgMigrateTable2,
118+
insert1: msgInsertTable1,
119+
insert2: msgInsertTable2,
120+
deleteStale1: msgDeleteStale1,
121+
deleteStale2: msgDeleteStale2,
122+
}
123+
}
124+
125+
func TestMixedBatchWriter(t *testing.T) {
126+
tm := getTestMessages()
108127
testCases := []struct {
109128
name string
110129
messages []message.WriteMessage
@@ -113,64 +132,65 @@ func TestMixedBatchWriter(t *testing.T) {
113132
{
114133
name: "create table, insert, delete stale",
115134
messages: []message.WriteMessage{
116-
msgMigrateTable1,
117-
msgMigrateTable2,
118-
msgInsertTable1,
119-
msgInsertTable2,
120-
msgDeleteStale1,
121-
msgDeleteStale2,
135+
tm.migrateTable1,
136+
tm.migrateTable2,
137+
tm.insert1,
138+
tm.insert2,
139+
tm.deleteStale1,
140+
tm.deleteStale2,
122141
},
123142
wantBatches: [][]message.WriteMessage{
124-
{msgMigrateTable1, msgMigrateTable2},
125-
{msgInsertTable1, msgInsertTable2},
126-
{msgDeleteStale1, msgDeleteStale2},
143+
{tm.migrateTable1, tm.migrateTable2},
144+
{tm.insert1, tm.insert2},
145+
{tm.deleteStale1, tm.deleteStale2},
127146
},
128147
},
129148
{
130149
name: "interleaved messages",
131150
messages: []message.WriteMessage{
132-
msgMigrateTable1,
133-
msgInsertTable1,
134-
msgDeleteStale1,
135-
msgMigrateTable2,
136-
msgInsertTable2,
137-
msgDeleteStale2,
151+
tm.migrateTable1,
152+
tm.insert1,
153+
tm.deleteStale1,
154+
tm.migrateTable2,
155+
tm.insert2,
156+
tm.deleteStale2,
138157
},
139158
wantBatches: [][]message.WriteMessage{
140-
{msgMigrateTable1},
141-
{msgInsertTable1},
142-
{msgDeleteStale1},
143-
{msgMigrateTable2},
144-
{msgInsertTable2},
145-
{msgDeleteStale2},
159+
{tm.migrateTable1},
160+
{tm.insert1},
161+
{tm.deleteStale1},
162+
{tm.migrateTable2},
163+
{tm.insert2},
164+
{tm.deleteStale2},
146165
},
147166
},
148167
{
149168
name: "interleaved messages",
150169
messages: []message.WriteMessage{
151-
msgMigrateTable1,
152-
msgMigrateTable2,
153-
msgInsertTable1,
154-
msgDeleteStale2,
155-
msgInsertTable2,
156-
msgDeleteStale1,
170+
tm.migrateTable1,
171+
tm.migrateTable2,
172+
tm.insert1,
173+
tm.deleteStale2,
174+
tm.insert2,
175+
tm.deleteStale1,
157176
},
158177
wantBatches: [][]message.WriteMessage{
159-
{msgMigrateTable1, msgMigrateTable2},
160-
{msgInsertTable1},
161-
{msgDeleteStale2},
162-
{msgInsertTable2},
163-
{msgDeleteStale1},
178+
{tm.migrateTable1, tm.migrateTable2},
179+
{tm.insert1},
180+
{tm.deleteStale2},
181+
{tm.insert2},
182+
{tm.deleteStale1},
164183
},
165184
},
166185
}
167186

168187
for _, tc := range testCases {
169188
t.Run(tc.name, func(t *testing.T) {
189+
ctx := context.Background()
170190
client := &testMixedBatchClient{
171191
receivedBatches: make([][]message.WriteMessage, 0),
172192
}
173-
wr, err := mixedbatchwriter.New(client)
193+
wr, err := New(client)
174194
if err != nil {
175195
t.Fatal(err)
176196
}
@@ -193,3 +213,75 @@ func TestMixedBatchWriter(t *testing.T) {
193213
})
194214
}
195215
}
216+
217+
func TestMixedBatchWriterTimeout(t *testing.T) {
218+
tm := getTestMessages()
219+
cases := []struct {
220+
name string
221+
messages []message.WriteMessage
222+
wantBatches [][]message.WriteMessage
223+
}{
224+
{
225+
name: "one_message_batches",
226+
messages: []message.WriteMessage{
227+
tm.insert1,
228+
tm.insert2,
229+
},
230+
wantBatches: [][]message.WriteMessage{
231+
{tm.insert1},
232+
{tm.insert2},
233+
},
234+
},
235+
}
236+
for _, tc := range cases {
237+
t.Run(tc.name, func(t *testing.T) {
238+
ctx := context.Background()
239+
client := &testMixedBatchClient{
240+
receivedBatches: make([][]message.WriteMessage, 0),
241+
}
242+
triggerTimeout := make(chan struct{})
243+
wr, err := New(client,
244+
WithBatchSize(1000),
245+
WithBatchSizeBytes(1000000),
246+
withTimerFn(func(_ time.Duration) <-chan time.Time {
247+
c := make(chan time.Time)
248+
go func() {
249+
<-triggerTimeout
250+
c <- time.Now()
251+
}()
252+
return c
253+
}),
254+
)
255+
if err != nil {
256+
t.Fatal(err)
257+
}
258+
ch := make(chan message.WriteMessage)
259+
260+
eg := errgroup.Group{}
261+
eg.Go(func() error {
262+
return wr.Write(ctx, ch)
263+
})
264+
265+
for _, msg := range tc.messages {
266+
ch <- msg
267+
time.Sleep(100 * time.Millisecond)
268+
triggerTimeout <- struct{}{}
269+
time.Sleep(100 * time.Millisecond)
270+
}
271+
close(ch)
272+
err = eg.Wait()
273+
if err != nil {
274+
t.Fatalf("got error %v, want nil", err)
275+
}
276+
277+
if len(client.receivedBatches) != len(tc.wantBatches) {
278+
t.Fatalf("got %d batches, want %d", len(client.receivedBatches), len(tc.wantBatches))
279+
}
280+
for i, wantBatch := range tc.wantBatches {
281+
if len(client.receivedBatches[i]) != len(wantBatch) {
282+
t.Fatalf("got %d messages in batch %d, want %d", len(client.receivedBatches[i]), i, len(wantBatch))
283+
}
284+
}
285+
})
286+
}
287+
}

0 commit comments

Comments
 (0)