Skip to content

Commit 269357f

Browse files
authored
fix(go): fix potential deadlocks in reader (#60)
## What's Changed ## Fix Critical Deadlocks and Race Conditions in Snowflake Record Reader This PR addresses multiple critical concurrency issues in the Snowflake driver's `recordReader` that could cause complete application hangs under normal racing conditions. ### Issues Fixed *1. Critical Deadlock: `Release()` Blocking Forever* *Problem*: When `Release()` was called while producer goroutines were blocked on channel sends, a permanent deadlock occurred: * `Release()` cancels context and attempts to drain channels * Producer goroutines blocked on `ch <- rec` cannot see the cancellation * Channels never close because producers never exit * `Release()` blocks forever on `for rec := range ch` *Fix:* Added a `done` channel that signals when all producer goroutines have completed. `Release()` now waits for this signal before attempting to drain channels. *2. Severe Deadlock: Non-Context-Aware Channel Sends* *Problem:* Channel send operations at lines 694 and 732 checked context before the send but not during: ```go for rr.Next() && ctx.Err() == nil { // Context checked here // ... ch <- rec // But send blocks here without checking context } ``` *Fix:* Wrapped all channel sends in `select` statements with context awareness: ```go select { case chs[0] <- rec: // Successfully sent case <-ctx.Done(): rec.Release() return ctx.Err() } ``` *3. Critical Race Condition: Nil Channel Reads* *Problem:* Channels were created asynchronously in goroutines after `newRecordReader` returned. If `Next()` was called quickly after creation, it could read from uninitialized (nil) channels, causing infinite blocking. *Fix:* Initialize all channels upfront before starting any goroutines: ```go chs := make([]chan arrow.RecordBatch, len(batches)) for i := range chs { chs[i] = make(chan arrow.RecordBatch, bufferSize) } ``` *4. Goroutine Leaks on Initialization Errors* *Problem:* Error paths only cleaned up the first channel, potentially leaking goroutines if initialization failed after starting concurrent operations. *Fix:* Moved all error-prone initialization (GetStream, NewReader) before goroutine creation, and added proper cleanup on errors. ---------------------- #### Changes * Added `done` channel to `reader` struct to signal goroutine completion * Initialize all channels upfront to eliminate race conditions * Use context-aware sends with `select` statements for all channel operations * Update `Release()` to wait on `done` channel before draining * Reorganize initialization to handle errors before starting goroutines * Signal completion by closing `done` channel after all producers finish #### Reproduction Scenarios Prevented *Deadlock:* 1. bufferSize = 1, producer generates 2 records quickly 2. Channel becomes full after first record 3. Producer blocks on send 4. Consumer calls Release() before Next() 5. Without fix: permanent deadlock 6. With fix: producer responds to cancellation, Release() completes *Race Condition:* 1. Query returns 3 batches 2. First batch processes quickly 3. Next() advances to second channel 4. Without fix: reads from nil channel, blocks forever 5. With fix: channel already initialized, works correctly Backport of apache/arrow-adbc#3870.
1 parent ea369d9 commit 269357f

File tree

1 file changed

+66
-26
lines changed

1 file changed

+66
-26
lines changed

go/record_reader.go

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ type reader struct {
548548
err error
549549

550550
cancelFn context.CancelFunc
551+
done chan struct{} // signals all producer goroutines have finished
551552
}
552553

553554
func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake.ArrowStreamLoader, bufferSize, prefetchConcurrency int, useHighPrecision bool, maxTimestampPrecision MaxTimestampPrecision) (array.RecordReader, error) {
@@ -655,48 +656,59 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
655656
return array.NewRecordReader(schema, results)
656657
}
657658

658-
ch := make(chan arrow.RecordBatch, bufferSize)
659-
group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
660-
ctx, cancelFn := context.WithCancel(ctx)
661-
group.SetLimit(prefetchConcurrency)
662-
663-
defer func() {
664-
if err != nil {
665-
close(ch)
666-
cancelFn()
667-
}
668-
}()
669-
670-
chs := make([]chan arrow.RecordBatch, len(batches))
671-
rdr := &reader{
672-
refCount: 1,
673-
chs: chs,
674-
err: nil,
675-
cancelFn: cancelFn,
676-
}
677-
659+
// Handle empty batches case early
678660
if len(batches) == 0 {
679661
schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision, maxTimestampPrecision)
680662
if err != nil {
681663
return nil, err
682664
}
665+
_, cancelFn := context.WithCancel(ctx)
666+
rdr := &reader{
667+
refCount: 1,
668+
chs: nil,
669+
err: nil,
670+
cancelFn: cancelFn,
671+
done: make(chan struct{}),
672+
}
673+
close(rdr.done) // No goroutines to wait for
683674
rdr.schema, _ = getTransformer(schema, ld, useHighPrecision, maxTimestampPrecision)
684675
return rdr, nil
685676
}
686677

