diff --git a/bindings_arrow.go b/bindings_arrow.go index 1097815..8a5fe4e 100644 --- a/bindings_arrow.go +++ b/bindings_arrow.go @@ -152,9 +152,10 @@ func SchemaFromArrow(conn Connection, schema *arrow.Schema) (ArrowConvertedSchem } // DataChunkFromArrow converts an Arrow RecordBatch to a DuckDB DataChunk using the provided Connection and ArrowConvertedSchema. -// The returned DataChunk must be destroyed with DestroyDataChunk. +// The provided DataChunk must be pre-allocated with the correct schema. // The returned ErrorData must be checked for errors and destroyed with DestroyErrorData. -func DataChunkFromArrow(conn Connection, rec arrow.RecordBatch, schema ArrowConvertedSchema) (DataChunk, ErrorData) { +func DataChunkFromArrow(conn Connection, rec arrow.RecordBatch, schema ArrowConvertedSchema, chunk DataChunk) ErrorData { + // Export Arrow RecordBatch to C ArrowArray arr := C.calloc(1, C.sizeof_struct_ArrowArray) defer func() { cdata.ReleaseCArrowArray((*cdata.CArrowArray)(arr)) @@ -167,16 +168,13 @@ func DataChunkFromArrow(conn Connection, rec arrow.RecordBatch, schema ArrowConv }() cdata.ExportArrowRecordBatch(rec, (*cdata.CArrowArray)(arr), (*cdata.CArrowSchema)(arrs)) - var chunk C.duckdb_data_chunk - ed := C.duckdb_data_chunk_from_arrow(conn.data(), (*C.struct_ArrowArray)(arr), schema.data(), &chunk) + cd := chunk.data() + ed := C.duckdb_data_chunk_from_arrow(conn.data(), (*C.struct_ArrowArray)(arr), schema.data(), &cd) errData := ErrorData{Ptr: unsafe.Pointer(ed)} if debugMode && ed != nil { incrAllocCount("errorData") } - if debugMode && chunk != nil { - incrAllocCount("chunk") - } - return DataChunk{Ptr: unsafe.Pointer(chunk)}, errData + return errData } // ------------------------------------------------------------------ // diff --git a/bindings_arrow_test.go b/bindings_arrow_test.go index a9fb7b8..45b78d7 100644 --- a/bindings_arrow_test.go +++ b/bindings_arrow_test.go @@ -4,8 +4,10 @@ package duckdb_go_bindings import ( "testing" + "unsafe" "github.com/apache/arrow-go/v18/arrow/array" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -100,7 +102,8 @@ func TestArrow(t *testing.T) { require.False(t, ErrorDataHasError(ed)) defer DestroyArrowConvertedSchema(&convSchema) - dataChunk, ed := DataChunkFromArrow(conn, newRec, convSchema) + dataChunk := CreateDataChunk(types) + ed = DataChunkFromArrow(conn, newRec, convSchema, dataChunk) defer DestroyErrorData(&ed) require.False(t, ErrorDataHasError(ed)) defer DestroyDataChunk(&dataChunk) @@ -109,5 +112,26 @@ func TestArrow(t *testing.T) { require.Equal(t, colCount, cc) rc := DataChunkGetSize(dataChunk) - require.Equal(t, IdxT(rec.NumRows()), rc) + assert.Equal(t, IdxT(3), rc) + + // check chunk values + vecInt := DataChunkGetVector(chunk, IdxT(0)) + vecStr := DataChunkGetVector(chunk, IdxT(1)) + for rowIdx := range 3 { + intVal := getPrimitive[int32](vecInt, IdxT(rowIdx)) + assert.Equal(t, int32(rowIdx+1), intVal) + + strT := getPrimitive[StringT](vecStr, IdxT(rowIdx)) + strVal := StringTData(&strT) + assert.Equal(t, []string{"foo", "bar", ""}[rowIdx], strVal) + } +} + +func getPrimitive[T any](vec Vector, rowIdx IdxT) T { + dataPtr := VectorGetData(vec) + var zero T + elementSize := unsafe.Sizeof(zero) + offset := uintptr(rowIdx) * elementSize + ptr := unsafe.Add(dataPtr, offset) + return *(*T)(ptr) }