@@ -182,28 +182,26 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
182182 errCh := make (chan error )
183183 defer close (errCh )
184184
185- for {
186- select {
187- case msg , ok := <- msgs :
188- if ! ok {
189- return w .Close (ctx )
190- }
185+ go func () {
186+ for err := range errCh {
187+ w .logger .Err (err ).Msg ("error from StreamingBatchWriter" )
188+ }
189+ }()
191190
192- msgType := writers .MsgID (msg )
193- if w .lastMsgType != writers .MsgTypeUnset && w .lastMsgType != msgType {
194- if err := w .Flush (ctx ); err != nil {
195- return err
196- }
197- }
198- w .lastMsgType = msgType
199- if err := w .startWorker (ctx , errCh , msg ); err != nil {
191+ for msg := range msgs {
192+ msgType := writers .MsgID (msg )
193+ if w .lastMsgType != writers .MsgTypeUnset && w .lastMsgType != msgType {
194+ if err := w .Flush (ctx ); err != nil {
200195 return err
201196 }
202-
203- case err := <- errCh :
197+ }
198+ w .lastMsgType = msgType
199+ if err := w .startWorker (ctx , errCh , msg ); err != nil {
204200 return err
205201 }
206202 }
203+
204+ return w .Close (ctx )
207205}
208206
209207func (w * StreamingBatchWriter ) startWorker (ctx context.Context , errCh chan <- error , msg message.WriteMessage ) error {
@@ -223,14 +221,13 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
223221 case * message.WriteMigrateTable :
224222 w .workersLock .Lock ()
225223 defer w .workersLock .Unlock ()
226-
227224 if w .migrateWorker != nil {
228225 w .migrateWorker .ch <- m
229226 return nil
230227 }
231-
228+ ch := make ( chan * message. WriteMigrateTable )
232229 w .migrateWorker = & streamingWorkerManager [* message.WriteMigrateTable ]{
233- ch : make ( chan * message. WriteMigrateTable ) ,
230+ ch : ch ,
234231 writeFunc : w .client .MigrateTable ,
235232
236233 flush : make (chan chan bool ),
@@ -244,19 +241,17 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
244241 w .workersWaitGroup .Add (1 )
245242 go w .migrateWorker .run (ctx , & w .workersWaitGroup , tableName )
246243 w .migrateWorker .ch <- m
247-
248244 return nil
249245 case * message.WriteDeleteStale :
250246 w .workersLock .Lock ()
251247 defer w .workersLock .Unlock ()
252-
253248 if w .deleteStaleWorker != nil {
254249 w .deleteStaleWorker .ch <- m
255250 return nil
256251 }
257-
252+ ch := make ( chan * message. WriteDeleteStale )
258253 w .deleteStaleWorker = & streamingWorkerManager [* message.WriteDeleteStale ]{
259- ch : make ( chan * message. WriteDeleteStale ) ,
254+ ch : ch ,
260255 writeFunc : w .client .DeleteStale ,
261256
262257 flush : make (chan chan bool ),
@@ -270,29 +265,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
270265 w .workersWaitGroup .Add (1 )
271266 go w .deleteStaleWorker .run (ctx , & w .workersWaitGroup , tableName )
272267 w .deleteStaleWorker .ch <- m
273-
274268 return nil
275269 case * message.WriteInsert :
276270 w .workersLock .RLock ()
277- worker , ok := w .insertWorkers [tableName ]
271+ wr , ok := w .insertWorkers [tableName ]
278272 w .workersLock .RUnlock ()
279273 if ok {
280- worker .ch <- m
274+ wr .ch <- m
281275 return nil
282276 }
283277
284- w .workersLock .Lock ()
285- activeWorker , ok := w .insertWorkers [tableName ]
286- if ok {
287- w .workersLock .Unlock ()
288- // some other goroutine could have already added the worker
289- // just send the message to it & discard our allocated worker
290- activeWorker .ch <- m
291- return nil
292- }
293-
294- worker = & streamingWorkerManager [* message.WriteInsert ]{
295- ch : make (chan * message.WriteInsert ),
278+ ch := make (chan * message.WriteInsert )
279+ wr = & streamingWorkerManager [* message.WriteInsert ]{
280+ ch : ch ,
296281 writeFunc : w .client .WriteTable ,
297282
298283 flush : make (chan chan bool ),
@@ -302,27 +287,33 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
302287 batchTimeout : w .batchTimeout ,
303288 tickerFn : w .tickerFn ,
304289 }
305-
306- w .insertWorkers [tableName ] = worker
290+ w .workersLock .Lock ()
291+ wrOld , ok := w .insertWorkers [tableName ]
292+ if ok {
293+ w .workersLock .Unlock ()
294+ // some other goroutine could have already added the worker
295+ // just send the message to it & discard our allocated worker
296+ wrOld .ch <- m
297+ return nil
298+ }
299+ w .insertWorkers [tableName ] = wr
307300 w .workersLock .Unlock ()
308301
309302 w .workersWaitGroup .Add (1 )
310- go worker .run (ctx , & w .workersWaitGroup , tableName )
311- worker .ch <- m
312-
303+ go wr .run (ctx , & w .workersWaitGroup , tableName )
304+ ch <- m
313305 return nil
314306 case * message.WriteDeleteRecord :
315307 w .workersLock .Lock ()
316308 defer w .workersLock .Unlock ()
317-
318309 if w .deleteRecordWorker != nil {
319310 w .deleteRecordWorker .ch <- m
320311 return nil
321312 }
322-
313+ ch := make ( chan * message. WriteDeleteRecord )
323314 // TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296)
324315 w .deleteRecordWorker = & streamingWorkerManager [* message.WriteDeleteRecord ]{
325- ch : make ( chan * message. WriteDeleteRecord ) ,
316+ ch : ch ,
326317 writeFunc : w .client .DeleteRecords ,
327318
328319 flush : make (chan chan bool ),
@@ -336,7 +327,6 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
336327 w .workersWaitGroup .Add (1 )
337328 go w .deleteRecordWorker .run (ctx , & w .workersWaitGroup , tableName )
338329 w .deleteRecordWorker .ch <- m
339-
340330 return nil
341331 default :
342332 return fmt .Errorf ("unhandled message type: %T" , msg )
@@ -358,40 +348,35 @@ type streamingWorkerManager[T message.WriteMessage] struct {
358348func (s * streamingWorkerManager [T ]) run (ctx context.Context , wg * sync.WaitGroup , tableName string ) {
359349 defer wg .Done ()
360350 var (
361- inputCh chan T
362- outputCh chan error
363- open bool
351+ clientCh chan T
352+ clientErrCh chan error
353+ open bool
364354 )
365355
366356 ensureOpened := func () {
367357 if open {
368358 return
369359 }
370360
371- inputCh = make (chan T )
372- outputCh = make (chan error )
361+ clientCh = make (chan T )
362+ clientErrCh = make (chan error , 1 )
373363 go func () {
374- defer close (outputCh )
364+ defer close (clientErrCh )
375365 defer func () {
376- if msg := recover (); msg != nil {
377- switch v := msg .(type ) {
378- case error :
379- outputCh <- fmt .Errorf ("panic: %w [recovered]" , v )
380- default :
381- outputCh <- fmt .Errorf ("panic: %v [recovered]" , msg )
382- }
366+ if err := recover (); err != nil {
367+ clientErrCh <- fmt .Errorf ("panic: %v" , err )
383368 }
384369 }()
385- result := s .writeFunc (ctx , inputCh )
386- outputCh <- result
370+ clientErrCh <- s .writeFunc (ctx , clientCh )
387371 }()
388-
389372 open = true
390373 }
391-
392374 closeFlush := func () {
393375 if open {
394- close (inputCh )
376+ close (clientCh )
377+ if err := <- clientErrCh ; err != nil {
378+ s .errCh <- fmt .Errorf ("handler failed on %s: %w" , tableName , err )
379+ }
395380 s .limit .Reset ()
396381 }
397382 open = false
@@ -415,7 +400,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
415400 if add != nil {
416401 ensureOpened ()
417402 s .limit .AddSlice (add )
418- inputCh <- any (& message.WriteInsert {Record : add .Record }).(T )
403+ clientCh <- any (& message.WriteInsert {Record : add .Record }).(T )
419404 }
420405 if len (toFlush ) > 0 || rest != nil || s .limit .ReachedLimit () {
421406 // flush current batch
@@ -425,7 +410,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
425410 for _ , sliceToFlush := range toFlush {
426411 ensureOpened ()
427412 s .limit .AddRows (sliceToFlush .NumRows ())
428- inputCh <- any (& message.WriteInsert {Record : sliceToFlush }).(T )
413+ clientCh <- any (& message.WriteInsert {Record : sliceToFlush }).(T )
429414 closeFlush ()
430415 ticker .Reset (s .batchTimeout )
431416 }
@@ -434,11 +419,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
434419 if rest != nil {
435420 ensureOpened ()
436421 s .limit .AddSlice (rest )
437- inputCh <- any (& message.WriteInsert {Record : rest .Record }).(T )
422+ clientCh <- any (& message.WriteInsert {Record : rest .Record }).(T )
438423 }
439424 } else {
440425 ensureOpened ()
441- inputCh <- r
426+ clientCh <- r
442427 s .limit .AddRows (1 )
443428 if s .limit .ReachedLimit () {
444429 closeFlush ()
@@ -456,11 +441,6 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
456441 ticker .Reset (s .batchTimeout )
457442 }
458443 done <- true
459- case err := <- outputCh :
460- if err != nil {
461- s .errCh <- fmt .Errorf ("handler failed on %s: %w" , tableName , err )
462- return
463- }
464444 case <- ctxDone :
465445 // this means the request was cancelled
466446 return // after this NO other call will succeed
0 commit comments