Skip to content

Commit 6bb1879

Browse files
PECO-1054 Expose Arrow batches to users, part three (#166)
Added DBSqlRows and DBSQLArrowBatchIterator public interfaces. Added arrowRecordIterator which implements DBSQLArrowBatchIterator. Moved closing the database operation from rows type into resultPageIterator as well as properties that are only used by resultPageIterator. Added GetArrowBatches function to rows and arrowRowScanner types. Added HasNext function to BatchIterator and SparkArrowBatch interfaces. Added example for accessing Arrow batches and updated doc.go
2 parents 5a3a210 + 2d88022 commit 6bb1879

File tree

20 files changed

+1854
-99
lines changed

20 files changed

+1854
-99
lines changed

doc.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,80 @@ Example usage:
233233
234234
See the documentation for dbsql/errors for more information.
235235
236+
# Retrieving Arrow Batches
237+
238+
The driver supports the ability to retrieve Apache Arrow record batches.
239+
To work with record batches it is necessary to use sql.Conn.Raw() to access the underlying driver connection to retrieve a driver.Rows instance.
240+
The driver exposes two public interfaces for working with record batches from the rows sub-package:
241+
242+
type Rows interface {
243+
GetArrowBatches(context.Context) (ArrowBatchIterator, error)
244+
}
245+
246+
type ArrowBatchIterator interface {
247+
// Retrieve the next arrow.Record.
248+
// Will return io.EOF if there are no more records
249+
Next() (arrow.Record, error)
250+
251+
// Return true if the iterator contains more batches, false otherwise.
252+
HasNext() bool
253+
254+
// Release any resources in use by the iterator.
255+
Close()
256+
}
257+
258+
The driver.Rows instance retrieved using Conn.Raw() can be converted to a Databricks Rows instance via a type assertion, then use GetArrowBatches() to retrieve a batch iterator.
259+
If the ArrowBatchIterator is not closed it will leak resources, such as the underlying connection.
260+
Calling code must call Release() on records returned by DBSQLArrowBatchIterator.Next().
261+
262+
Example usage:
263+
264+
import (
265+
...
266+
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
267+
)
268+
269+
func main() {
270+
...
271+
db := sql.OpenDB(connector)
272+
defer db.Close()
273+
274+
conn, _ := db.Conn(context.BackGround())
275+
defer conn.Close()
276+
277+
query := `select * from main.default.taxi_trip_data`
278+
279+
var rows driver.Rows
280+
var err error
281+
err = conn.Raw(func(d interface{}) error {
282+
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
283+
return err
284+
})
285+
286+
if err != nil {
287+
log.Fatalf("unable to run the query. err: %v", err)
288+
}
289+
defer rows.Close()
290+
291+
batches, err := rows.(dbsqlrows.Rows).GetArrowBatches(context.BackGround())
292+
if err != nil {
293+
log.Fatalf("unable to get arrow batches. err: %v", err)
294+
}
295+
296+
var iBatch, nRows int
297+
for batches.HasNext() {
298+
b, err := batches.Next()
299+
if err != nil {
300+
log.Fatalf("Failure retrieving batch. err: %v", err)
301+
}
302+
303+
log.Printf("batch %v: nRecords=%v\n", iBatch, b.NumRows())
304+
iBatch += 1
305+
nRows += int(b.NumRows())
306+
}
307+
log.Printf("NRows: %v\n", nRows)
308+
}
309+
236310
# Supported Data Types
237311
238312
==================================

examples/arrrowbatches/main.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
"io"
8+
"log"
9+
"os"
10+
"strconv"
11+
"time"
12+
13+
"github.com/apache/arrow/go/v12/arrow"
14+
dbsql "github.com/databricks/databricks-sql-go"
15+
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
16+
"github.com/joho/godotenv"
17+
)
18+
19+
func main() {
20+
// Opening a driver typically will not attempt to connect to the database.
21+
err := godotenv.Load()
22+
if err != nil {
23+
log.Fatal(err.Error())
24+
}
25+
26+
// dbsqllog.SetLogLevel("debug")
27+
28+
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
29+
if err != nil {
30+
log.Fatal(err.Error())
31+
}
32+
connector, err := dbsql.NewConnector(
33+
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
34+
dbsql.WithPort(port),
35+
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
36+
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
37+
dbsql.WithMaxRows(10000),
38+
)
39+
40+
if err != nil {
41+
// This will not be a connection error, but a DSN parse error or
42+
// another initialization error.
43+
log.Fatal(err)
44+
}
45+
46+
db := sql.OpenDB(connector)
47+
defer db.Close()
48+
49+
loopWithHasNext(db)
50+
loopWithNext(db)
51+
}
52+
53+
func loopWithHasNext(db *sql.DB) {
54+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
55+
defer cancel()
56+
57+
conn, _ := db.Conn(ctx)
58+
defer conn.Close()
59+
60+
query := `select * from main.default.diamonds`
61+
62+
var rows driver.Rows
63+
var err error
64+
err = conn.Raw(func(d interface{}) error {
65+
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
66+
return err
67+
})
68+
69+
if err != nil {
70+
log.Fatalf("unable to run the query. err: %v", err)
71+
}
72+
defer rows.Close()
73+
74+
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
75+
defer cancel2()
76+
77+
batches, err := rows.(dbsqlrows.Rows).GetArrowBatches(ctx2)
78+
if err != nil {
79+
log.Fatalf("unable to get arrow batches. err: %v", err)
80+
}
81+
82+
var iBatch, nRows int
83+
for batches.HasNext() {
84+
b, err := batches.Next()
85+
if err != nil {
86+
log.Fatalf("Failure retrieving batch. err: %v", err)
87+
}
88+
89+
log.Printf("batch %v: nRecords=%v\n", iBatch, b.NumRows())
90+
iBatch += 1
91+
nRows += int(b.NumRows())
92+
}
93+
log.Printf("NRows: %v\n", nRows)
94+
}
95+
96+
func loopWithNext(db *sql.DB) {
97+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
98+
defer cancel()
99+
100+
conn, _ := db.Conn(ctx)
101+
defer conn.Close()
102+
103+
query := `select * from main.default.diamonds`
104+
105+
var rows driver.Rows
106+
var err error
107+
108+
err = conn.Raw(func(d interface{}) error {
109+
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
110+
return err
111+
})
112+
if err != nil {
113+
log.Fatalf("unable to run the query. err: %v", err)
114+
}
115+
defer rows.Close()
116+
117+
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
118+
defer cancel2()
119+
120+
batches, err := rows.(dbsqlrows.Rows).GetArrowBatches(ctx2)
121+
if err != nil {
122+
log.Fatalf("unable to get arrow batches. err: %v", err)
123+
}
124+
125+
var iBatch, nRows int
126+
var b arrow.Record
127+
for b, err = batches.Next(); err == nil; b, err = batches.Next() {
128+
log.Printf("batch %v: nRecords=%v\n", iBatch, b.NumRows())
129+
iBatch += 1
130+
nRows += int(b.NumRows())
131+
}
132+
133+
log.Printf("NRows: %v\n", nRows)
134+
if err == io.EOF {
135+
log.Println("normal loop termination")
136+
} else {
137+
log.Printf("loop terminated with error: %v", err)
138+
}
139+
}

0 commit comments

Comments
 (0)