Skip to content

Commit 1dea99c

Browse files
authored
feat: Add internal columns helpers (#2105)
#### Summary Related to cloudquery/cloud#4919 (internal issue). Adds some useful helpers to deal with adding CQ internal columns ---
1 parent c27765c commit 1dea99c

File tree

2 files changed

+233
-0
lines changed

2 files changed

+233
-0
lines changed

schema/arrow.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
package schema
22

33
import (
4+
"crypto/sha1"
5+
"time"
6+
47
"github.com/apache/arrow-go/v18/arrow"
8+
"github.com/apache/arrow-go/v18/arrow/array"
9+
"github.com/apache/arrow-go/v18/arrow/memory"
10+
"github.com/cloudquery/plugin-sdk/v4/types"
11+
"github.com/google/uuid"
512
)
613

714
const (
@@ -40,3 +47,98 @@ func (s Schemas) SchemaByName(name string) *arrow.Schema {
4047
}
4148
return nil
4249
}
50+
51+
func hashRecord(record arrow.Record) arrow.Array {
52+
numRows := int(record.NumRows())
53+
fields := record.Schema().Fields()
54+
hashArray := types.NewUUIDBuilder(memory.DefaultAllocator)
55+
hashArray.Reserve(numRows)
56+
for row := range numRows {
57+
rowHash := sha1.New()
58+
for col := 0; col < int(record.NumCols()); col++ {
59+
fieldName := fields[col].Name
60+
rowHash.Write([]byte(fieldName))
61+
value := record.Column(col).ValueStr(row)
62+
_, _ = rowHash.Write([]byte(value))
63+
}
64+
// This part ensures that we conform to the UUID spec
65+
hashArray.Append(uuid.NewSHA1(uuid.NameSpaceURL, rowHash.Sum(nil)))
66+
}
67+
return hashArray.NewArray()
68+
}
69+
70+
func nullUUIDsForRecord(numRows int) arrow.Array {
71+
uuidArray := types.NewUUIDBuilder(memory.DefaultAllocator)
72+
uuidArray.AppendNulls(numRows)
73+
return uuidArray.NewArray()
74+
}
75+
76+
func StringArrayFromValue(value string, nRows int) arrow.Array {
77+
arrayBuilder := array.NewStringBuilder(memory.DefaultAllocator)
78+
arrayBuilder.Reserve(nRows)
79+
for range nRows {
80+
arrayBuilder.AppendString(value)
81+
}
82+
return arrayBuilder.NewArray()
83+
}
84+
85+
func TimestampArrayFromTime(t time.Time, unit arrow.TimeUnit, timeZone string, nRows int) (arrow.Array, error) {
86+
ts, err := arrow.TimestampFromTime(t, unit)
87+
if err != nil {
88+
return nil, err
89+
}
90+
arrayBuilder := array.NewTimestampBuilder(memory.DefaultAllocator, &arrow.TimestampType{Unit: unit, TimeZone: timeZone})
91+
arrayBuilder.Reserve(nRows)
92+
for range nRows {
93+
arrayBuilder.Append(ts)
94+
}
95+
return arrayBuilder.NewArray(), nil
96+
}
97+
98+
func ReplaceFieldInRecord(src arrow.Record, fieldName string, field arrow.Array) (record arrow.Record, err error) {
99+
fieldIndexes := src.Schema().FieldIndices(fieldName)
100+
for i := range fieldIndexes {
101+
record, err = src.SetColumn(fieldIndexes[i], field)
102+
if err != nil {
103+
return nil, err
104+
}
105+
}
106+
return record, nil
107+
}
108+
109+
func AddInternalColumnsToRecord(record arrow.Record, cqClientIDValue string) (arrow.Record, error) {
110+
schema := record.Schema()
111+
nRows := int(record.NumRows())
112+
113+
newFields := []arrow.Field{}
114+
newColumns := []arrow.Array{}
115+
116+
var err error
117+
if !schema.HasField(CqIDColumn.Name) {
118+
cqID := hashRecord(record)
119+
newFields = append(newFields, CqIDColumn.ToArrowField())
120+
newColumns = append(newColumns, cqID)
121+
}
122+
if !schema.HasField(CqParentIDColumn.Name) {
123+
cqParentID := nullUUIDsForRecord(nRows)
124+
newFields = append(newFields, CqParentIDColumn.ToArrowField())
125+
newColumns = append(newColumns, cqParentID)
126+
}
127+
128+
clientIDArray := StringArrayFromValue(cqClientIDValue, nRows)
129+
if !schema.HasField(CqClientIDColumn.Name) {
130+
newFields = append(newFields, CqClientIDColumn.ToArrowField())
131+
newColumns = append(newColumns, clientIDArray)
132+
} else {
133+
record, err = ReplaceFieldInRecord(record, CqClientIDColumn.Name, clientIDArray)
134+
if err != nil {
135+
return nil, err
136+
}
137+
}
138+
139+
allFields := append(schema.Fields(), newFields...)
140+
allColumns := append(record.Columns(), newColumns...)
141+
metadata := schema.Metadata()
142+
newSchema := arrow.NewSchema(allFields, &metadata)
143+
return array.NewRecord(newSchema, allColumns, int64(nRows)), nil
144+
}

schema/arrow_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@ package schema
33
import (
44
"fmt"
55
"strings"
6+
"testing"
67

78
"github.com/apache/arrow-go/v18/arrow"
89
"github.com/apache/arrow-go/v18/arrow/array"
10+
"github.com/apache/arrow-go/v18/arrow/memory"
11+
"github.com/cloudquery/plugin-sdk/v4/types"
12+
"github.com/google/uuid"
13+
"github.com/samber/lo"
14+
"github.com/stretchr/testify/require"
915
)
1016

1117
func RecordDiff(l arrow.Record, r arrow.Record) string {
@@ -31,3 +37,128 @@ func RecordDiff(l arrow.Record, r arrow.Record) string {
3137
}
3238
return sb.String()
3339
}
40+
41+
func buildTestRecord(withClientIDValue string) arrow.Record {
42+
testFields := []arrow.Field{
43+
{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true},
44+
{Name: "name", Type: arrow.BinaryTypes.String, Nullable: true},
45+
{Name: "value", Type: arrow.PrimitiveTypes.Float64, Nullable: true},
46+
{Name: "bool", Type: arrow.FixedWidthTypes.Boolean, Nullable: true},
47+
{Name: "uuid", Type: types.UUID, Nullable: true},
48+
}
49+
if withClientIDValue != "" {
50+
testFields = append(testFields, CqClientIDColumn.ToArrowField())
51+
}
52+
schema := arrow.NewSchema(testFields, nil)
53+
54+
testValuesCount := 10
55+
builders := []array.Builder{
56+
array.NewInt64Builder(memory.DefaultAllocator),
57+
array.NewStringBuilder(memory.DefaultAllocator),
58+
array.NewFloat64Builder(memory.DefaultAllocator),
59+
array.NewBooleanBuilder(memory.DefaultAllocator),
60+
types.NewUUIDBuilder(memory.DefaultAllocator),
61+
}
62+
for _, builder := range builders {
63+
builder.Reserve(testValuesCount)
64+
switch b := builder.(type) {
65+
case *array.Int64Builder:
66+
for i := range testValuesCount {
67+
b.Append(int64(i))
68+
}
69+
case *array.StringBuilder:
70+
for i := range testValuesCount {
71+
b.AppendString(fmt.Sprintf("test%d", i))
72+
}
73+
case *array.Float64Builder:
74+
for i := range testValuesCount {
75+
b.Append(float64(i))
76+
}
77+
case *array.BooleanBuilder:
78+
for i := range testValuesCount {
79+
b.Append(i%2 == 0)
80+
}
81+
case *types.UUIDBuilder:
82+
for i := range testValuesCount {
83+
b.Append(uuid.NewSHA1(uuid.NameSpaceURL, []byte(fmt.Sprintf("test%d", i))))
84+
}
85+
}
86+
}
87+
if withClientIDValue != "" {
88+
builder := array.NewStringBuilder(memory.DefaultAllocator)
89+
builder.Reserve(testValuesCount)
90+
for range testValuesCount {
91+
builder.AppendString(withClientIDValue)
92+
}
93+
builders = append(builders, builder)
94+
}
95+
values := lo.Map(builders, func(builder array.Builder, _ int) arrow.Array {
96+
return builder.NewArray()
97+
})
98+
return array.NewRecord(schema, values, int64(testValuesCount))
99+
}
100+
101+
func TestAddInternalColumnsToRecord(t *testing.T) {
102+
tests := []struct {
103+
name string
104+
record arrow.Record
105+
cqClientIDValue string
106+
expectedNewColumns int64
107+
}{
108+
{
109+
name: "add _cq_id,_cq_parent_id,_cq_client_id",
110+
record: buildTestRecord(""),
111+
cqClientIDValue: "new_client_id",
112+
expectedNewColumns: 3,
113+
},
114+
{
115+
name: "add cq_client_id,cq_id replace existing _cq_client_id",
116+
record: buildTestRecord("existing_client_id"),
117+
cqClientIDValue: "new_client_id",
118+
expectedNewColumns: 2,
119+
},
120+
}
121+
for _, tt := range tests {
122+
t.Run(tt.name, func(t *testing.T) {
123+
got, err := AddInternalColumnsToRecord(tt.record, tt.cqClientIDValue)
124+
require.NoError(t, err)
125+
require.Equal(t, tt.record.NumRows(), got.NumRows())
126+
require.Equal(t, tt.record.NumCols()+tt.expectedNewColumns, got.NumCols())
127+
128+
gotSchema := got.Schema()
129+
cqIDFields := gotSchema.FieldIndices(CqIDColumn.Name)
130+
require.Len(t, cqIDFields, 1)
131+
132+
cqParentIDFields := gotSchema.FieldIndices(CqParentIDColumn.Name)
133+
require.Len(t, cqParentIDFields, 1)
134+
135+
cqClientIDFields := gotSchema.FieldIndices(CqClientIDColumn.Name)
136+
require.Len(t, cqClientIDFields, 1)
137+
138+
cqIDArray := got.Column(cqIDFields[0])
139+
require.Equal(t, types.UUID, cqIDArray.DataType())
140+
require.Equal(t, tt.record.NumRows(), int64(cqIDArray.Len()))
141+
142+
cqParentIDArray := got.Column(cqParentIDFields[0])
143+
require.Equal(t, types.UUID, cqParentIDArray.DataType())
144+
require.Equal(t, tt.record.NumRows(), int64(cqParentIDArray.Len()))
145+
146+
cqClientIDArray := got.Column(cqClientIDFields[0])
147+
require.Equal(t, arrow.BinaryTypes.String, cqClientIDArray.DataType())
148+
require.Equal(t, tt.record.NumRows(), int64(cqClientIDArray.Len()))
149+
150+
for i := range cqIDArray.Len() {
151+
cqID := cqIDArray.GetOneForMarshal(i).(uuid.UUID)
152+
require.NotEmpty(t, cqID)
153+
}
154+
for i := range cqParentIDArray.Len() {
155+
cqParentID := cqParentIDArray.GetOneForMarshal(i)
156+
require.Nil(t, cqParentID)
157+
}
158+
for i := range cqClientIDArray.Len() {
159+
cqClientID := cqClientIDArray.GetOneForMarshal(i).(string)
160+
require.Equal(t, tt.cqClientIDValue, cqClientID)
161+
}
162+
})
163+
}
164+
}

0 commit comments

Comments
 (0)