Skip to content

Commit fa3ed27

Browse files
committed
refactor on arrow record iterator
1 parent f707f16 commit fa3ed27

File tree

6 files changed

+266
-156
lines changed

6 files changed

+266
-156
lines changed

internal/rows/arrowbased/arrowRecordIterator.go

Lines changed: 101 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,59 @@ import (
1010
"github.com/apache/arrow/go/v12/arrow/ipc"
1111
"github.com/databricks/databricks-sql-go/internal/cli_service"
1212
"github.com/databricks/databricks-sql-go/internal/config"
13-
dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors"
1413
"github.com/databricks/databricks-sql-go/internal/rows/rowscanner"
1514
"github.com/databricks/databricks-sql-go/rows"
1615
)
1716

1817
func NewArrowRecordIterator(ctx context.Context, rpi rowscanner.ResultPageIterator, bi BatchIterator, arrowSchemaBytes []byte, cfg config.Config) rows.ArrowBatchIterator {
1918
ari := arrowRecordIterator{
20-
cfg: cfg,
21-
batchIterator: bi,
22-
resultPageIterator: rpi,
23-
ctx: ctx,
24-
arrowSchemaBytes: arrowSchemaBytes,
19+
cfg: cfg,
20+
ctx: ctx,
21+
arrowSchemaBytes: arrowSchemaBytes,
2522
}
2623

27-
return &ari
24+
if bi != nil && rpi != nil {
25+
// Both initial batch iterator and result page iterator
26+
// Extract the raw iterator from the initial batch iterator and create a composite
27+
if batchIter, ok := bi.(*batchIterator); ok {
28+
pagedRaw := &pagedRawBatchIterator{
29+
ctx: ctx,
30+
resultPageIterator: rpi,
31+
cfg: &cfg,
32+
startRowOffset: 0,
33+
}
34+
compositeRaw := NewInitialThenPagedRawIterator(batchIter.rawIterator, pagedRaw)
35+
ari.batchIterator = NewBatchIterator(compositeRaw, arrowSchemaBytes, &cfg)
36+
} else {
37+
// Fallback: use initial batch iterator, ignore pagination for now
38+
ari.batchIterator = bi
39+
}
40+
} else if bi != nil {
41+
// Only initial batch iterator
42+
ari.batchIterator = bi
43+
} else if rpi != nil {
44+
// Only result page iterator
45+
pagedRawIter := &pagedRawBatchIterator{
46+
ctx: ctx,
47+
resultPageIterator: rpi,
48+
cfg: &cfg,
49+
startRowOffset: 0,
50+
}
51+
ari.batchIterator = NewBatchIterator(pagedRawIter, arrowSchemaBytes, &cfg)
52+
}
2853

54+
return &ari
2955
}
3056

3157
// A type implemented DBSQLArrowBatchIterator
3258
type arrowRecordIterator struct {
33-
ctx context.Context
34-
cfg config.Config
35-
batchIterator BatchIterator
36-
resultPageIterator rowscanner.ResultPageIterator
37-
currentBatch SparkArrowBatch
38-
isFinished bool
39-
arrowSchemaBytes []byte
40-
arrowSchema *arrow.Schema
59+
ctx context.Context
60+
cfg config.Config
61+
batchIterator BatchIterator
62+
currentBatch SparkArrowBatch
63+
isFinished bool
64+
arrowSchemaBytes []byte
65+
arrowSchema *arrow.Schema
4166
}
4267

4368
var _ rows.ArrowBatchIterator = (*arrowRecordIterator)(nil)
@@ -80,18 +105,13 @@ func (ri *arrowRecordIterator) Close() {
80105
if ri.batchIterator != nil {
81106
ri.batchIterator.Close()
82107
}
83-
84-
if ri.resultPageIterator != nil {
85-
ri.resultPageIterator.Close()
86-
}
87108
}
88109
}
89110

90111
func (ri *arrowRecordIterator) checkFinished() {
91112
finished := ri.isFinished ||
92113
((ri.currentBatch == nil || !ri.currentBatch.HasNext()) &&
93-
(ri.batchIterator == nil || !ri.batchIterator.HasNext()) &&
94-
(ri.resultPageIterator == nil || !ri.resultPageIterator.HasNext()))
114+
(ri.batchIterator == nil || !ri.batchIterator.HasNext()))
95115

