Skip to content

Commit 0c6803e

Browse files
Add schema to ArrowBatchIterator
Signed-off-by: Vikrant Puppala <[email protected]>
1 parent c4d5d18 commit 0c6803e

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

internal/rows/arrowbased/arrowRecordIterator.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package arrowbased
22

33
import (
4+
"bytes"
45
"context"
6+
"fmt"
57
"io"
68

79
"github.com/apache/arrow/go/v12/arrow"
10+
"github.com/apache/arrow/go/v12/arrow/ipc"
811
"github.com/databricks/databricks-sql-go/internal/cli_service"
912
"github.com/databricks/databricks-sql-go/internal/config"
1013
dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors"
@@ -34,6 +37,7 @@ type arrowRecordIterator struct {
3437
currentBatch SparkArrowBatch
3538
isFinished bool
3639
arrowSchemaBytes []byte
40+
arrowSchema *arrow.Schema
3741
}
3842

3943
var _ rows.ArrowBatchIterator = (*arrowRecordIterator)(nil)
@@ -170,3 +174,36 @@ func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsRes
170174
return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
171175
}
172176
}
177+
178+
// Return the schema of the records.
179+
func (ri *arrowRecordIterator) Schema() (*arrow.Schema, error) {
180+
// Return cached schema if available
181+
if ri.arrowSchema != nil {
182+
return ri.arrowSchema, nil
183+
}
184+
185+
// Try to get schema bytes if not already available
186+
if ri.arrowSchemaBytes == nil {
187+
if ri.HasNext() {
188+
if err := ri.getCurrentBatch(); err != nil {
189+
return nil, err
190+
}
191+
}
192+
193+
// If still no schema bytes, we can't create a schema
194+
if ri.arrowSchemaBytes == nil {
195+
return nil, fmt.Errorf("no schema available")
196+
}
197+
}
198+
199+
// Convert schema bytes to Arrow schema
200+
reader, err := ipc.NewReader(bytes.NewReader(ri.arrowSchemaBytes))
201+
if err != nil {
202+
return nil, fmt.Errorf("failed to create Arrow IPC reader: %w", err)
203+
}
204+
defer reader.Release()
205+
206+
// Cache and return the schema
207+
ri.arrowSchema = reader.Schema()
208+
return ri.arrowSchema, nil
209+
}

internal/rows/arrowbased/arrowRecordIterator_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,135 @@ func TestArrowRecordIterator(t *testing.T) {
188188
})
189189
}
190190

191+
func TestArrowRecordIteratorSchema(t *testing.T) {
192+
// Test with arrowSchemaBytes available
193+
t.Run("schema with initial schema bytes", func(t *testing.T) {
194+
logger := dbsqllog.WithContext("connectionId", "correlationId", "")
195+
196+
executeStatementResp := cli_service.TExecuteStatementResp{}
197+
loadTestData2(t, "directResultsMultipleFetch/ExecuteStatement.json", &executeStatementResp)
198+
199+
fetchResp1 := cli_service.TFetchResultsResp{}
200+
loadTestData2(t, "directResultsMultipleFetch/FetchResults1.json", &fetchResp1)
201+
202+
var fetchesInfo []fetchResultsInfo
203+
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
204+
rpi := rowscanner.NewResultPageIterator(
205+
rowscanner.NewDelimiter(0, 0),
206+
5000,
207+
nil,
208+
false,
209+
simpleClient,
210+
"connectionId",
211+
"correlationId",
212+
logger,
213+
)
214+
215+
cfg := *config.WithDefaults()
216+
217+
bi, err := NewLocalBatchIterator(
218+
context.Background(),
219+
executeStatementResp.DirectResults.ResultSet.Results.ArrowBatches,
220+
0,
221+
executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema,
222+
&cfg,
223+
)
224+
assert.Nil(t, err)
225+
226+
// Create arrowRecordIterator with schema bytes already available
227+
rs := NewArrowRecordIterator(
228+
context.Background(),
229+
rpi,
230+
bi,
231+
executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema,
232+
cfg,
233+
)
234+
defer rs.Close()
235+
236+
// Test Schema() method
237+
schema, schemaErr := rs.Schema()
238+
assert.NoError(t, schemaErr)
239+
assert.NotNil(t, schema)
240+
241+
// Cache works - we should get same schema object on second call
242+
secondSchema, schemaErr2 := rs.Schema()
243+
assert.NoError(t, schemaErr2)
244+
assert.Same(t, schema, secondSchema)
245+
})
246+
247+
// Test with arrowSchemaBytes that needs to be populated via a batch
248+
t.Run("schema with lazy loading", func(t *testing.T) {
249+
logger := dbsqllog.WithContext("connectionId", "correlationId", "")
250+
251+
fetchResp1 := cli_service.TFetchResultsResp{}
252+
loadTestData2(t, "multipleFetch/FetchResults1.json", &fetchResp1)
253+
254+
var fetchesInfo []fetchResultsInfo
255+
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
256+
rpi := rowscanner.NewResultPageIterator(
257+
rowscanner.NewDelimiter(0, 0),
258+
5000,
259+
nil,
260+
false,
261+
simpleClient,
262+
"connectionId",
263+
"correlationId",
264+
logger,
265+
)
266+
267+
cfg := *config.WithDefaults()
268+
269+
// Create arrowRecordIterator without initial schema bytes
270+
rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg)
271+
defer rs.Close()
272+
273+
// Schema() should trigger loading a batch to get schema
274+
schema, schemaErr := rs.Schema()
275+
assert.NoError(t, schemaErr)
276+
assert.NotNil(t, schema)
277+
278+
// Cache works - we should get same schema object on second call
279+
secondSchema, schemaErr2 := rs.Schema()
280+
assert.NoError(t, schemaErr2)
281+
assert.Same(t, schema, secondSchema)
282+
})
283+
284+
// Test with no schema available
285+
t.Run("schema with no data available", func(t *testing.T) {
286+
logger := dbsqllog.WithContext("connectionId", "correlationId", "")
287+
288+
// Instead of using an empty response list, let's create a custom client
289+
// that returns an error when trying to fetch results
290+
failingClient := &client.TestClient{
291+
FnFetchResults: func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) {
292+
return nil, fmt.Errorf("no data available")
293+
},
294+
}
295+
296+
rpi := rowscanner.NewResultPageIterator(
297+
rowscanner.NewDelimiter(0, 0),
298+
5000,
299+
nil,
300+
false,
301+
failingClient,
302+
"connectionId",
303+
"correlationId",
304+
logger,
305+
)
306+
307+
cfg := *config.WithDefaults()
308+
309+
// Create arrowRecordIterator without schema bytes and with failing client
310+
rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg)
311+
defer rs.Close()
312+
313+
// Schema() should return error since no schema can be obtained
314+
schema, schemaErr := rs.Schema()
315+
assert.Error(t, schemaErr)
316+
assert.Nil(t, schema)
317+
})
318+
}
319+
191320
type fetchResultsInfo struct {
192321
direction cli_service.TFetchOrientation
193322
resultStartRec int

rows/rows.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,7 @@ type ArrowBatchIterator interface {
2020

2121
// Release any resources in use by the iterator.
2222
Close()
23+
24+
// Return the schema of the records.
25+
Schema() (*arrow.Schema, error)
2326
}

0 commit comments

Comments
 (0)