Skip to content

Commit 6d7be0b

Browse files
authored
feat: Allow sync to be cancelled when in progress (#1334)
In some cases, when the resource channel is processing a large collection of objects, the sync process can continue even if the context cancel has been triggered. This is problematic for billing and usage since a sync can significantly overshoot the remaining rows quota. To allow an in-progress sync to abort quicker, we need to modify a number of channel patterns from using `range` to using `select` - which will allow the `context.Done` channel to be monitored. This PR also adds the `OnSyncFinisher` hook which is called when the sync has finished adding messages. This allows the client code to call the batch updater close method without missing any updates. fixes: cloudquery/cloudquery-issues#750
1 parent f15a89d commit 6d7be0b

File tree

5 files changed

+120
-6
lines changed

5 files changed

+120
-6
lines changed

internal/servers/plugin/v3/plugin.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ func (s *Server) Sync(req *pb.Sync_Request, stream pb.Plugin_SyncServer) error {
210210
}
211211
}
212212

213+
if err := s.Plugin.OnSyncFinish(ctx); err != nil {
214+
return status.Errorf(codes.Internal, "failed to finish sync: %v", err)
215+
}
216+
213217
return syncErr
214218
}
215219

plugin/plugin.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,19 @@ func (p *Plugin) OnBeforeSend(ctx context.Context, msg message.SyncMessage) (mes
141141
return msg, nil
142142
}
143143

144+
// OnSyncFinisher is an interface that can be implemented by a plugin client to be notified when a sync finishes.
145+
type OnSyncFinisher interface {
146+
OnSyncFinish(context.Context) error
147+
}
148+
149+
// OnSyncFinish gets called after a sync finishes.
150+
func (p *Plugin) OnSyncFinish(ctx context.Context) error {
151+
if v, ok := p.client.(OnSyncFinisher); ok {
152+
return v.OnSyncFinish(ctx)
153+
}
154+
return nil
155+
}
156+
144157
// IsStaticLinkingEnabled whether static linking is to be enabled
145158
func (p *Plugin) IsStaticLinkingEnabled() bool {
146159
return p.staticLinking

scheduler/scheduler.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"github.com/apache/arrow/go/v14/arrow"
78
"runtime/debug"
89
"sync/atomic"
910
"time"
@@ -182,15 +183,23 @@ func (s *Scheduler) Sync(ctx context.Context, client schema.ClientMeta, tables s
182183
}
183184
}()
184185
for resource := range resources {
185-
vector := resource.GetValues()
186-
bldr := array.NewRecordBuilder(memory.DefaultAllocator, resource.Table.ToArrowSchema())
187-
scalar.AppendToRecordBuilder(bldr, vector)
188-
rec := bldr.NewRecord()
189-
res <- &message.SyncInsert{Record: rec}
186+
select {
187+
case res <- &message.SyncInsert{Record: resourceToRecord(resource)}:
188+
case <-ctx.Done():
189+
return ctx.Err()
190+
}
190191
}
191192
return nil
192193
}
193194

195+
func resourceToRecord(resource *schema.Resource) arrow.Record {
196+
vector := resource.GetValues()
197+
bldr := array.NewRecordBuilder(memory.DefaultAllocator, resource.Table.ToArrowSchema())
198+
scalar.AppendToRecordBuilder(bldr, vector)
199+
rec := bldr.NewRecord()
200+
return rec
201+
}
202+
194203
func (s *syncClient) logTablesMetrics(tables schema.Tables, client Client) {
195204
clientName := client.ID()
196205
for _, table := range tables {

scheduler/scheduler_dfs.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,10 @@ func (s *syncClient) resolveResourcesDfs(ctx context.Context, table *schema.Tabl
176176
atomic.AddUint64(&tableMetrics.Errors, 1)
177177
return
178178
}
179-
resourcesChan <- resolvedResource
179+
select {
180+
case resourcesChan <- resolvedResource:
181+
case <-ctx.Done():
182+
}
180183
}()
181184
}
182185
wg.Wait()

scheduler/scheduler_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package scheduler
22

33
import (
44
"context"
5+
"fmt"
6+
"github.com/stretchr/testify/assert"
7+
"github.com/stretchr/testify/require"
58
"testing"
69

710
"github.com/apache/arrow/go/v14/arrow"
@@ -40,6 +43,22 @@ func testColumnResolverPanic(context.Context, schema.ClientMeta, *schema.Resourc
4043
panic("ColumnResolver")
4144
}
4245

46+
func testTableSuccessWithData(data []any) *schema.Table {
47+
return &schema.Table{
48+
Name: "test_table_success",
49+
Resolver: func(_ context.Context, _ schema.ClientMeta, _ *schema.Resource, res chan<- any) error {
50+
res <- data
51+
return nil
52+
},
53+
Columns: []schema.Column{
54+
{
55+
Name: "test_column",
56+
Type: arrow.PrimitiveTypes.Int64,
57+
},
58+
},
59+
}
60+
}
61+
4362
func testTableSuccess() *schema.Table {
4463
return &schema.Table{
4564
Name: "test_table_success",
@@ -233,6 +252,72 @@ func TestScheduler(t *testing.T) {
233252
}
234253
}
235254

255+
func TestScheduler_Cancellation(t *testing.T) {
256+
data := make([]any, 100)
257+
258+
tests := []struct {
259+
name string
260+
data []any
261+
cancel bool
262+
messageCount int
263+
}{
264+
{
265+
name: "should consume all message",
266+
data: data,
267+
cancel: false,
268+
messageCount: len(data) + 1, // 9 data + 1 migration message
269+
},
270+
{
271+
name: "should not consume all message on cancel",
272+
data: data,
273+
cancel: true,
274+
messageCount: len(data) + 1, // 9 data + 1 migration message
275+
},
276+
}
277+
278+
for _, strategy := range AllStrategies {
279+
for _, tc := range tests {
280+
tc := tc
281+
t.Run(fmt.Sprintf("%s_%s", tc.name, strategy.String()), func(t *testing.T) {
282+
sc := NewScheduler(WithLogger(zerolog.New(zerolog.NewTestWriter(t))), WithStrategy(strategy))
283+
284+
messages := make(chan message.SyncMessage)
285+
ctx, cancel := context.WithCancel(context.Background())
286+
defer cancel()
287+
288+
go func() {
289+
err := sc.Sync(
290+
ctx,
291+
&testExecutionClient{},
292+
[]*schema.Table{testTableSuccessWithData(tc.data)},
293+
messages,
294+
)
295+
if tc.cancel {
296+
assert.Equal(t, err, context.Canceled)
297+
} else {
298+
require.NoError(t, err)
299+
}
300+
close(messages)
301+
}()
302+
303+
messageConsumed := 0
304+
for range messages {
305+
if tc.cancel {
306+
cancel()
307+
}
308+
messageConsumed++
309+
}
310+
311+
if tc.cancel {
312+
assert.NotEqual(t, tc.messageCount, messageConsumed)
313+
} else {
314+
assert.Equal(t, tc.messageCount, messageConsumed)
315+
}
316+
})
317+
}
318+
}
319+
}
320+
236321
func testSyncTable(t *testing.T, tc syncTestCase, strategy Strategy, deterministicCQID bool) {
237322
ctx := context.Background()
238323
tables := []*schema.Table{}

0 commit comments

Comments
 (0)