@@ -3,9 +3,15 @@ package schema
33import (
44 "fmt"
55 "strings"
6+ "testing"
67
78 "github.com/apache/arrow-go/v18/arrow"
89 "github.com/apache/arrow-go/v18/arrow/array"
10+ "github.com/apache/arrow-go/v18/arrow/memory"
11+ "github.com/cloudquery/plugin-sdk/v4/types"
12+ "github.com/google/uuid"
13+ "github.com/samber/lo"
14+ "github.com/stretchr/testify/require"
915)
1016
1117func RecordDiff (l arrow.Record , r arrow.Record ) string {
@@ -31,3 +37,128 @@ func RecordDiff(l arrow.Record, r arrow.Record) string {
3137 }
3238 return sb .String ()
3339}
40+
41+ func buildTestRecord (withClientIDValue string ) arrow.Record {
42+ testFields := []arrow.Field {
43+ {Name : "id" , Type : arrow .PrimitiveTypes .Int64 , Nullable : true },
44+ {Name : "name" , Type : arrow .BinaryTypes .String , Nullable : true },
45+ {Name : "value" , Type : arrow .PrimitiveTypes .Float64 , Nullable : true },
46+ {Name : "bool" , Type : arrow .FixedWidthTypes .Boolean , Nullable : true },
47+ {Name : "uuid" , Type : types .UUID , Nullable : true },
48+ }
49+ if withClientIDValue != "" {
50+ testFields = append (testFields , CqClientIDColumn .ToArrowField ())
51+ }
52+ schema := arrow .NewSchema (testFields , nil )
53+
54+ testValuesCount := 10
55+ builders := []array.Builder {
56+ array .NewInt64Builder (memory .DefaultAllocator ),
57+ array .NewStringBuilder (memory .DefaultAllocator ),
58+ array .NewFloat64Builder (memory .DefaultAllocator ),
59+ array .NewBooleanBuilder (memory .DefaultAllocator ),
60+ types .NewUUIDBuilder (memory .DefaultAllocator ),
61+ }
62+ for _ , builder := range builders {
63+ builder .Reserve (testValuesCount )
64+ switch b := builder .(type ) {
65+ case * array.Int64Builder :
66+ for i := range testValuesCount {
67+ b .Append (int64 (i ))
68+ }
69+ case * array.StringBuilder :
70+ for i := range testValuesCount {
71+ b .AppendString (fmt .Sprintf ("test%d" , i ))
72+ }
73+ case * array.Float64Builder :
74+ for i := range testValuesCount {
75+ b .Append (float64 (i ))
76+ }
77+ case * array.BooleanBuilder :
78+ for i := range testValuesCount {
79+ b .Append (i % 2 == 0 )
80+ }
81+ case * types.UUIDBuilder :
82+ for i := range testValuesCount {
83+ b .Append (uuid .NewSHA1 (uuid .NameSpaceURL , []byte (fmt .Sprintf ("test%d" , i ))))
84+ }
85+ }
86+ }
87+ if withClientIDValue != "" {
88+ builder := array .NewStringBuilder (memory .DefaultAllocator )
89+ builder .Reserve (testValuesCount )
90+ for range testValuesCount {
91+ builder .AppendString (withClientIDValue )
92+ }
93+ builders = append (builders , builder )
94+ }
95+ values := lo .Map (builders , func (builder array.Builder , _ int ) arrow.Array {
96+ return builder .NewArray ()
97+ })
98+ return array .NewRecord (schema , values , int64 (testValuesCount ))
99+ }
100+
101+ func TestAddInternalColumnsToRecord (t * testing.T ) {
102+ tests := []struct {
103+ name string
104+ record arrow.Record
105+ cqClientIDValue string
106+ expectedNewColumns int64
107+ }{
108+ {
109+ name : "add _cq_id,_cq_parent_id,_cq_client_id" ,
110+ record : buildTestRecord ("" ),
111+ cqClientIDValue : "new_client_id" ,
112+ expectedNewColumns : 3 ,
113+ },
114+ {
115+ name : "add cq_client_id,cq_id replace existing _cq_client_id" ,
116+ record : buildTestRecord ("existing_client_id" ),
117+ cqClientIDValue : "new_client_id" ,
118+ expectedNewColumns : 2 ,
119+ },
120+ }
121+ for _ , tt := range tests {
122+ t .Run (tt .name , func (t * testing.T ) {
123+ got , err := AddInternalColumnsToRecord (tt .record , tt .cqClientIDValue )
124+ require .NoError (t , err )
125+ require .Equal (t , tt .record .NumRows (), got .NumRows ())
126+ require .Equal (t , tt .record .NumCols ()+ tt .expectedNewColumns , got .NumCols ())
127+
128+ gotSchema := got .Schema ()
129+ cqIDFields := gotSchema .FieldIndices (CqIDColumn .Name )
130+ require .Len (t , cqIDFields , 1 )
131+
132+ cqParentIDFields := gotSchema .FieldIndices (CqParentIDColumn .Name )
133+ require .Len (t , cqParentIDFields , 1 )
134+
135+ cqClientIDFields := gotSchema .FieldIndices (CqClientIDColumn .Name )
136+ require .Len (t , cqClientIDFields , 1 )
137+
138+ cqIDArray := got .Column (cqIDFields [0 ])
139+ require .Equal (t , types .UUID , cqIDArray .DataType ())
140+ require .Equal (t , tt .record .NumRows (), int64 (cqIDArray .Len ()))
141+
142+ cqParentIDArray := got .Column (cqParentIDFields [0 ])
143+ require .Equal (t , types .UUID , cqParentIDArray .DataType ())
144+ require .Equal (t , tt .record .NumRows (), int64 (cqParentIDArray .Len ()))
145+
146+ cqClientIDArray := got .Column (cqClientIDFields [0 ])
147+ require .Equal (t , arrow .BinaryTypes .String , cqClientIDArray .DataType ())
148+ require .Equal (t , tt .record .NumRows (), int64 (cqClientIDArray .Len ()))
149+
150+ for i := range cqIDArray .Len () {
151+ cqID := cqIDArray .GetOneForMarshal (i ).(uuid.UUID )
152+ require .NotEmpty (t , cqID )
153+ }
154+ for i := range cqParentIDArray .Len () {
155+ cqParentID := cqParentIDArray .GetOneForMarshal (i )
156+ require .Nil (t , cqParentID )
157+ }
158+ for i := range cqClientIDArray .Len () {
159+ cqClientID := cqClientIDArray .GetOneForMarshal (i ).(string )
160+ require .Equal (t , tt .cqClientIDValue , cqClientID )
161+ }
162+ })
163+ }
164+ }
0 commit comments