1- package mixedbatchwriter_test
1+ package mixedbatchwriter
22
33import (
44 "context"
@@ -10,7 +10,7 @@ import (
1010 "github.com/apache/arrow/go/v13/arrow/memory"
1111 "github.com/cloudquery/plugin-sdk/v4/message"
1212 "github.com/cloudquery/plugin-sdk/v4/schema"
13- "github.com/cloudquery/plugin-sdk/v4/writers/mixedbatchwriter "
13+ "golang.org/x/sync/errgroup "
1414)
1515
1616type testMixedBatchClient struct {
@@ -44,11 +44,18 @@ func (c *testMixedBatchClient) DeleteStaleBatch(_ context.Context, messages mess
4444 return nil
4545}
4646
47- var _ mixedbatchwriter. Client = (* testMixedBatchClient )(nil )
47+ var _ Client = (* testMixedBatchClient )(nil )
4848
49- func TestMixedBatchWriter (t * testing.T ) {
50- ctx := context .Background ()
49+ type testMessages struct {
50+ migrateTable1 * message.WriteMigrateTable
51+ migrateTable2 * message.WriteMigrateTable
52+ insert1 * message.WriteInsert
53+ insert2 * message.WriteInsert
54+ deleteStale1 * message.WriteDeleteStale
55+ deleteStale2 * message.WriteDeleteStale
56+ }
5157
58+ func getTestMessages () testMessages {
5259 // message to create table1
5360 table1 := & schema.Table {
5461 Name : "table1" ,
@@ -105,6 +112,18 @@ func TestMixedBatchWriter(t *testing.T) {
105112 SyncTime : time .Now (),
106113 }
107114
115+ return testMessages {
116+ migrateTable1 : msgMigrateTable1 ,
117+ migrateTable2 : msgMigrateTable2 ,
118+ insert1 : msgInsertTable1 ,
119+ insert2 : msgInsertTable2 ,
120+ deleteStale1 : msgDeleteStale1 ,
121+ deleteStale2 : msgDeleteStale2 ,
122+ }
123+ }
124+
125+ func TestMixedBatchWriter (t * testing.T ) {
126+ tm := getTestMessages ()
108127 testCases := []struct {
109128 name string
110129 messages []message.WriteMessage
@@ -113,64 +132,65 @@ func TestMixedBatchWriter(t *testing.T) {
113132 {
114133 name : "create table, insert, delete stale" ,
115134 messages : []message.WriteMessage {
116- msgMigrateTable1 ,
117- msgMigrateTable2 ,
118- msgInsertTable1 ,
119- msgInsertTable2 ,
120- msgDeleteStale1 ,
121- msgDeleteStale2 ,
135+ tm . migrateTable1 ,
136+ tm . migrateTable2 ,
137+ tm . insert1 ,
138+ tm . insert2 ,
139+ tm . deleteStale1 ,
140+ tm . deleteStale2 ,
122141 },
123142 wantBatches : [][]message.WriteMessage {
124- {msgMigrateTable1 , msgMigrateTable2 },
125- {msgInsertTable1 , msgInsertTable2 },
126- {msgDeleteStale1 , msgDeleteStale2 },
143+ {tm . migrateTable1 , tm . migrateTable2 },
144+ {tm . insert1 , tm . insert2 },
145+ {tm . deleteStale1 , tm . deleteStale2 },
127146 },
128147 },
129148 {
130149 name : "interleaved messages" ,
131150 messages : []message.WriteMessage {
132- msgMigrateTable1 ,
133- msgInsertTable1 ,
134- msgDeleteStale1 ,
135- msgMigrateTable2 ,
136- msgInsertTable2 ,
137- msgDeleteStale2 ,
151+ tm . migrateTable1 ,
152+ tm . insert1 ,
153+ tm . deleteStale1 ,
154+ tm . migrateTable2 ,
155+ tm . insert2 ,
156+ tm . deleteStale2 ,
138157 },
139158 wantBatches : [][]message.WriteMessage {
140- {msgMigrateTable1 },
141- {msgInsertTable1 },
142- {msgDeleteStale1 },
143- {msgMigrateTable2 },
144- {msgInsertTable2 },
145- {msgDeleteStale2 },
159+ {tm . migrateTable1 },
160+ {tm . insert1 },
161+ {tm . deleteStale1 },
162+ {tm . migrateTable2 },
163+ {tm . insert2 },
164+ {tm . deleteStale2 },
146165 },
147166 },
148167 {
149168 name : "interleaved messages" ,
150169 messages : []message.WriteMessage {
151- msgMigrateTable1 ,
152- msgMigrateTable2 ,
153- msgInsertTable1 ,
154- msgDeleteStale2 ,
155- msgInsertTable2 ,
156- msgDeleteStale1 ,
170+ tm . migrateTable1 ,
171+ tm . migrateTable2 ,
172+ tm . insert1 ,
173+ tm . deleteStale2 ,
174+ tm . insert2 ,
175+ tm . deleteStale1 ,
157176 },
158177 wantBatches : [][]message.WriteMessage {
159- {msgMigrateTable1 , msgMigrateTable2 },
160- {msgInsertTable1 },
161- {msgDeleteStale2 },
162- {msgInsertTable2 },
163- {msgDeleteStale1 },
178+ {tm . migrateTable1 , tm . migrateTable2 },
179+ {tm . insert1 },
180+ {tm . deleteStale2 },
181+ {tm . insert2 },
182+ {tm . deleteStale1 },
164183 },
165184 },
166185 }
167186
168187 for _ , tc := range testCases {
169188 t .Run (tc .name , func (t * testing.T ) {
189+ ctx := context .Background ()
170190 client := & testMixedBatchClient {
171191 receivedBatches : make ([][]message.WriteMessage , 0 ),
172192 }
173- wr , err := mixedbatchwriter . New (client )
193+ wr , err := New (client )
174194 if err != nil {
175195 t .Fatal (err )
176196 }
@@ -193,3 +213,75 @@ func TestMixedBatchWriter(t *testing.T) {
193213 })
194214 }
195215}
216+
217+ func TestMixedBatchWriterTimeout (t * testing.T ) {
218+ tm := getTestMessages ()
219+ cases := []struct {
220+ name string
221+ messages []message.WriteMessage
222+ wantBatches [][]message.WriteMessage
223+ }{
224+ {
225+ name : "one_message_batches" ,
226+ messages : []message.WriteMessage {
227+ tm .insert1 ,
228+ tm .insert2 ,
229+ },
230+ wantBatches : [][]message.WriteMessage {
231+ {tm .insert1 },
232+ {tm .insert2 },
233+ },
234+ },
235+ }
236+ for _ , tc := range cases {
237+ t .Run (tc .name , func (t * testing.T ) {
238+ ctx := context .Background ()
239+ client := & testMixedBatchClient {
240+ receivedBatches : make ([][]message.WriteMessage , 0 ),
241+ }
242+ triggerTimeout := make (chan struct {})
243+ wr , err := New (client ,
244+ WithBatchSize (1000 ),
245+ WithBatchSizeBytes (1000000 ),
246+ withTimerFn (func (_ time.Duration ) <- chan time.Time {
247+ c := make (chan time.Time )
248+ go func () {
249+ <- triggerTimeout
250+ c <- time .Now ()
251+ }()
252+ return c
253+ }),
254+ )
255+ if err != nil {
256+ t .Fatal (err )
257+ }
258+ ch := make (chan message.WriteMessage )
259+
260+ eg := errgroup.Group {}
261+ eg .Go (func () error {
262+ return wr .Write (ctx , ch )
263+ })
264+
265+ for _ , msg := range tc .messages {
266+ ch <- msg
267+ time .Sleep (100 * time .Millisecond )
268+ triggerTimeout <- struct {}{}
269+ time .Sleep (100 * time .Millisecond )
270+ }
271+ close (ch )
272+ err = eg .Wait ()
273+ if err != nil {
274+ t .Fatalf ("got error %v, want nil" , err )
275+ }
276+
277+ if len (client .receivedBatches ) != len (tc .wantBatches ) {
278+ t .Fatalf ("got %d batches, want %d" , len (client .receivedBatches ), len (tc .wantBatches ))
279+ }
280+ for i , wantBatch := range tc .wantBatches {
281+ if len (client .receivedBatches [i ]) != len (wantBatch ) {
282+ t .Fatalf ("got %d messages in batch %d, want %d" , len (client .receivedBatches [i ]), i , len (wantBatch ))
283+ }
284+ }
285+ })
286+ }
287+ }
0 commit comments