Skip to content

Commit c291992

Browse files
Add schema to ArrowBatchIterator (#267)
Allow ability to get the arrow schema when fetching Arrow Batches. Currently, the GetArrowBatches(ctx) method in the DBSQL driver does not expose schema information, this PR adds support to get Schema directly like the Arrow Flight interface.
2 parents 7263c53 + 0c6803e commit c291992

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

internal/rows/arrowbased/arrowRecordIterator.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package arrowbased
22

33
import (
4+
"bytes"
45
"context"
6+
"fmt"
57
"io"
68

79
"github.com/apache/arrow/go/v12/arrow"
10+
"github.com/apache/arrow/go/v12/arrow/ipc"
811
"github.com/databricks/databricks-sql-go/internal/cli_service"
912
"github.com/databricks/databricks-sql-go/internal/config"
1013
dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors"
@@ -34,6 +37,7 @@ type arrowRecordIterator struct {
3437
currentBatch SparkArrowBatch
3538
isFinished bool
3639
arrowSchemaBytes []byte
40+
arrowSchema *arrow.Schema
3741
}
3842

3943
var _ rows.ArrowBatchIterator = (*arrowRecordIterator)(nil)
@@ -170,3 +174,36 @@ func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsRes
170174
return NewLocalBatchIterator(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
171175
}
172176
}
177+
178+
// Return the schema of the records.
179+
func (ri *arrowRecordIterator) Schema() (*arrow.Schema, error) {
180+
// Return cached schema if available
181+
if ri.arrowSchema != nil {
182+
return ri.arrowSchema, nil
183+
}
184+
185+
// Try to get schema bytes if not already available
186+
if ri.arrowSchemaBytes == nil {
187+
if ri.HasNext() {
188+
if err := ri.getCurrentBatch(); err != nil {
189+
return nil, err
190+
}
191+
}
192+
193+
// If still no schema bytes, we can't create a schema
194+
if ri.arrowSchemaBytes == nil {
195+
return nil, fmt.Errorf("no schema available")
196+
}
197+
}
198+
199+
// Convert schema bytes to Arrow schema
200+
reader, err := ipc.NewReader(bytes.NewReader(ri.arrowSchemaBytes))
201+
if err != nil {
202+
return nil, fmt.Errorf("failed to create Arrow IPC reader: %w", err)
203+
}
204+
defer reader.Release()
205+
206+
// Cache and return the schema
207+
ri.arrowSchema = reader.Schema()
208+
return ri.arrowSchema, nil
209+
}

internal/rows/arrowbased/arrowRecordIterator_test.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,135 @@ func TestArrowRecordIterator(t *testing.T) {
188188
})
189189
}
190190

