@@ -3,15 +3,18 @@ package scheduler
33import (
44 "context"
55 "fmt"
6+ "strconv"
67 "testing"
78 "time"
89
910 "github.com/apache/arrow-go/v18/arrow"
1011 "github.com/apache/arrow-go/v18/arrow/array"
12+ "github.com/apache/arrow-go/v18/arrow/memory"
1113 "github.com/cloudquery/plugin-sdk/v4/message"
1214 "github.com/cloudquery/plugin-sdk/v4/scalar"
1315 "github.com/cloudquery/plugin-sdk/v4/schema"
1416 "github.com/rs/zerolog"
17+ "github.com/samber/lo"
1518 "github.com/stretchr/testify/assert"
1619 "github.com/stretchr/testify/require"
1720)
@@ -184,9 +187,82 @@ func testTableRelationSuccess() *schema.Table {
184187 }
185188}
186189
190+ const chunkSize = 200
191+
192+ func testTableSuccessWithRowsChunkResolverSendSingleItemToResChan () * schema.Table {
193+ return & schema.Table {
194+ Name : "test_table_success_with_rows_chunk_resolver" ,
195+ Resolver : func (ctx context.Context , meta schema.ClientMeta , parent * schema.Resource , res chan <- any ) error {
196+ for i := range chunkSize {
197+ res <- i
198+ }
199+ return nil
200+ },
201+ PreResourceChunkResolver : & schema.RowsChunkResolver {
202+ ChunkSize : chunkSize ,
203+ RowsResolver : func (ctx context.Context , meta schema.ClientMeta , resourcesChunk []* schema.Resource ) error {
204+ for _ , resource := range resourcesChunk {
205+ resource .Set ("test_column" , strconv .Itoa (resource .Item .(int )))
206+ }
207+ return nil
208+ },
209+ },
210+ Columns : []schema.Column {
211+ {
212+ Name : "test_column" ,
213+ Type : arrow .BinaryTypes .String ,
214+ },
215+ },
216+ }
217+ }
218+
219+ func testTableSuccessWithRowsChunkResolverSendSliceToResChan () * schema.Table {
220+ return & schema.Table {
221+ Name : "test_table_success_with_rows_chunk_resolver" ,
222+ Resolver : func (ctx context.Context , meta schema.ClientMeta , parent * schema.Resource , res chan <- any ) error {
223+ data := make ([]int , chunkSize )
224+ for i := range chunkSize {
225+ data [i ] = i
226+ }
227+ res <- data
228+ return nil
229+ },
230+ PreResourceChunkResolver : & schema.RowsChunkResolver {
231+ ChunkSize : chunkSize ,
232+ RowsResolver : func (ctx context.Context , meta schema.ClientMeta , resourcesChunk []* schema.Resource ) error {
233+ for _ , resource := range resourcesChunk {
234+ resource .Set ("test_column" , strconv .Itoa (resource .Item .(int )))
235+ }
236+ return nil
237+ },
238+ },
239+ Columns : []schema.Column {
240+ {
241+ Name : "test_column" ,
242+ Type : arrow .BinaryTypes .String ,
243+ },
244+ },
245+ }
246+ }
247+
248+ func expectedChunkedResolverData (schema * arrow.Schema ) []arrow.Record {
249+ const rowsPerRecord = 50
250+ data := make ([]arrow.Record , chunkSize / rowsPerRecord )
251+ for i := range data {
252+ builder := array .NewRecordBuilder (memory .DefaultAllocator , schema )
253+ for j := range rowsPerRecord {
254+ builder .Field (0 ).(* array.StringBuilder ).Append (strconv .Itoa (i * rowsPerRecord + j ))
255+ }
256+ record := builder .NewRecord ()
257+ data [i ] = record
258+ }
259+ return data
260+ }
261+
187262type syncTestCase struct {
188263 table * schema.Table
189264 data []scalar.Vector
265+ dataAsRecords []arrow.Record
190266 deterministicCQID bool
191267 err error
192268}
@@ -298,6 +374,14 @@ var syncTestCases = []syncTestCase{
298374 },
299375 },
300376 },
377+ {
378+ table : testTableSuccessWithRowsChunkResolverSendSingleItemToResChan (),
379+ dataAsRecords : expectedChunkedResolverData (testTableSuccessWithRowsChunkResolverSendSingleItemToResChan ().ToArrowSchema ()),
380+ },
381+ {
382+ table : testTableSuccessWithRowsChunkResolverSendSliceToResChan (),
383+ dataAsRecords : expectedChunkedResolverData (testTableSuccessWithRowsChunkResolverSendSliceToResChan ().ToArrowSchema ()),
384+ },
301385}
302386
303387type optionsTestCase struct {
@@ -348,36 +432,24 @@ func testSyncTable(t *testing.T, tc syncTestCase, strategy Strategy, determinist
348432 WithStrategy (strategy ),
349433 }, extra ... )
350434 sc := NewScheduler (opts ... )
351- msgs := make (chan message.SyncMessage , 10 )
435+ // We use a buffer channel to avoid the complexity of testing using a Go routine, we just need the buffer to be big enough to contain all sync messages
436+ msgs := make (chan message.SyncMessage , 500 )
352437 err := sc .Sync (ctx , & c , tables , msgs , WithSyncDeterministicCQID (deterministicCQID ))
353438 require .ErrorIs (t , err , tc .err )
354439 close (msgs )
355440
356- var i int
441+ dataAsRecords := tc .dataAsRecords
442+ if dataAsRecords == nil {
443+ dataAsRecords = lo .Map (tc .data , func (item scalar.Vector , _ int ) arrow.Record {
444+ return item .ToArrowRecord (tc .table .ToArrowSchema ())
445+ })
446+ }
447+
448+ gotRecords := make ([]arrow.Record , 0 )
357449 for msg := range msgs {
358450 switch v := msg .(type ) {
359451 case * message.SyncInsert :
360- record := v .Record
361- rec := tc .data [i ].ToArrowRecord (record .Schema ())
362- if ! array .RecordEqual (rec , record ) {
363- // For records that include CqIDColumn, we can't verify equality because it is generated by the scheduler, unless deterministicCQID is true
364- onlyCqIDInequality := false
365- for col := range rec .Columns () {
366- if ! deterministicCQID && rec .ColumnName (col ) == schema .CqIDColumn .Name {
367- onlyCqIDInequality = true
368- continue
369- }
370- lc := rec .Column (col )
371- rc := record .Column (col )
372- if ! array .Equal (lc , rc ) {
373- onlyCqIDInequality = false
374- }
375- }
376- if ! onlyCqIDInequality {
377- t .Fatalf ("expected at i=%d: %v. got %v" , i , tc .data [i ], record )
378- }
379- }
380- i ++
452+ gotRecords = append (gotRecords , v .Record )
381453 case * message.SyncMigrateTable :
382454 migratedTable := v .Table
383455
@@ -402,8 +474,47 @@ func testSyncTable(t *testing.T, tc syncTestCase, strategy Strategy, determinist
402474 t .Fatalf ("expected insert message. got %T" , msg )
403475 }
404476 }
405- if len (tc .data ) != i {
406- t .Fatalf ("expected %d resources. got %d" , len (tc .data ), i )
477+
478+ // We do this since the SDK can batch rows into a single record, so we need to compare them as single row records
479+ slicedExpectedRecords := make ([]arrow.Record , 0 )
480+ for _ , record := range dataAsRecords {
481+ for j := int64 (0 ); j < record .NumRows (); j ++ {
482+ slicedRecord := record .NewSlice (j , j + 1 )
483+ slicedExpectedRecords = append (slicedExpectedRecords , slicedRecord )
484+ }
485+ }
486+ gotSlicedRecords := make ([]arrow.Record , 0 )
487+ for _ , record := range gotRecords {
488+ for j := int64 (0 ); j < record .NumRows (); j ++ {
489+ slicedRecord := record .NewSlice (j , j + 1 )
490+ gotSlicedRecords = append (gotSlicedRecords , slicedRecord )
491+ }
492+ }
493+ if len (slicedExpectedRecords ) != len (gotSlicedRecords ) {
494+ t .Fatalf ("expected %d rows. got %d" , len (slicedExpectedRecords ), len (gotSlicedRecords ))
495+ }
496+
497+ for _ , expectedRecord := range slicedExpectedRecords {
498+ // Records can be returned in any order, so we need to find the matching record
499+ _ , found := lo .Find (gotSlicedRecords , func (gotRecord arrow.Record ) bool {
500+ if deterministicCQID {
501+ return array .RecordEqual (gotRecord , expectedRecord )
502+ }
503+ for col := range gotRecord .Columns () {
504+ // skip equality check for random CQID values
505+ if gotRecord .ColumnName (col ) == schema .CqIDColumn .Name {
506+ continue
507+ }
508+ if ! array .Equal (gotRecord .Column (col ), expectedRecord .Column (col )) {
509+ return false
510+ }
511+ }
512+ return true
513+ })
514+
515+ if ! found {
516+ t .Fatalf ("expected record %v not found" , expectedRecord )
517+ }
407518 }
408519}
409520
0 commit comments