678+
// Do all error-prone initialization first, before starting goroutines
687679
r, err := batches[0].GetStream(ctx)
688680
if err != nil {
689681
return nil, errToAdbcErr(adbc.StatusIO, err)
690682
}
691683

692684
rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
693685
if err != nil {
686+
_ = r.Close() // Clean up the stream
694687
return nil, adbc.Error{
695688
Msg: err.Error(),
696689
Code: adbc.StatusInvalidState,
697690
}
698691
}
699692

693+
// Now setup concurrency primitives after error-prone operations
694+
group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
695+
ctx, cancelFn := context.WithCancel(ctx)
696+
group.SetLimit(prefetchConcurrency)
697+
698+
// Initialize all channels upfront to avoid race condition
699+
chs := make([]chan arrow.RecordBatch, len(batches))
700+
for i := range chs {
701+
chs[i] = make(chan arrow.RecordBatch, bufferSize)
702+
}
703+
704+
rdr := &reader{
705+
refCount: 1,
706+
chs: chs,
707+
err: nil,
708+
cancelFn: cancelFn,
709+
done: make(chan struct{}),
710+
}
711+
700712
var recTransform recordTransformer
701713
rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision, maxTimestampPrecision)
702714

@@ -706,7 +718,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
706718
err = errors.Join(err, r.Close())
707719
}()
708720
if len(batches) > 1 {
709-
defer close(ch)
721+
defer close(chs[0])
710722
}
711723

712724
for rr.Next() && ctx.Err() == nil {
@@ -715,18 +727,25 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
715727
if err != nil {
716728
return err
717729
}
718-
ch <- rec
730+
731+
// Use context-aware send to prevent deadlock
732+
select {
733+
case chs[0] <- rec:
734+
// Successfully sent
735+
case <-ctx.Done():
736+
// Context cancelled, clean up and exit
737+
rec.Release()
738+
return ctx.Err()
739+
}
719740
}
720741
return rr.Err()
721742
})
722743

723-
chs[0] = ch
724-
725744
lastChannelIndex := len(chs) - 1
726745
go func() {
727746
for i, b := range batches[1:] {
728747
batch, batchIdx := b, i+1
729-
chs[batchIdx] = make(chan arrow.RecordBatch, bufferSize)
748+
// Channels already initialized above, no need to create them here
730749
group.Go(func() (err error) {
731750
// close channels (except the last) so that Next can move on to the next channel properly
732751
if batchIdx != lastChannelIndex {
@@ -753,7 +772,16 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
753772
if err != nil {
754773
return err
755774
}
756-
chs[batchIdx] <- rec
775+
776+
// Use context-aware send to prevent deadlock
777+
select {
778+
case chs[batchIdx] <- rec:
779+
// Successfully sent
780+
case <-ctx.Done():
781+
// Context cancelled, clean up and exit
782+
rec.Release()
783+
return ctx.Err()
784+
}
757785
}
758786

759787
return rr.Err()
@@ -768,6 +796,8 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
768796
// don't close the last channel until after the group is finished,
769797
// so that Next() can only return after reader.err may have been set
770798
close(chs[lastChannelIndex])
799+
// Signal that all producer goroutines have finished
800+
close(rdr.done)
771801
}()
772802

773803
return rdr, nil
@@ -819,7 +849,17 @@ func (r *reader) Release() {
819849
r.rec.Release()
820850
}
821851
r.cancelFn()
852+
853+
// Wait for all producer goroutines to finish before draining channels
854+
// This prevents deadlock where producers are blocked on sends
855+
<-r.done
856+
857+
// Now safely drain remaining data from channels
858+
// All channels should be closed at this point
822859
for _, ch := range r.chs {
860+
if ch == nil {
861+
continue
862+
}
823863
for rec := range ch {
824864
rec.Release()
825865
}

0 commit comments

Comments
 (0)