Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions plugin/nulls.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,11 @@ func (s *WriterTestSuite) replaceNullsByEmptyNestedArray(arr arrow.Array) arrow.

func (s *WriterTestSuite) handleNulls(record arrow.Record) arrow.Record {
cols := record.Columns()
newCols := make([]arrow.Array, len(cols))
for c, col := range cols {
cols[c] = s.handleNullsArray(col)
newCols[c] = s.handleNullsArray(col)
}
return array.NewRecord(record.Schema(), cols, record.NumRows())
return array.NewRecord(record.Schema(), newCols, record.NumRows())
}

func (s *WriterTestSuite) handleNullsArray(arr arrow.Array) arrow.Array {
Expand Down
136 changes: 80 additions & 56 deletions writers/streamingbatchwriter/streamingbatchwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package streamingbatchwriter

import (
"context"
"errors"
"fmt"
"sync"
"time"
Expand Down Expand Up @@ -178,30 +179,36 @@ func (w *StreamingBatchWriter) Close(context.Context) error {
return nil
}

func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessage) error {
func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessage) (retErr error) {
errCh := make(chan error)
defer close(errCh)

go func() {
for err := range errCh {
w.logger.Err(err).Msg("error from StreamingBatchWriter")
}
defer func() {
err := w.Close(ctx)
retErr = errors.Join(retErr, err)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New improvement: always do w.Close even we returned an error, and join errors.

}()

for msg := range msgs {
msgType := writers.MsgID(msg)
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
if err := w.Flush(ctx); err != nil {
for {
select {
case msg, ok := <-msgs:
if !ok {
return nil
}

msgType := writers.MsgID(msg)
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
if err := w.Flush(ctx); err != nil {
return err
}
}
w.lastMsgType = msgType
if err := w.startWorker(ctx, errCh, msg); err != nil {
return err
}
}
w.lastMsgType = msgType
if err := w.startWorker(ctx, errCh, msg); err != nil {

case err := <-errCh:
return err
}
}

return w.Close(ctx)
}

func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- error, msg message.WriteMessage) error {
Expand All @@ -221,13 +228,14 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
case *message.WriteMigrateTable:
w.workersLock.Lock()
defer w.workersLock.Unlock()

if w.migrateWorker != nil {
w.migrateWorker.ch <- m
return nil
}
ch := make(chan *message.WriteMigrateTable)

w.migrateWorker = &streamingWorkerManager[*message.WriteMigrateTable]{
ch: ch,
ch: make(chan *message.WriteMigrateTable),
writeFunc: w.client.MigrateTable,

flush: make(chan chan bool),
Expand All @@ -241,17 +249,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
w.workersWaitGroup.Add(1)
go w.migrateWorker.run(ctx, &w.workersWaitGroup, tableName)
w.migrateWorker.ch <- m

return nil
case *message.WriteDeleteStale:
w.workersLock.Lock()
defer w.workersLock.Unlock()

if w.deleteStaleWorker != nil {
w.deleteStaleWorker.ch <- m
return nil
}
ch := make(chan *message.WriteDeleteStale)

w.deleteStaleWorker = &streamingWorkerManager[*message.WriteDeleteStale]{
ch: ch,
ch: make(chan *message.WriteDeleteStale),
writeFunc: w.client.DeleteStale,

flush: make(chan chan bool),
Expand All @@ -265,19 +275,29 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
w.workersWaitGroup.Add(1)
go w.deleteStaleWorker.run(ctx, &w.workersWaitGroup, tableName)
w.deleteStaleWorker.ch <- m

return nil
case *message.WriteInsert:
w.workersLock.RLock()
wr, ok := w.insertWorkers[tableName]
worker, ok := w.insertWorkers[tableName]
w.workersLock.RUnlock()
if ok {
wr.ch <- m
worker.ch <- m
return nil
}

w.workersLock.Lock()
activeWorker, ok := w.insertWorkers[tableName]
if ok {
w.workersLock.Unlock()
// some other goroutine could have already added the worker
// just send the message to it & discard our allocated worker
activeWorker.ch <- m
return nil
}

ch := make(chan *message.WriteInsert)
wr = &streamingWorkerManager[*message.WriteInsert]{
ch: ch,
worker = &streamingWorkerManager[*message.WriteInsert]{
ch: make(chan *message.WriteInsert),
writeFunc: w.client.WriteTable,

flush: make(chan chan bool),
Expand All @@ -287,33 +307,27 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
batchTimeout: w.batchTimeout,
tickerFn: w.tickerFn,
}
w.workersLock.Lock()
wrOld, ok := w.insertWorkers[tableName]
if ok {
w.workersLock.Unlock()
// some other goroutine could have already added the worker
// just send the message to it & discard our allocated worker
wrOld.ch <- m
return nil
}
w.insertWorkers[tableName] = wr

w.insertWorkers[tableName] = worker
w.workersLock.Unlock()

w.workersWaitGroup.Add(1)
go wr.run(ctx, &w.workersWaitGroup, tableName)
ch <- m
go worker.run(ctx, &w.workersWaitGroup, tableName)
worker.ch <- m

return nil
case *message.WriteDeleteRecord:
w.workersLock.Lock()
defer w.workersLock.Unlock()

if w.deleteRecordWorker != nil {
w.deleteRecordWorker.ch <- m
return nil
}
ch := make(chan *message.WriteDeleteRecord)

// TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296)
w.deleteRecordWorker = &streamingWorkerManager[*message.WriteDeleteRecord]{
ch: ch,
ch: make(chan *message.WriteDeleteRecord),
writeFunc: w.client.DeleteRecords,

flush: make(chan chan bool),
Expand All @@ -327,6 +341,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
w.workersWaitGroup.Add(1)
go w.deleteRecordWorker.run(ctx, &w.workersWaitGroup, tableName)
w.deleteRecordWorker.ch <- m

return nil
default:
return fmt.Errorf("unhandled message type: %T", msg)
Expand All @@ -348,35 +363,40 @@ type streamingWorkerManager[T message.WriteMessage] struct {
func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) {
defer wg.Done()
var (
clientCh chan T
clientErrCh chan error
open bool
inputCh chan T
outputCh chan error
open bool
)

ensureOpened := func() {
if open {
return
}

clientCh = make(chan T)
clientErrCh = make(chan error, 1)
inputCh = make(chan T)
outputCh = make(chan error)
go func() {
defer close(clientErrCh)
defer close(outputCh)
defer func() {
if err := recover(); err != nil {
clientErrCh <- fmt.Errorf("panic: %v", err)
if msg := recover(); msg != nil {
switch v := msg.(type) {
case error:
outputCh <- fmt.Errorf("panic: %w [recovered]", v)
default:
outputCh <- fmt.Errorf("panic: %v [recovered]", msg)
}
}
}()
clientErrCh <- s.writeFunc(ctx, clientCh)
result := s.writeFunc(ctx, inputCh)
outputCh <- result
}()

open = true
}

closeFlush := func() {
if open {
close(clientCh)
if err := <-clientErrCh; err != nil {
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
}
close(inputCh)
s.limit.Reset()
}
open = false
Expand All @@ -394,13 +414,12 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
if !ok {
return
}

if ins, ok := any(r).(*message.WriteInsert); ok {
add, toFlush, rest := batch.SliceRecord(ins.Record, s.limit)
if add != nil {
ensureOpened()
s.limit.AddSlice(add)
clientCh <- any(&message.WriteInsert{Record: add.Record}).(T)
inputCh <- any(&message.WriteInsert{Record: add.Record}).(T)
}
if len(toFlush) > 0 || rest != nil || s.limit.ReachedLimit() {
// flush current batch
Expand All @@ -410,7 +429,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
for _, sliceToFlush := range toFlush {
ensureOpened()
s.limit.AddRows(sliceToFlush.NumRows())
clientCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
inputCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
closeFlush()
ticker.Reset(s.batchTimeout)
}
Expand All @@ -419,11 +438,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
if rest != nil {
ensureOpened()
s.limit.AddSlice(rest)
clientCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
inputCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
}
} else {
ensureOpened()
clientCh <- r
inputCh <- r
s.limit.AddRows(1)
if s.limit.ReachedLimit() {
closeFlush()
Expand All @@ -441,6 +460,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
ticker.Reset(s.batchTimeout)
}
done <- true
case err := <-outputCh:
if err != nil {
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
return
}
case <-ctxDone:
// this means the request was cancelled
return // after this NO other call will succeed
Expand Down
Loading