@@ -182,26 +182,28 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
182182 errCh := make (chan error )
183183 defer close (errCh )
184184
185- go func () {
186- for err := range errCh {
187- w .logger .Err (err ).Msg ("error from StreamingBatchWriter" )
188- }
189- }()
185+ for {
186+ select {
187+ case msg , ok := <- msgs :
188+ if ! ok {
189+ return w .Close (ctx )
190+ }
190191
191- for msg := range msgs {
192- msgType := writers .MsgID (msg )
193- if w .lastMsgType != writers .MsgTypeUnset && w .lastMsgType != msgType {
194- if err := w .Flush (ctx ); err != nil {
192+ msgType := writers .MsgID (msg )
193+ if w .lastMsgType != writers .MsgTypeUnset && w .lastMsgType != msgType {
194+ if err := w .Flush (ctx ); err != nil {
195+ return err
196+ }
197+ }
198+ w .lastMsgType = msgType
199+ if err := w .startWorker (ctx , errCh , msg ); err != nil {
195200 return err
196201 }
197- }
198- w .lastMsgType = msgType
199- if err := w .startWorker (ctx , errCh , msg ); err != nil {
202+
203+ case err := <- errCh :
200204 return err
201205 }
202206 }
203-
204- return w .Close (ctx )
205207}
206208
207209func (w * StreamingBatchWriter ) startWorker (ctx context.Context , errCh chan <- error , msg message.WriteMessage ) error {
@@ -221,13 +223,14 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
221223 case * message.WriteMigrateTable :
222224 w .workersLock .Lock ()
223225 defer w .workersLock .Unlock ()
226+
224227 if w .migrateWorker != nil {
225228 w .migrateWorker .ch <- m
226229 return nil
227230 }
228- ch := make ( chan * message. WriteMigrateTable )
231+
229232 w .migrateWorker = & streamingWorkerManager [* message.WriteMigrateTable ]{
230- ch : ch ,
233+ ch : make ( chan * message. WriteMigrateTable ) ,
231234 writeFunc : w .client .MigrateTable ,
232235
233236 flush : make (chan chan bool ),
@@ -241,17 +244,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
241244 w .workersWaitGroup .Add (1 )
242245 go w .migrateWorker .run (ctx , & w .workersWaitGroup , tableName )
243246 w .migrateWorker .ch <- m
247+
244248 return nil
245249 case * message.WriteDeleteStale :
246250 w .workersLock .Lock ()
247251 defer w .workersLock .Unlock ()
252+
248253 if w .deleteStaleWorker != nil {
249254 w .deleteStaleWorker .ch <- m
250255 return nil
251256 }
252- ch := make ( chan * message. WriteDeleteStale )
257+
253258 w .deleteStaleWorker = & streamingWorkerManager [* message.WriteDeleteStale ]{
254- ch : ch ,
259+ ch : make ( chan * message. WriteDeleteStale ) ,
255260 writeFunc : w .client .DeleteStale ,
256261
257262 flush : make (chan chan bool ),
@@ -265,19 +270,29 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
265270 w .workersWaitGroup .Add (1 )
266271 go w .deleteStaleWorker .run (ctx , & w .workersWaitGroup , tableName )
267272 w .deleteStaleWorker .ch <- m
273+
268274 return nil
269275 case * message.WriteInsert :
270276 w .workersLock .RLock ()
271- wr , ok := w .insertWorkers [tableName ]
277+ worker , ok := w .insertWorkers [tableName ]
272278 w .workersLock .RUnlock ()
273279 if ok {
274- wr .ch <- m
280+ worker .ch <- m
275281 return nil
276282 }
277283
278- ch := make (chan * message.WriteInsert )
279- wr = & streamingWorkerManager [* message.WriteInsert ]{
280- ch : ch ,
284+ w .workersLock .Lock ()
285+ activeWorker , ok := w .insertWorkers [tableName ]
286+ if ok {
287+ w .workersLock .Unlock ()
288+ // some other goroutine could have already added the worker
289+ // just send the message to it & discard our allocated worker
290+ activeWorker .ch <- m
291+ return nil
292+ }
293+
294+ worker = & streamingWorkerManager [* message.WriteInsert ]{
295+ ch : make (chan * message.WriteInsert ),
281296 writeFunc : w .client .WriteTable ,
282297
283298 flush : make (chan chan bool ),
@@ -287,33 +302,27 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
287302 batchTimeout : w .batchTimeout ,
288303 tickerFn : w .tickerFn ,
289304 }
290- w .workersLock .Lock ()
291- wrOld , ok := w .insertWorkers [tableName ]
292- if ok {
293- w .workersLock .Unlock ()
294- // some other goroutine could have already added the worker
295- // just send the message to it & discard our allocated worker
296- wrOld .ch <- m
297- return nil
298- }
299- w .insertWorkers [tableName ] = wr
305+
306+ w .insertWorkers [tableName ] = worker
300307 w .workersLock .Unlock ()
301308
302309 w .workersWaitGroup .Add (1 )
303- go wr .run (ctx , & w .workersWaitGroup , tableName )
304- ch <- m
310+ go worker .run (ctx , & w .workersWaitGroup , tableName )
311+ worker .ch <- m
312+
305313 return nil
306314 case * message.WriteDeleteRecord :
307315 w .workersLock .Lock ()
308316 defer w .workersLock .Unlock ()
317+
309318 if w .deleteRecordWorker != nil {
310319 w .deleteRecordWorker .ch <- m
311320 return nil
312321 }
313- ch := make ( chan * message. WriteDeleteRecord )
322+
314323 // TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296)
315324 w .deleteRecordWorker = & streamingWorkerManager [* message.WriteDeleteRecord ]{
316- ch : ch ,
325+ ch : make ( chan * message. WriteDeleteRecord ) ,
317326 writeFunc : w .client .DeleteRecords ,
318327
319328 flush : make (chan chan bool ),
@@ -327,6 +336,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
327336 w .workersWaitGroup .Add (1 )
328337 go w .deleteRecordWorker .run (ctx , & w .workersWaitGroup , tableName )
329338 w .deleteRecordWorker .ch <- m
339+
330340 return nil
331341 default :
332342 return fmt .Errorf ("unhandled message type: %T" , msg )
@@ -348,35 +358,40 @@ type streamingWorkerManager[T message.WriteMessage] struct {
348358func (s * streamingWorkerManager [T ]) run (ctx context.Context , wg * sync.WaitGroup , tableName string ) {
349359 defer wg .Done ()
350360 var (
351- clientCh chan T
352- clientErrCh chan error
353- open bool
361+ inputCh chan T
362+ outputCh chan error
363+ open bool
354364 )
355365
356366 ensureOpened := func () {
357367 if open {
358368 return
359369 }
360370
361- clientCh = make (chan T )
362- clientErrCh = make (chan error , 1 )
371+ inputCh = make (chan T )
372+ outputCh = make (chan error )
363373 go func () {
364- defer close (clientErrCh )
374+ defer close (outputCh )
365375 defer func () {
366- if err := recover (); err != nil {
367- clientErrCh <- fmt .Errorf ("panic: %v" , err )
376+ if msg := recover (); msg != nil {
377+ switch v := msg .(type ) {
378+ case error :
379+ outputCh <- fmt .Errorf ("panic: %w [recovered]" , v )
380+ default :
381+ outputCh <- fmt .Errorf ("panic: %v [recovered]" , msg )
382+ }
368383 }
369384 }()
370- clientErrCh <- s .writeFunc (ctx , clientCh )
385+ result := s .writeFunc (ctx , inputCh )
386+ outputCh <- result
371387 }()
388+
372389 open = true
373390 }
391+
374392 closeFlush := func () {
375393 if open {
376- close (clientCh )
377- if err := <- clientErrCh ; err != nil {
378- s .errCh <- fmt .Errorf ("handler failed on %s: %w" , tableName , err )
379- }
394+ close (inputCh )
380395 s .limit .Reset ()
381396 }
382397 open = false
@@ -400,7 +415,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
400415 if add != nil {
401416 ensureOpened ()
402417 s .limit .AddSlice (add )
403- clientCh <- any (& message.WriteInsert {Record : add .Record }).(T )
418+ inputCh <- any (& message.WriteInsert {Record : add .Record }).(T )
404419 }
405420 if len (toFlush ) > 0 || rest != nil || s .limit .ReachedLimit () {
406421 // flush current batch
@@ -410,7 +425,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
410425 for _ , sliceToFlush := range toFlush {
411426 ensureOpened ()
412427 s .limit .AddRows (sliceToFlush .NumRows ())
413- clientCh <- any (& message.WriteInsert {Record : sliceToFlush }).(T )
428+ inputCh <- any (& message.WriteInsert {Record : sliceToFlush }).(T )
414429 closeFlush ()
415430 ticker .Reset (s .batchTimeout )
416431 }
@@ -419,11 +434,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
419434 if rest != nil {
420435 ensureOpened ()
421436 s .limit .AddSlice (rest )
422- clientCh <- any (& message.WriteInsert {Record : rest .Record }).(T )
437+ inputCh <- any (& message.WriteInsert {Record : rest .Record }).(T )
423438 }
424439 } else {
425440 ensureOpened ()
426- clientCh <- r
441+ inputCh <- r
427442 s .limit .AddRows (1 )
428443 if s .limit .ReachedLimit () {
429444 closeFlush ()
@@ -441,6 +456,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
441456 ticker .Reset (s .batchTimeout )
442457 }
443458 done <- true
459+ case err := <- outputCh :
460+ if err != nil {
461+ s .errCh <- fmt .Errorf ("handler failed on %s: %w" , tableName , err )
462+ return
463+ }
444464 case <- ctxDone :
445465 // this means the request was cancelled
446466 return // after this NO other call will succeed
0 commit comments