@@ -10,34 +10,59 @@ import (
1010 "github.com/apache/arrow/go/v12/arrow/ipc"
1111 "github.com/databricks/databricks-sql-go/internal/cli_service"
1212 "github.com/databricks/databricks-sql-go/internal/config"
13- dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors"
1413 "github.com/databricks/databricks-sql-go/internal/rows/rowscanner"
1514 "github.com/databricks/databricks-sql-go/rows"
1615)
1716
1817func NewArrowRecordIterator (ctx context.Context , rpi rowscanner.ResultPageIterator , bi BatchIterator , arrowSchemaBytes []byte , cfg config.Config ) rows.ArrowBatchIterator {
1918 ari := arrowRecordIterator {
20- cfg : cfg ,
21- batchIterator : bi ,
22- resultPageIterator : rpi ,
23- ctx : ctx ,
24- arrowSchemaBytes : arrowSchemaBytes ,
19+ cfg : cfg ,
20+ ctx : ctx ,
21+ arrowSchemaBytes : arrowSchemaBytes ,
2522 }
2623
27- return & ari
24+ if bi != nil && rpi != nil {
25+ // Both initial batch iterator and result page iterator
26+ // Extract the raw iterator from the initial batch iterator and create a composite
27+ if batchIter , ok := bi .(* batchIterator ); ok {
28+ pagedRaw := & pagedRawBatchIterator {
29+ ctx : ctx ,
30+ resultPageIterator : rpi ,
31+ cfg : & cfg ,
32+ startRowOffset : 0 ,
33+ }
34+ compositeRaw := NewInitialThenPagedRawIterator (batchIter .rawIterator , pagedRaw )
35+ ari .batchIterator = NewBatchIterator (compositeRaw , arrowSchemaBytes , & cfg )
36+ } else {
37+ // Fallback: use initial batch iterator, ignore pagination for now
38+ ari .batchIterator = bi
39+ }
40+ } else if bi != nil {
41+ // Only initial batch iterator
42+ ari .batchIterator = bi
43+ } else if rpi != nil {
44+ // Only result page iterator
45+ pagedRawIter := & pagedRawBatchIterator {
46+ ctx : ctx ,
47+ resultPageIterator : rpi ,
48+ cfg : & cfg ,
49+ startRowOffset : 0 ,
50+ }
51+ ari .batchIterator = NewBatchIterator (pagedRawIter , arrowSchemaBytes , & cfg )
52+ }
2853
54+ return & ari
2955}
3056
3157// A type implemented DBSQLArrowBatchIterator
3258type arrowRecordIterator struct {
33- ctx context.Context
34- cfg config.Config
35- batchIterator BatchIterator
36- resultPageIterator rowscanner.ResultPageIterator
37- currentBatch SparkArrowBatch
38- isFinished bool
39- arrowSchemaBytes []byte
40- arrowSchema * arrow.Schema
59+ ctx context.Context
60+ cfg config.Config
61+ batchIterator BatchIterator
62+ currentBatch SparkArrowBatch
63+ isFinished bool
64+ arrowSchemaBytes []byte
65+ arrowSchema * arrow.Schema
4166}
4267
4368var _ rows.ArrowBatchIterator = (* arrowRecordIterator )(nil )
@@ -80,18 +105,13 @@ func (ri *arrowRecordIterator) Close() {
80105 if ri .batchIterator != nil {
81106 ri .batchIterator .Close ()
82107 }
83-
84- if ri .resultPageIterator != nil {
85- ri .resultPageIterator .Close ()
86- }
87108 }
88109}
89110
90111func (ri * arrowRecordIterator ) checkFinished () {
91112 finished := ri .isFinished ||
92113 ((ri .currentBatch == nil || ! ri .currentBatch .HasNext ()) &&
93- (ri .batchIterator == nil || ! ri .batchIterator .HasNext ()) &&
94- (ri .resultPageIterator == nil || ! ri .resultPageIterator .HasNext ()))
114+ (ri .batchIterator == nil || ! ri .batchIterator .HasNext ()))
95115
96116 if finished {
97117 // Reached end of result set so Close
@@ -101,86 +121,39 @@ func (ri *arrowRecordIterator) checkFinished() {
101121
102122// Update the current batch if necessary
103123func (ri * arrowRecordIterator ) getCurrentBatch () error {
104-
105124 // only need to update if no current batch or current batch has no more records
106125 if ri .currentBatch == nil || ! ri .currentBatch .HasNext () {
107-
108- // ensure up to date batch iterator
109- err := ri .getBatchIterator ()
110- if err != nil {
111- return err
112- }
113-
114126 // release current batch
115127 if ri .currentBatch != nil {
116128 ri .currentBatch .Close ()
117129 }
118130
119131 // Get next batch from batch iterator
120- ri .currentBatch , err = ri .batchIterator .Next ()
121- if err != nil {
122- return err
132+ if ri .batchIterator == nil {
133+ return io .EOF
123134 }
124- }
125-
126- return nil
127- }
128135
129- // Update batch iterator if necessary
130- func (ri * arrowRecordIterator ) getBatchIterator () error {
131- // only need to update if there is no batch iterator or the
132- // batch iterator has no more batches
133- if ri .batchIterator == nil || ! ri .batchIterator .HasNext () {
134- if ri .batchIterator != nil {
135- // release any resources held by the current batch iterator
136- ri .batchIterator .Close ()
137- ri .batchIterator = nil
138- }
139-
140- // Get the next page of the result set
141- resp , err := ri .resultPageIterator .Next ()
136+ var err error
137+ ri .currentBatch , err = ri .batchIterator .Next ()
142138 if err != nil {
143139 return err
144140 }
145141
146- // Check the result format
147- resultFormat := resp .ResultSetMetadata .GetResultFormat ()
148- if resultFormat != cli_service .TSparkRowSetType_ARROW_BASED_SET && resultFormat != cli_service .TSparkRowSetType_URL_BASED_SET {
149- return dbsqlerr .NewDriverError (ri .ctx , errArrowRowsNotArrowFormat , nil )
150- }
151-
142+ // Update schema bytes if we don't have them yet and the batch iterator got them
152143 if ri .arrowSchemaBytes == nil {
153- ri . arrowSchemaBytes = resp . ResultSetMetadata . ArrowSchema
154- }
155-
156- // Create a new batch iterator for the batches in the result page
157- bi , err := ri . newBatchIterator ( resp )
158- if err != nil {
159- return err
144+ if batchIter , ok := ri . batchIterator .( * batchIterator ); ok {
145+ if pagedIter , ok := batchIter . rawIterator .( * pagedRawBatchIterator ); ok {
146+ if schemaBytes := pagedIter . GetSchemaBytes (); schemaBytes != nil {
147+ ri . arrowSchemaBytes = schemaBytes
148+ }
149+ }
150+ }
160151 }
161-
162- ri .batchIterator = bi
163152 }
164153
165154 return nil
166155}
167156
168- // Create a new batch iterator from a page of the result set
169- func (ri * arrowRecordIterator ) newBatchIterator (fr * cli_service.TFetchResultsResp ) (BatchIterator , error ) {
170- rowSet := fr .Results
171- var rawBi RawBatchIterator
172- var err error
173- if len (rowSet .ResultLinks ) > 0 {
174- rawBi , err = NewCloudRawBatchIterator (ri .ctx , rowSet .ResultLinks , rowSet .StartRowOffset , & ri .cfg )
175- } else {
176- rawBi , err = NewLocalRawBatchIterator (ri .ctx , rowSet .ArrowBatches , rowSet .StartRowOffset , ri .arrowSchemaBytes , & ri .cfg )
177- }
178- if err != nil {
179- return nil , err
180- }
181- return NewBatchIterator (rawBi , ri .arrowSchemaBytes , & ri .cfg ), nil
182- }
183-
184157// Return the schema of the records.
185158func (ri * arrowRecordIterator ) Schema () (* arrow.Schema , error ) {
186159 // Return cached schema if available
@@ -213,3 +186,51 @@ func (ri *arrowRecordIterator) Schema() (*arrow.Schema, error) {
213186 ri .arrowSchema = reader .Schema ()
214187 return ri .arrowSchema , nil
215188}
189+
190+ // InitialThenPagedRawIterator handles initial raw iterator first, then paged raw iterator
191+ type InitialThenPagedRawIterator struct {
192+ InitialRaw RawBatchIterator
193+ PagedRaw RawBatchIterator
194+ }
195+
196+ // NewInitialThenPagedRawIterator creates a composite iterator
197+ func NewInitialThenPagedRawIterator (initial , paged RawBatchIterator ) RawBatchIterator {
198+ return & InitialThenPagedRawIterator {
199+ InitialRaw : initial ,
200+ PagedRaw : paged ,
201+ }
202+ }
203+
204+ func (i * InitialThenPagedRawIterator ) Next () (* cli_service.TSparkArrowBatch , error ) {
205+ if i .InitialRaw != nil && i .InitialRaw .HasNext () {
206+ return i .InitialRaw .Next ()
207+ }
208+ if i .PagedRaw != nil {
209+ return i .PagedRaw .Next ()
210+ }
211+ return nil , io .EOF
212+ }
213+
214+ func (i * InitialThenPagedRawIterator ) HasNext () bool {
215+ return (i .InitialRaw != nil && i .InitialRaw .HasNext ()) ||
216+ (i .PagedRaw != nil && i .PagedRaw .HasNext ())
217+ }
218+
219+ func (i * InitialThenPagedRawIterator ) Close () {
220+ if i .InitialRaw != nil {
221+ i .InitialRaw .Close ()
222+ }
223+ if i .PagedRaw != nil {
224+ i .PagedRaw .Close ()
225+ }
226+ }
227+
228+ func (i * InitialThenPagedRawIterator ) GetStartRowOffset () int64 {
229+ if i .InitialRaw != nil && i .InitialRaw .HasNext () {
230+ return i .InitialRaw .GetStartRowOffset ()
231+ }
232+ if i .PagedRaw != nil {
233+ return i .PagedRaw .GetStartRowOffset ()
234+ }
235+ return 0
236+ }
0 commit comments