Skip to content

Commit 806c85d

Browse files
authored
fix(writers): StreamingBatchWriter hangs with non-append mode (#1131)
Fixes the hang [issue](cloudquery/cloudquery#12793) with non-append mode when a dest is using `UnimplementedDeleteStale`. - This also makes it so that it closes things when Write is done: Otherwise our testing framework confuses things (by invoking Write multiple times for multiple tests in the same plugin, I think). We're noticing this now because of cloudquery/cloudquery#12765, as other plugins using this writer (file based plugins) don't support migration operations. Also makes it so that you don't need to call Close() on plugin deinit (but it could still be a good idea if Write was interrupted) --------- Co-authored-by: Kemal Hadimli <[email protected]>
1 parent 0c47570 commit 806c85d

File tree

4 files changed

+28
-18
lines changed

4 files changed

+28
-18
lines changed
Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,25 @@
11
package streamingbatchwriter
22

33
import (
4+
"sync"
45
"time"
56

67
"github.com/cloudquery/plugin-sdk/v4/writers"
78
)
89

910
type mockTicker struct {
10-
expire chan time.Time
11+
expire chan time.Time
12+
stopped sync.Once
1113
}
1214

1315
func (t *mockTicker) Stop() {
14-
close(t.expire)
16+
t.stopped.Do(func() {
17+
close(t.expire)
18+
})
19+
}
20+
21+
func (t *mockTicker) Tick() {
22+
t.expire <- time.Now()
1523
}
1624

1725
func (*mockTicker) Reset(time.Duration) {}
@@ -20,12 +28,12 @@ func (t *mockTicker) Chan() <-chan time.Time {
2028
return t.expire
2129
}
2230

23-
func newMockTicker() (writers.TickerFunc, chan<- time.Time) {
31+
func newMockTicker() (writers.TickerFunc, func()) {
2432
expire := make(chan time.Time)
2533
t := &mockTicker{
2634
expire: expire,
2735
}
2836
return func(time.Duration) writers.Ticker {
2937
return t
30-
}, expire
38+
}, t.Tick
3139
}

writers/streamingbatchwriter/streamingbatchwriter.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,17 @@ func (w *StreamingBatchWriter) Close(context.Context) error {
156156
}
157157
w.workersWaitGroup.Wait()
158158

159-
w.insertWorkers = nil
159+
w.insertWorkers = make(map[string]*streamingWorkerManager[*message.WriteInsert])
160+
w.migrateWorker = nil
161+
w.deleteWorker = nil
162+
w.lastMsgType = writers.MsgTypeUnset
160163

161164
return nil
162165
}
163166

164167
func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.WriteMessage) error {
165168
errCh := make(chan error)
169+
defer close(errCh)
166170

167171
go func() {
168172
for err := range errCh {
@@ -172,7 +176,7 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
172176

173177
for msg := range msgs {
174178
msgType := writers.MsgID(msg)
175-
if w.lastMsgType != msgType {
179+
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
176180
if err := w.Flush(ctx); err != nil {
177181
return err
178182
}
@@ -183,12 +187,7 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
183187
}
184188
}
185189

186-
if err := w.Flush(ctx); err != nil {
187-
return err
188-
}
189-
190-
close(errCh)
191-
return nil
190+
return w.Close(ctx)
192191
}
193192

194193
func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- error, msg message.WriteMessage) error {

writers/streamingbatchwriter/streamingbatchwriter_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ func TestStreamingBatchTimeout(t *testing.T) {
227227
ch := make(chan message.WriteMessage)
228228

229229
testClient := newClient()
230-
tickerFn, expire := newMockTicker()
230+
tickerFn, tickFn := newMockTicker()
231231

232232
wr, err := New(testClient, withTickerFn(tickerFn))
233233
if err != nil {
@@ -258,7 +258,7 @@ func TestStreamingBatchTimeout(t *testing.T) {
258258
}
259259

260260
// flush
261-
close(expire)
261+
tickFn()
262262
waitForLength(t, testClient.MessageLen, messageTypeInsert, 1)
263263

264264
close(ch)
@@ -332,7 +332,7 @@ func TestStreamingBatchUpserts(t *testing.T) {
332332
ch := make(chan message.WriteMessage)
333333

334334
testClient := newClient()
335-
tickerFn, expire := newMockTicker()
335+
tickerFn, tickFn := newMockTicker()
336336
wr, err := New(testClient, WithBatchSizeRows(2), withTickerFn(tickerFn))
337337
if err != nil {
338338
t.Fatal(err)
@@ -363,7 +363,7 @@ func TestStreamingBatchUpserts(t *testing.T) {
363363
time.Sleep(50 * time.Millisecond)
364364

365365
// flush the batch
366-
close(expire)
366+
tickFn()
367367
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
368368

369369
close(ch)

writers/streamingbatchwriter/unimplemented.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@ func (IgnoreMigrateTable) MigrateTable(_ context.Context, ch <-chan *message.Wri
1818
return nil
1919
}
2020

21-
// UnimplementedDeleteStale is a dummy handler to error on DeleteStale messages
21+
// UnimplementedDeleteStale is a dummy handler to consume and error on DeleteStale messages
2222
type UnimplementedDeleteStale struct{}
2323

24-
func (UnimplementedDeleteStale) DeleteStale(_ context.Context, _ <-chan *message.WriteDeleteStale) error {
24+
func (UnimplementedDeleteStale) DeleteStale(_ context.Context, ch <-chan *message.WriteDeleteStale) error {
25+
// nolint:revive
26+
for range ch {
27+
}
2528
return fmt.Errorf("DeleteStale: %w", plugin.ErrNotImplemented)
2629
}

0 commit comments

Comments
 (0)