191+
func TestArrowRecordIteratorSchema(t *testing.T) {
192+
// Test with arrowSchemaBytes available
193+
t.Run("schema with initial schema bytes", func(t *testing.T) {
194+
logger := dbsqllog.WithContext("connectionId", "correlationId", "")
195+
196+
executeStatementResp := cli_service.TExecuteStatementResp{}
197+
loadTestData2(t, "directResultsMultipleFetch/ExecuteStatement.json", &executeStatementResp)
198+
199+
fetchResp1 := cli_service.TFetchResultsResp{}
200+
loadTestData2(t, "directResultsMultipleFetch/FetchResults1.json", &fetchResp1)
201+
202+
var fetchesInfo []fetchResultsInfo
203+
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
204+
rpi := rowscanner.NewResultPageIterator(
205+
rowscanner.NewDelimiter(0, 0),
206+
5000,
207+
nil,
208+
false,
209+
simpleClient,
210+
"connectionId",
211+
"correlationId",
212+
logger,
213+
)
214+
215+
cfg := *config.WithDefaults()
216+
217+
bi, err := NewLocalBatchIterator(
218+
context.Background(),
219+
executeStatementResp.DirectResults.ResultSet.Results.ArrowBatches,
220+
0,
221+
executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema,
222+
&cfg,
223+
)
224+
assert.Nil(t, err)
225+
226+
// Create arrowRecordIterator with schema bytes already available
227+
rs := NewArrowRecordIterator(
228+
context.Background(),
229+
rpi,
230+
bi,
231+
executeStatementResp.DirectResults.ResultSetMetadata.ArrowSchema,
232+
cfg,
233+
)
234+
defer rs.Close()
235+
236+
// Test Schema() method
237+
schema, schemaErr := rs.Schema()
238+
assert.NoError(t, schemaErr)
239+
assert.NotNil(t, schema)
240+
241+
// Cache works - we should get same schema object on second call
242+
secondSchema, schemaErr2 := rs.Schema()
243+
assert.NoError(t, schemaErr2)
244+
assert.Same(t, schema, secondSchema)
245+
})
246+
247+
// Test with arrowSchemaBytes that needs to be populated via a batch
248+
t.Run("schema with lazy loading", func(t *testing.T) {
249+
logger := dbsqllog.WithContext("connectionId", "correlationId", "")
250+
251+
fetchResp1 := cli_service.TFetchResultsResp{}
252+
loadTestData2(t, "multipleFetch/FetchResults1.json", &fetchResp1)
253+
254+
var fetchesInfo []fetchResultsInfo
255+
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
256+
rpi := rowscanner.NewResultPageIterator(
257+
rowscanner.NewDelimiter(0, 0),
258+
5000,
259+
nil,
260+
false,
261+
simpleClient,
262+
"connectionId",
263+
"correlationId",
264+
logger,
265+
)
266+
267+
cfg := *config.WithDefaults()
268+
269+
// Create arrowRecordIterator without initial schema bytes
270+
rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg)
271+
defer rs.Close()
272+
273+
// Schema() should trigger loading a batch to get schema
274+
schema, schemaErr := rs.Schema()
275+
assert.NoError(t, schemaErr)
276+
assert.NotNil(t, schema)
277+
278+
// Cache works - we should get same schema object on second call
279+
secondSchema, schemaErr2 := rs.Schema()
280+
assert.NoError(t, schemaErr2)
281+
assert.Same(t, schema, secondSchema)
282+
})
283+
284+
// Test with no schema available
285+
t.Run("schema with no data available", func(t *testing.T) {
286+
logger := dbsqllog.WithContext("connectionId", "correlationId", "")
287+
288+
// Instead of using an empty response list, let's create a custom client
289+
// that returns an error when trying to fetch results
290+
failingClient := &client.TestClient{
291+
FnFetchResults: func(ctx context.Context, req *cli_service.TFetchResultsReq) (*cli_service.TFetchResultsResp, error) {
292+
return nil, fmt.Errorf("no data available")
293+
},
294+
}
295+
296+
rpi := rowscanner.NewResultPageIterator(
297+
rowscanner.NewDelimiter(0, 0),
298+
5000,
299+
nil,
300+
false,
301+
failingClient,
302+
"connectionId",
303+
"correlationId",
304+
logger,
305+
)
306+
307+
cfg := *config.WithDefaults()
308+
309+
// Create arrowRecordIterator without schema bytes and with failing client
310+
rs := NewArrowRecordIterator(context.Background(), rpi, nil, nil, cfg)
311+
defer rs.Close()
312+
313+
// Schema() should return error since no schema can be obtained
314+
schema, schemaErr := rs.Schema()
315+
assert.Error(t, schemaErr)
316+
assert.Nil(t, schema)
317+
})
318+
}
319+
191320
type fetchResultsInfo struct {
192321
direction cli_service.TFetchOrientation
193322
resultStartRec int

rows/rows.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,7 @@ type ArrowBatchIterator interface {
2020

2121
// Release any resources in use by the iterator.
2222
Close()
23+
24+
// Return the schema of the records.
25+
Schema() (*arrow.Schema, error)
2326
}

0 commit comments

Comments
 (0)