diff --git a/go/adbc/driver/snowflake/record_reader.go b/go/adbc/driver/snowflake/record_reader.go index 9e50303afd..5ccde3a107 100644 --- a/go/adbc/driver/snowflake/record_reader.go +++ b/go/adbc/driver/snowflake/record_reader.go @@ -524,6 +524,7 @@ type reader struct { err error cancelFn context.CancelFunc + done chan struct{} // signals all producer goroutines have finished } func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake.ArrowStreamLoader, bufferSize, prefetchConcurrency int, useHighPrecision bool, maxTimestampPrecision MaxTimestampPrecision) (array.RecordReader, error) { @@ -631,35 +632,26 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake return array.NewRecordReader(schema, results) } - ch := make(chan arrow.RecordBatch, bufferSize) - group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc)) - ctx, cancelFn := context.WithCancel(ctx) - group.SetLimit(prefetchConcurrency) - - defer func() { - if err != nil { - close(ch) - cancelFn() - } - }() - - chs := make([]chan arrow.RecordBatch, len(batches)) - rdr := &reader{ - refCount: 1, - chs: chs, - err: nil, - cancelFn: cancelFn, - } - + // Handle empty batches case early if len(batches) == 0 { schema, err := rowTypesToArrowSchema(ctx, ld, useHighPrecision, maxTimestampPrecision) if err != nil { return nil, err } + _, cancelFn := context.WithCancel(ctx) + rdr := &reader{ + refCount: 1, + chs: nil, + err: nil, + cancelFn: cancelFn, + done: make(chan struct{}), + } + close(rdr.done) // No goroutines to wait for rdr.schema, _ = getTransformer(schema, ld, useHighPrecision, maxTimestampPrecision) return rdr, nil } + // Do all error-prone initialization first, before starting goroutines r, err := batches[0].GetStream(ctx) if err != nil { return nil, errToAdbcErr(adbc.StatusIO, err) @@ -667,12 +659,32 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake rr, err := ipc.NewReader(r, ipc.WithAllocator(alloc)) if err != nil { + _ = r.Close() // Clean up the stream return nil, adbc.Error{ Msg: err.Error(), Code: adbc.StatusInvalidState, } } + // Now setup concurrency primitives after error-prone operations + group, ctx := errgroup.WithContext(compute.WithAllocator(ctx, alloc)) + ctx, cancelFn := context.WithCancel(ctx) + group.SetLimit(prefetchConcurrency) + + // Initialize all channels upfront to avoid race condition + chs := make([]chan arrow.RecordBatch, len(batches)) + for i := range chs { + chs[i] = make(chan arrow.RecordBatch, bufferSize) + } + + rdr := &reader{ + refCount: 1, + chs: chs, + err: nil, + cancelFn: cancelFn, + done: make(chan struct{}), + } + var recTransform recordTransformer rdr.schema, recTransform = getTransformer(rr.Schema(), ld, useHighPrecision, maxTimestampPrecision) @@ -682,7 +694,7 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake err = errors.Join(err, r.Close()) }() if len(batches) > 1 { - defer close(ch) + defer close(chs[0]) } for rr.Next() && ctx.Err() == nil { @@ -691,18 +703,25 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake if err != nil { return err } - ch <- rec + + // Use context-aware send to prevent deadlock + select { + case chs[0] <- rec: + // Successfully sent + case <-ctx.Done(): + // Context cancelled, clean up and exit + rec.Release() + return ctx.Err() + } } return rr.Err() }) - chs[0] = ch - lastChannelIndex := len(chs) - 1 go func() { for i, b := range batches[1:] { batch, batchIdx := b, i+1 - chs[batchIdx] = make(chan arrow.RecordBatch, bufferSize) + // Channels already initialized above, no need to create them here group.Go(func() (err error) { // close channels (except the last) so that Next can move on to the next channel properly if batchIdx != lastChannelIndex { @@ -729,7 +748,16 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake if err != nil { return err } - chs[batchIdx] <- rec + + // Use context-aware send to prevent deadlock + select { + case chs[batchIdx] <- rec: + // Successfully sent + case <-ctx.Done(): + // Context cancelled, clean up and exit + rec.Release() + return ctx.Err() + } } return rr.Err() @@ -744,6 +772,8 @@ func newRecordReader(ctx context.Context, alloc memory.Allocator, ld gosnowflake // don't close the last channel until after the group is finished, // so that Next() can only return after reader.err may have been set close(chs[lastChannelIndex]) + // Signal that all producer goroutines have finished + close(rdr.done) }() return rdr, nil @@ -795,7 +825,17 @@ func (r *reader) Release() { r.rec.Release() } r.cancelFn() + + // Wait for all producer goroutines to finish before draining channels + // This prevents deadlock where producers are blocked on sends + <-r.done + + // Now safely drain remaining data from channels + // All channels should be closed at this point for _, ch := range r.chs { + if ch == nil { + continue + } for rec := range ch { rec.Release() }