Skip to content

Commit ba322ae

Browse files
committed
Add more tests and improve error handling
1 parent 957d854 commit ba322ae

File tree

3 files changed

+148
-75
lines changed

3 files changed

+148
-75
lines changed

appender.go

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ package duckdb
33
import (
44
"database/sql/driver"
55
"errors"
6+
"fmt"
67
"runtime"
78
"sync"
89

910
"github.com/duckdb/duckdb-go/mapping"
11+
"golang.org/x/sync/errgroup"
1012
)
1113

1214
type (
@@ -224,70 +226,46 @@ func (a *Appender) initAppenderChunk() (*Appender, error) {
224226
return a, nil
225227
}
226228

229+
// AppendTableSource appends an AppenderSource to the appender. Repeatedly calls `FillRow` or
230+
// `FillChunk` depending on the AppenderSource type. If an error is returned, any data
231+
// inserted by that call to `FillRow` or `FillChunk` is ignored. Any previous calls to either
232+
// of these functions will be processed and appended.
227233
func (a *Appender) AppendTableSource(s AppenderSource) error {
234+
var runParallel = func(maxThreads int, worker func() error) error {
235+
g := errgroup.Group{}
236+
for range min(maxThreads, runtime.GOMAXPROCS(-1)) {
237+
g.Go(worker)
238+
}
239+
return g.Wait()
240+
}
241+
228242
lock := &sync.Mutex{}
229243
// projection is not used in chunk, so we must keep it a 1-1 mapping
230244
columnCount := mapping.AppenderColumnCount(a.appender)
231245
projection := make([]int, 0, columnCount)
232-
for i := mapping.IdxT(0); i < columnCount; i++ {
246+
for i := range columnCount {
233247
projection = append(projection, int(i))
234248
}
235-
var x any = s
236-
switch s := x.(type) {
249+
switch s := s.(type) {
237250
case rowAppenderSource:
238251
s.Source.Init()
239-
err := appenderRowThread(&parallelRowTSWrapper{s.Source}, lock, a.types, a.appender, projection)
240-
if err != nil {
241-
return err
242-
}
252+
return appenderRowThread(&parallelRowTSWrapper{s.Source}, lock, a.types, a.appender, projection)
243253
case parallelRowAppenderSource:
244-
wg := sync.WaitGroup{}
245-
246254
info := s.Source.Init()
247-
threads := min(info.MaxThreads, runtime.GOMAXPROCS(-1))
248-
var oerr error
249-
for range threads {
250-
wg.Add(1)
251-
go func() {
252-
err := appenderRowThread(s.Source, lock, a.types, a.appender, projection)
253-
if err != nil {
254-
oerr = err
255-
}
256-
wg.Done()
257-
}()
258-
}
259-
wg.Wait()
260-
if oerr != nil {
261-
return oerr
262-
}
255+
return runParallel(info.MaxThreads, func() error {
256+
return appenderRowThread(s.Source, lock, a.types, a.appender, projection)
257+
})
263258
case chunkAppenderSource:
264259
s.Source.Init()
265-
err := appenderChunkThread(&parallelChunkTSWrapper{s.Source}, lock, a.types, a.appender)
266-
if err != nil {
267-
return err
268-
}
260+
return appenderChunkThread(&parallelChunkTSWrapper{s.Source}, lock, a.types, a.appender)
269261
case parallelChunkAppenderSource:
270-
wg := sync.WaitGroup{}
271-
272262
info := s.Source.Init()
273-
threads := min(info.MaxThreads, runtime.GOMAXPROCS(-1))
274-
var oerr error
275-
for range threads {
276-
wg.Add(1)
277-
go func() {
278-
err := appenderChunkThread(s.Source, lock, a.types, a.appender)
279-
if err != nil {
280-
oerr = err
281-
}
282-
wg.Done()
283-
}()
284-
}
285-
wg.Wait()
286-
if oerr != nil {
287-
return oerr
288-
}
263+
return runParallel(info.MaxThreads, func() error {
264+
return appenderChunkThread(s.Source, lock, a.types, a.appender)
265+
})
266+
default:
267+
return fmt.Errorf("unknown AppenderSource type: %T. Must be created with NewAppenderRowSource, NewAppenderParallelRowSource, NewAppenderChunkSource, or NewAppenderParallelChunkSource", s)
289268
}
290-
return nil
291269
}
292270

293271
func (a *Appender) appendRowSlice(args []driver.Value) error {
@@ -341,6 +319,7 @@ func appenderRowThread(s ParallelRowTableSource, lock *sync.Mutex, types []mappi
341319
if err != nil {
342320
return err
343321
}
322+
defer chunk.close()
344323
chunk.projection = projection
345324

346325
for {
@@ -349,13 +328,10 @@ func appenderRowThread(s ParallelRowTableSource, lock *sync.Mutex, types []mappi
349328
r: 0,
350329
}
351330
var next bool
331+
var err error
352332
for ; row.r < mapping.IdxT(maxSize); row.r++ {
353333
next, err = s.FillRow(lstate, row)
354-
if err != nil {
355-
chunk.close()
356-
return err
357-
}
358-
if !next {
334+
if err != nil || !next {
359335
break
360336
}
361337
}
@@ -370,12 +346,14 @@ func appenderRowThread(s ParallelRowTableSource, lock *sync.Mutex, types []mappi
370346
return err
371347
}
372348
lock.Unlock()
349+
if err != nil {
350+
return err
351+
}
373352
if !next {
374353
break
375354
}
376355
chunk.reset(true)
377356
}
378-
chunk.close()
379357
return nil
380358
}
381359

@@ -386,15 +364,13 @@ func appenderChunkThread(s ParallelChunkTableSource, lock *sync.Mutex, types []m
386364
if err != nil {
387365
return err
388366
}
389-
367+
defer chunk.close()
390368
for {
391369
err = s.FillChunk(lstate, chunk)
392370
if err != nil {
393371
return err
394372
}
395-
396373
if chunk.GetSize() == 0 {
397-
chunk.close()
398374
break
399375
}
400376

@@ -407,8 +383,8 @@ func appenderChunkThread(s ParallelChunkTableSource, lock *sync.Mutex, types []m
407383
}
408384
lock.Unlock()
409385
chunk.reset(true)
386+
410387
}
411-
chunk.close()
412388
return nil
413389
}
414390

appender_test.go

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"database/sql"
66
"database/sql/driver"
77
"encoding/json"
8+
"errors"
89
"fmt"
910
"math/big"
1011
"math/rand"
@@ -1103,6 +1104,7 @@ type (
11031104
appStructTableUDF struct {
11041105
n int64
11051106
count int64
1107+
err error // return this when done. Could be nil
11061108
}
11071109

11081110
appParStructTableUDF struct {
@@ -1113,19 +1115,17 @@ type (
11131115
)
11141116

11151117
func (udf *appStructTableUDF) ColumnInfos() []ColumnInfo {
1116-
t, _ := NewTypeInfo(TYPE_BIGINT)
1117-
t2, _ := NewTypeInfo(TYPE_UTINYINT)
11181118
return []ColumnInfo{
1119-
{Name: "id", T: t},
1120-
{Name: "uint8", T: t2},
1119+
{Name: "id", T: typeBigintTableUDF},
1120+
{Name: "uint8", T: typeUTinyintTableUDF},
11211121
}
11221122
}
11231123

11241124
func (udf *appStructTableUDF) Init() {}
11251125

11261126
func (udf *appStructTableUDF) FillRow(row Row) (bool, error) {
11271127
if udf.count >= udf.n {
1128-
return false, nil
1128+
return false, udf.err
11291129
}
11301130
udf.count++
11311131
err := SetRowValue(row, 0, udf.count)
@@ -1153,7 +1153,10 @@ func (udf *appStructTableUDF) Cardinality() *CardinalityInfo {
11531153
}
11541154

11551155
func (udf *appParStructTableUDF) ColumnInfos() []ColumnInfo {
1156-
return []ColumnInfo{{Name: "result", T: typeBigintTableUDF}}
1156+
return []ColumnInfo{
1157+
{Name: "result", T: typeBigintTableUDF},
1158+
{Name: "result2", T: typeBigintTableUDF},
1159+
}
11571160
}
11581161

11591162
func (udf *appParStructTableUDF) Init() ParallelTableSourceInfo {
@@ -1196,7 +1199,6 @@ func (udf *appParStructTableUDF) FillRow(localState any, row Row) (bool, error)
11961199
}
11971200
err = SetRowValue(row, 1, state.start)
11981201
return true, err
1199-
12001202
}
12011203

12021204
func (udf *appParStructTableUDF) GetValue(r, c int) any {
@@ -1247,16 +1249,17 @@ func TestAppendParallelRowSource(t *testing.T) {
12471249
args[i] = &values[i]
12481250
}
12491251

1250-
count := 0
1252+
count := int64(0)
12511253
for r := 0; res.Next(); r++ {
12521254
require.NoError(t, res.Scan(args...))
12531255
for i, value := range values {
12541256
expected := f.GetValue(r, i)
1255-
require.Equal(t, expected, value, "incorrect value", r, i)
1257+
require.Equal(t, expected, value, "incorrect value at row %v column %v", r, i)
12561258
}
12571259
count++
12581260
}
12591261
cleanupAppender(t, c, db, con, a)
1262+
require.Equal(t, f.n, count, "number of rows should be equal")
12601263
}
12611264

12621265
func TestAppendParallelRowSourceSingle(t *testing.T) {
@@ -1289,18 +1292,17 @@ func TestAppendParallelRowSourceSingle(t *testing.T) {
12891292
args[i] = &values[i]
12901293
}
12911294

1292-
count := 0
1295+
count := int64(0)
12931296
for r := 0; res.Next(); r++ {
12941297
require.NoError(t, res.Scan(args...))
12951298
for i, value := range values {
12961299
expected := f.GetValue(r, i)
1297-
//fmt.Println(args)
1298-
//fmt.Println(expected, value)
12991300
require.Equal(t, expected, value, "incorrect value", r, i)
13001301
}
13011302
count++
13021303
}
13031304
cleanupAppender(t, c, db, con, a)
1305+
require.Equal(t, f.n, count, "number of rows should be correct")
13041306
}
13051307

13061308
func TestAppendRowSource(t *testing.T) {
@@ -1332,7 +1334,95 @@ func TestAppendRowSource(t *testing.T) {
13321334
args[i] = &values[i]
13331335
}
13341336

1335-
count := 0
1337+
count := int64(0)
1338+
for r := 0; res.Next(); r++ {
1339+
require.NoError(t, res.Scan(args...))
1340+
for i, value := range values {
1341+
expected := f.GetValue(r, i)
1342+
require.Equal(t, expected, value, "incorrect value", r, i)
1343+
}
1344+
count++
1345+
}
1346+
cleanupAppender(t, c, db, con, a)
1347+
require.Equal(t, f.n, count, "number of rows should be equal")
1348+
}
1349+
1350+
func TestAppendRowSourceError(t *testing.T) {
1351+
t.Parallel()
1352+
sc := `
1353+
CREATE TABLE test (
1354+
id BIGINT,
1355+
uint8 UTINYINT
1356+
)`
1357+
c, db, con, a := prepareAppender(t, sc)
1358+
1359+
f := appStructTableUDF{
1360+
n: 3000,
1361+
err: errors.New("Test test"),
1362+
}
1363+
1364+
err := a.AppendTableSource(NewAppenderRowSource(&f))
1365+
require.Equal(t, err, f.err)
1366+
1367+
err = a.Flush()
1368+
require.NoError(t, err)
1369+
1370+
// Verify results.
1371+
res, err := sql.OpenDB(c).QueryContext(context.Background(), `SELECT * FROM test ORDER BY id`)
1372+
require.NoError(t, err)
1373+
1374+
values := f.GetTypes()
1375+
args := make([]any, len(values))
1376+
for i := range values {
1377+
args[i] = &values[i]
1378+
}
1379+
1380+
count := int64(0)
1381+
for r := 0; res.Next(); r++ {
1382+
require.NoError(t, res.Scan(args...))
1383+
for i, value := range values {
1384+
expected := f.GetValue(r, i)
1385+
require.Equal(t, expected, value, "incorrect value", r, i)
1386+
}
1387+
count++
1388+
}
1389+
cleanupAppender(t, c, db, con, a)
1390+
require.Equal(t, f.n, count, "number of rows should be equal")
1391+
}
1392+
1393+
func roundDownChunk[T int64 | int](x T) T {
1394+
return (x / 2048) * 2048
1395+
}
1396+
1397+
func TestAppendChunkSourceError(t *testing.T) {
1398+
t.Parallel()
1399+
sc := `
1400+
CREATE TABLE test (
1401+
id BIGINT,
1402+
)`
1403+
c, db, con, a := prepareAppender(t, sc)
1404+
1405+
f := chunkIncTableUDF{
1406+
n: 3000,
1407+
err: errors.New("Test test"),
1408+
}
1409+
err := a.AppendTableSource(NewAppenderChunkSource(&f))
1410+
require.Equal(t, err, f.err)
1411+
1412+
err = a.Flush()
1413+
require.NoError(t, err)
1414+
1415+
// Verify results.
1416+
res, err := sql.OpenDB(c).QueryContext(context.Background(), `SELECT * FROM test ORDER BY id`)
1417+
require.NoError(t, err)
1418+
1419+
values := f.GetTypes()
1420+
args := make([]any, len(values))
1421+
for i := range values {
1422+
args[i] = &values[i]
1423+
}
1424+
1425+
count := int64(0)
13361426
for r := 0; res.Next(); r++ {
13371427
require.NoError(t, res.Scan(args...))
13381428
for i, value := range values {
@@ -1342,6 +1432,7 @@ func TestAppendRowSource(t *testing.T) {
13421432
count++
13431433
}
13441434
cleanupAppender(t, c, db, con, a)
1435+
require.Equal(t, roundDownChunk(f.n), count, "number of rows should be equal")
13451436
}
13461437

13471438
func BenchmarkAppenderNested(b *testing.B) {
@@ -1467,7 +1558,7 @@ func benchmarkAppenderSingle[T any](v T) func(*testing.B) {
14671558
tableSQL := fmt.Sprintf(createSingleTableSQL, types[reflect.TypeFor[T]()])
14681559
c, db, con, a := prepareAppender(b, tableSQL)
14691560

1470-
var vec [benchmarkRowsToAppend]T = [benchmarkRowsToAppend]T{}
1561+
vec := [benchmarkRowsToAppend]T{}
14711562
for i := range benchmarkRowsToAppend {
14721563
vec[i] = v
14731564
}

0 commit comments

Comments
 (0)