96116
if finished {
97117
// Reached end of result set so Close
@@ -101,86 +121,39 @@ func (ri *arrowRecordIterator) checkFinished() {
101121

102122
// Update the current batch if necessary
103123
func (ri *arrowRecordIterator) getCurrentBatch() error {
104-
105124
// only need to update if no current batch or current batch has no more records
106125
if ri.currentBatch == nil || !ri.currentBatch.HasNext() {
107-
108-
// ensure up to date batch iterator
109-
err := ri.getBatchIterator()
110-
if err != nil {
111-
return err
112-
}
113-
114126
// release current batch
115127
if ri.currentBatch != nil {
116128
ri.currentBatch.Close()
117129
}
118130

119131
// Get next batch from batch iterator
120-
ri.currentBatch, err = ri.batchIterator.Next()
121-
if err != nil {
122-
return err
132+
if ri.batchIterator == nil {
133+
return io.EOF
123134
}
124-
}
125-
126-
return nil
127-
}
128135

129-
// Update batch iterator if necessary
130-
func (ri *arrowRecordIterator) getBatchIterator() error {
131-
// only need to update if there is no batch iterator or the
132-
// batch iterator has no more batches
133-
if ri.batchIterator == nil || !ri.batchIterator.HasNext() {
134-
if ri.batchIterator != nil {
135-
// release any resources held by the current batch iterator
136-
ri.batchIterator.Close()
137-
ri.batchIterator = nil
138-
}
139-
140-
// Get the next page of the result set
141-
resp, err := ri.resultPageIterator.Next()
136+
var err error
137+
ri.currentBatch, err = ri.batchIterator.Next()
142138
if err != nil {
143139
return err
144140
}
145141

146-
// Check the result format
147-
resultFormat := resp.ResultSetMetadata.GetResultFormat()
148-
if resultFormat != cli_service.TSparkRowSetType_ARROW_BASED_SET && resultFormat != cli_service.TSparkRowSetType_URL_BASED_SET {
149-
return dbsqlerr.NewDriverError(ri.ctx, errArrowRowsNotArrowFormat, nil)
150-
}
151-
142+
// Update schema bytes if we don't have them yet and the batch iterator got them
152143
if ri.arrowSchemaBytes == nil {
153-
ri.arrowSchemaBytes = resp.ResultSetMetadata.ArrowSchema
154-
}
155-
156-
// Create a new batch iterator for the batches in the result page
157-
bi, err := ri.newBatchIterator(resp)
158-
if err != nil {
159-
return err
144+
if batchIter, ok := ri.batchIterator.(*batchIterator); ok {
145+
if pagedIter, ok := batchIter.rawIterator.(*pagedRawBatchIterator); ok {
146+
if schemaBytes := pagedIter.GetSchemaBytes(); schemaBytes != nil {
147+
ri.arrowSchemaBytes = schemaBytes
148+
}
149+
}
150+
}
160151
}
161-
162-
ri.batchIterator = bi
163152
}
164153

165154
return nil
166155
}
167156

168-
// Create a new batch iterator from a page of the result set
169-
func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) {
170-
rowSet := fr.Results
171-
var rawBi RawBatchIterator
172-
var err error
173-
if len(rowSet.ResultLinks) > 0 {
174-
rawBi, err = NewCloudRawBatchIterator(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg)
175-
} else {
176-
rawBi, err = NewLocalRawBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
177-
}
178-
if err != nil {
179-
return nil, err
180-
}
181-
return NewBatchIterator(rawBi, ri.arrowSchemaBytes, &ri.cfg), nil
182-
}
183-
184157
// Return the schema of the records.
185158
func (ri *arrowRecordIterator) Schema() (*arrow.Schema, error) {
186159
// Return cached schema if available
@@ -213,3 +186,51 @@ func (ri *arrowRecordIterator) Schema() (*arrow.Schema, error) {
213186
ri.arrowSchema = reader.Schema()
214187
return ri.arrowSchema, nil
215188
}
189+
190+
// InitialThenPagedRawIterator handles initial raw iterator first, then paged raw iterator
191+
type InitialThenPagedRawIterator struct {
192+
InitialRaw RawBatchIterator
193+
PagedRaw RawBatchIterator
194+
}
195+
196+
// NewInitialThenPagedRawIterator creates a composite iterator
197+
func NewInitialThenPagedRawIterator(initial, paged RawBatchIterator) RawBatchIterator {
198+
return &InitialThenPagedRawIterator{
199+
InitialRaw: initial,
200+
PagedRaw: paged,
201+
}
202+
}
203+
204+
func (i *InitialThenPagedRawIterator) Next() (*cli_service.TSparkArrowBatch, error) {
205+
if i.InitialRaw != nil && i.InitialRaw.HasNext() {
206+
return i.InitialRaw.Next()
207+
}
208+
if i.PagedRaw != nil {
209+
return i.PagedRaw.Next()
210+
}
211+
return nil, io.EOF
212+
}
213+
214+
func (i *InitialThenPagedRawIterator) HasNext() bool {
215+
return (i.InitialRaw != nil && i.InitialRaw.HasNext()) ||
216+
(i.PagedRaw != nil && i.PagedRaw.HasNext())
217+
}
218+
219+
func (i *InitialThenPagedRawIterator) Close() {
220+
if i.InitialRaw != nil {
221+
i.InitialRaw.Close()
222+
}
223+
if i.PagedRaw != nil {
224+
i.PagedRaw.Close()
225+
}
226+
}
227+
228+
func (i *InitialThenPagedRawIterator) GetStartRowOffset() int64 {
229+
if i.InitialRaw != nil && i.InitialRaw.HasNext() {
230+
return i.InitialRaw.GetStartRowOffset()
231+
}
232+
if i.PagedRaw != nil {
233+
return i.PagedRaw.GetStartRowOffset()
234+
}
235+
return 0
236+
}

0 commit comments

Comments
 (0)