Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 66 additions & 26 deletions go/adbc/driver/snowflake/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ type reader struct {
err error

cancelFn context.CancelFunc
done chan struct{} // signals all producer goroutines have finished
}

func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake.ArrowStreamLoader, bufferSize, prefetchConcurrency int, useHighPrecision bool, maxTimestampPrecision MaxTimestampPrecision) (array.RecordReader, error) {
Expand Down Expand Up @@ -631,48 +632,59 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
return array.NewRecordReader(schema, results)
}

ch := make(chan arrow.RecordBatch, bufferSize)
group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
ctx, cancelFn := context.WithCancel(ctx)
group.SetLimit(prefetchConcurrency)

defer func() {
if err != nil {
close(ch)
cancelFn()
}
}()

chs := make([]chan arrow.RecordBatch, len(batches))
rdr := &reader{
refCount: 1,
chs: chs,
err: nil,
cancelFn: cancelFn,
}

// Handle empty batches case early
if len(batches) == 0 {
schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision, maxTimestampPrecision)
if err != nil {
return nil, err
}
_, cancelFn := context.WithCancel(ctx)
rdr := &reader{
refCount: 1,
chs: nil,
err: nil,
cancelFn: cancelFn,
done: make(chan struct{}),
}
close(rdr.done) // No goroutines to wait for
rdr.schema, _ = getTransformer(schema, ld, useHighPrecision, maxTimestampPrecision)
return rdr, nil
}

// Do all error-prone initialization first, before starting goroutines
r, err := batches[0].GetStream(ctx)
if err != nil {
return nil, errToAdbcErr(adbc.StatusIO, err)
}

rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc))
if err != nil {
_ = r.Close() // Clean up the stream
return nil, adbc.Error{
Msg: err.Error(),
Code: adbc.StatusInvalidState,
}
}

// Now setup concurrency primitives after error-prone operations
group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc))
ctx, cancelFn := context.WithCancel(ctx)
group.SetLimit(prefetchConcurrency)

// Initialize all channels upfront to avoid race condition
chs := make([]chan arrow.RecordBatch, len(batches))
for i := range chs {
chs[i] = make(chan arrow.RecordBatch, bufferSize)
}

rdr := &reader{
refCount: 1,
chs: chs,
err: nil,
cancelFn: cancelFn,
done: make(chan struct{}),
}

var recTransform recordTransformer
rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision, maxTimestampPrecision)

Expand All @@ -682,7 +694,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
err = errors.Join(err, r.Close())
}()
if len(batches) > 1 {
defer close(ch)
defer close(chs[0])
}

for rr.Next() && ctx.Err() == nil {
Expand All @@ -691,18 +703,25 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
if err != nil {
return err
}
ch <- rec

// Use context-aware send to prevent deadlock
select {
case chs[0] <- rec:
// Successfully sent
case <-ctx.Done():
// Context cancelled, clean up and exit
rec.Release()
return ctx.Err()
}
}
return rr.Err()
})

chs[0] = ch

lastChannelIndex := len(chs) - 1
go func() {
for i, b := range batches[1:] {
batch, batchIdx := b, i+1
chs[batchIdx] = make(chan arrow.RecordBatch, bufferSize)
// Channels already initialized above, no need to create them here
group.Go(func() (err error) {
// close channels (except the last) so that Next can move on to the next channel properly
if batchIdx != lastChannelIndex {
Expand All @@ -729,7 +748,16 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
if err != nil {
return err
}
chs[batchIdx] <- rec

// Use context-aware send to prevent deadlock
select {
case chs[batchIdx] <- rec:
// Successfully sent
case <-ctx.Done():
// Context cancelled, clean up and exit
rec.Release()
return ctx.Err()
}
}

return rr.Err()
Expand All @@ -744,6 +772,8 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake
// don't close the last channel until after the group is finished,
// so that Next() can only return after reader.err may have been set
close(chs[lastChannelIndex])
// Signal that all producer goroutines have finished
close(rdr.done)
}()

return rdr, nil
Expand Down Expand Up @@ -795,7 +825,17 @@ func (r *reader) Release() {
r.rec.Release()
}
r.cancelFn()

// Wait for all producer goroutines to finish before draining channels
// This prevents deadlock where producers are blocked on sends
<-r.done

// Now safely drain remaining data from channels
// All channels should be closed at this point
for _, ch := range r.chs {
if ch == nil {
continue
}
for rec := range ch {
rec.Release()
}
Expand Down
Loading