Skip to content

Commit 00bd733

Browse files
Close operation after executing statement
Signed-off-by: Raymond Cypher <[email protected]>
1 parent 4488f73 commit 00bd733

File tree

4 files changed

+200
-20
lines changed

4 files changed

+200
-20
lines changed

connection.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,21 +90,37 @@ func (c *conn) IsValid() bool {
9090
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
9191
log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "")
9292
msg, start := logger.Track("ExecContext")
93+
defer log.Duration(msg, start)
94+
9395
ctx = driverctx.NewContextWithConnId(ctx, c.id)
9496
if len(args) > 0 {
9597
return nil, errors.New(ErrParametersNotSupported)
9698
}
9799
exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)
98100

99101
if exStmtResp != nil && exStmtResp.OperationHandle != nil {
102+
// we have an operation id so update the logger
100103
log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(exStmtResp.OperationHandle.OperationId.GUID))
104+
105+
// since we have an operation handle we can close the operation
106+
if opStatusResp == nil || opStatusResp.GetOperationState() != cli_service.TOperationState_CLOSED_STATE {
107+
_, err1 := c.client.CloseOperation(ctx, &cli_service.TCloseOperationReq{
108+
OperationHandle: exStmtResp.OperationHandle,
109+
})
110+
if err1 != nil {
111+
log.Err(err1).Msg("failed to close operation after executing statement")
112+
}
113+
}
101114
}
102-
defer log.Duration(msg, start)
103115

104116
if err != nil {
117+
// TODO: are there error situations in which the operation still needs to be closed?
118+
// Currently if there is an error we never get back a TExecuteStatementResponse so
119+
// can't try to close.
105120
log.Err(err).Msgf("databricks: failed to execute query: query %s", query)
106121
return nil, wrapErrf(err, "failed to execute query")
107122
}
123+
108124
res := result{AffectedRows: opStatusResp.GetNumModifiedRows()}
109125

110126
return &res, nil
@@ -261,10 +277,12 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
261277
if err != nil {
262278
return nil, err
263279
}
280+
264281
exStmtResp, ok := res.(*cli_service.TExecuteStatementResp)
265282
if !ok {
266283
return exStmtResp, errors.New("databricks: invalid execute statement response")
267284
}
285+
268286
return exStmtResp, err
269287
}
270288

connection_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,89 @@ func TestConn_executeStatement(t *testing.T) {
8686
assert.NoError(t, err)
8787
assert.Equal(t, 1, executeStatementCount)
8888
})
89+
90+
t.Run("ExecStatement should close operation on success", func(t *testing.T) {
91+
var executeStatementCount, closeOperationCount int
92+
executeStatementResp := &cli_service.TExecuteStatementResp{
93+
Status: &cli_service.TStatus{
94+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
95+
},
96+
OperationHandle: &cli_service.TOperationHandle{
97+
OperationId: &cli_service.THandleIdentifier{
98+
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 4, 223, 34, 54},
99+
Secret: []byte("b"),
100+
},
101+
},
102+
DirectResults: &cli_service.TSparkDirectResults{
103+
OperationStatus: &cli_service.TGetOperationStatusResp{
104+
Status: &cli_service.TStatus{
105+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
106+
},
107+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_ERROR_STATE),
108+
ErrorMessage: strPtr("error message"),
109+
DisplayMessage: strPtr("display message"),
110+
},
111+
ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{
112+
Status: &cli_service.TStatus{
113+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
114+
},
115+
},
116+
ResultSet: &cli_service.TFetchResultsResp{
117+
Status: &cli_service.TStatus{
118+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
119+
},
120+
},
121+
},
122+
}
123+
124+
testClient := &client.TestClient{
125+
FnExecuteStatement: func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
126+
executeStatementCount++
127+
return executeStatementResp, nil
128+
},
129+
FnCloseOperation: func(ctx context.Context, req *cli_service.TCloseOperationReq) (_r *cli_service.TCloseOperationResp, _err error) {
130+
closeOperationCount++
131+
return &cli_service.TCloseOperationResp{}, nil
132+
},
133+
}
134+
135+
testConn := &conn{
136+
session: getTestSession(),
137+
client: testClient,
138+
cfg: config.WithDefaults(),
139+
}
140+
141+
type opStateTest struct {
142+
state cli_service.TOperationState
143+
err string
144+
closeOperationCount int
145+
}
146+
147+
// test behaviour with all terminal operation states
148+
operationStateTests := []opStateTest{
149+
{state: cli_service.TOperationState_ERROR_STATE, err: "error state", closeOperationCount: 1},
150+
{state: cli_service.TOperationState_FINISHED_STATE, err: "", closeOperationCount: 1},
151+
{state: cli_service.TOperationState_CANCELED_STATE, err: "cancelled state", closeOperationCount: 1},
152+
{state: cli_service.TOperationState_CLOSED_STATE, err: "closed state", closeOperationCount: 0},
153+
{state: cli_service.TOperationState_TIMEDOUT_STATE, err: "timeout state", closeOperationCount: 1},
154+
}
155+
156+
for _, opTest := range operationStateTests {
157+
closeOperationCount = 0
158+
executeStatementCount = 0
159+
executeStatementResp.DirectResults.OperationStatus.OperationState = &opTest.state
160+
executeStatementResp.DirectResults.OperationStatus.DisplayMessage = &opTest.err
161+
_, err := testConn.ExecContext(context.Background(), "select 1", []driver.NamedValue{})
162+
if opTest.err == "" {
163+
assert.NoError(t, err)
164+
} else {
165+
assert.EqualError(t, err, opTest.err)
166+
}
167+
assert.Equal(t, 1, executeStatementCount)
168+
assert.Equal(t, opTest.closeOperationCount, closeOperationCount)
169+
}
170+
})
171+
89172
}
90173

91174
func TestConn_pollOperation(t *testing.T) {

examples/createdrop/main.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"log"
7+
"os"
8+
"strconv"
9+
"time"
10+
11+
dbsql "github.com/databricks/databricks-sql-go"
12+
dbsqlctx "github.com/databricks/databricks-sql-go/driverctx"
13+
dbsqllog "github.com/databricks/databricks-sql-go/logger"
14+
"github.com/joho/godotenv"
15+
)
16+
17+
func main() {
18+
// use this package to set up logging. By default logging level is `warn`. If you want to disable logging, use `disabled`
19+
if err := dbsqllog.SetLogLevel("debug"); err != nil {
20+
log.Fatal(err)
21+
}
22+
// sets the logging output. By default it will use os.Stderr. If running in terminal, it will use ConsoleWriter to make it pretty
23+
// dbsqllog.SetLogOutput(os.Stdout)
24+
25+
// this is just to make it easy to load all variables
26+
if err := godotenv.Load(); err != nil {
27+
log.Fatal(err)
28+
}
29+
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
30+
if err != nil {
31+
log.Fatal(err)
32+
}
33+
34+
// programmatically initializes the connector
35+
// another way is to use a DNS. In this case the equivalent DNS would be:
36+
// "token:<my_token>@hostname:port/http_path?catalog=hive_metastore&schema=default&timeout=60&maxRows=10&&timezone=America/Sao_Paulo&ANSI_MODE=true"
37+
connector, err := dbsql.NewConnector(
38+
// minimum configuration
39+
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
40+
dbsql.WithPort(port),
41+
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
42+
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
43+
//optional configuration
44+
dbsql.WithSessionParams(map[string]string{"timezone": "America/Sao_Paulo", "ansi_mode": "true"}),
45+
dbsql.WithUserAgentEntry("workflow-example"),
46+
dbsql.WithInitialNamespace("hive_metastore", "default"),
47+
dbsql.WithTimeout(time.Minute), // defaults to no timeout. Global timeout. Any query will be canceled if taking more than this time.
48+
dbsql.WithMaxRows(10), // defaults to 10000
49+
)
50+
if err != nil {
51+
// This will not be a connection error, but a DSN parse error or
52+
// another initialization error.
53+
log.Fatal(err)
54+
55+
}
56+
// Opening a driver typically will not attempt to connect to the database.
57+
db := sql.OpenDB(connector)
58+
// make sure to close it later
59+
defer db.Close()
60+
61+
ogCtx := dbsqlctx.NewContextWithCorrelationId(context.Background(), "createdrop-example")
62+
63+
// sets the timeout to 30 seconds. More than that we ping will fail. The default is 15 seconds
64+
ctx1, cancel := context.WithTimeout(ogCtx, 30*time.Second)
65+
defer cancel()
66+
if err := db.PingContext(ctx1); err != nil {
67+
log.Fatal(err)
68+
}
69+
70+
// create a table with some data. This has no context timeout, it will follow the timeout of one minute set for the connection.
71+
if _, err := db.ExecContext(ogCtx, `CREATE TABLE IF NOT EXISTS diamonds USING CSV LOCATION '/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv' options (header = true, inferSchema = true)`); err != nil {
72+
log.Fatal(err)
73+
}
74+
75+
if _, err := db.ExecContext(ogCtx, `DROP TABLE diamonds `); err != nil {
76+
log.Fatal(err)
77+
}
78+
}

examples/workflow/main.go

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"fmt"
7+
"log"
78
"os"
89
"strconv"
910
"time"
@@ -17,18 +18,18 @@ import (
1718
func main() {
1819
// use this package to set up logging. By default logging level is `warn`. If you want to disable logging, use `disabled`
1920
if err := dbsqllog.SetLogLevel("debug"); err != nil {
20-
panic(err)
21+
log.Fatal(err)
2122
}
2223
// sets the logging output. By default it will use os.Stderr. If running in terminal, it will use ConsoleWriter to make it pretty
2324
// dbsqllog.SetLogOutput(os.Stdout)
2425

2526
// this is just to make it easy to load all variables
2627
if err := godotenv.Load(); err != nil {
27-
panic(err)
28+
log.Fatal(err)
2829
}
2930
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
3031
if err != nil {
31-
panic(err)
32+
log.Fatal(err)
3233
}
3334

3435
// programmatically initializes the connector
@@ -50,7 +51,7 @@ func main() {
5051
if err != nil {
5152
// This will not be a connection error, but a DSN parse error or
5253
// another initialization error.
53-
panic(err)
54+
log.Fatal(err)
5455

5556
}
5657
// Opening a driver typically will not attempt to connect to the database.
@@ -88,18 +89,18 @@ func main() {
8889
ctx1, cancel := context.WithTimeout(ogCtx, 30*time.Second)
8990
defer cancel()
9091
if err := db.PingContext(ctx1); err != nil {
91-
panic(err)
92+
log.Fatal(err)
9293
}
9394

9495
// create a table with some data. This has no context timeout, it will follow the timeout of one minute set for the connection.
9596
if _, err := db.ExecContext(ogCtx, `CREATE TABLE IF NOT EXISTS diamonds USING CSV LOCATION '/databricks-datasets/Rdatasets/data-001/csv/ggplot2/diamonds.csv' options (header = true, inferSchema = true)`); err != nil {
96-
panic(err)
97+
log.Fatal(err)
9798
}
9899

99100
// QueryRowContext is a shortcut function to get a single value
100101
var max float64
101102
if err := db.QueryRowContext(ogCtx, `select max(carat) from diamonds`).Scan(&max); err != nil {
102-
panic(err)
103+
log.Fatal(err)
103104
} else {
104105
fmt.Printf("max carat in dataset is: %f\n", max)
105106
}
@@ -109,7 +110,7 @@ func main() {
109110
defer cancel()
110111

111112
if rows, err := db.QueryContext(ctx2, "select * from diamonds limit 19"); err != nil {
112-
panic(err)
113+
log.Fatal(err)
113114
} else {
114115
type row struct {
115116
_c0 int
@@ -127,11 +128,11 @@ func main() {
127128

128129
cols, err := rows.Columns()
129130
if err != nil {
130-
panic(err)
131+
log.Fatal(err)
131132
}
132133
types, err := rows.ColumnTypes()
133134
if err != nil {
134-
panic(err)
135+
log.Fatal(err)
135136
}
136137
for i, c := range cols {
137138
fmt.Printf("column %d is %s and has type %v\n", i, c, types[i].DatabaseTypeName())
@@ -141,7 +142,7 @@ func main() {
141142
// After row 10 this will cause one fetch call, as 10 rows (maxRows config) will come from the first execute statement call.
142143
r := row{}
143144
if err := rows.Scan(&r._c0, &r.carat, &r.cut, &r.color, &r.clarity, &r.depth, &r.table, &r.price, &r.x, &r.y, &r.z); err != nil {
144-
panic(err)
145+
log.Fatal(err)
145146
}
146147
res = append(res, r)
147148
}
@@ -156,7 +157,7 @@ func main() {
156157
var curTimezone string
157158

158159
if err := db.QueryRowContext(ogCtx, `select current_date(), current_timestamp(), current_timezone()`).Scan(&curDate, &curTimestamp, &curTimezone); err != nil {
159-
panic(err)
160+
log.Fatal(err)
160161
} else {
161162
// this will print now at timezone America/Sao_Paulo is: 2022-11-16 20:25:15.282 -0300 -03
162163
fmt.Printf("current timestamp at timezone %s is: %s\n", curTimezone, curTimestamp)
@@ -170,11 +171,11 @@ func main() {
170171
array_col array < int >,
171172
map_col map < string, int >,
172173
struct_col struct < string_field string, array_field array < int > >)`); err != nil {
173-
panic(err)
174+
log.Fatal(err)
174175
}
175176
var numRows int
176177
if err := db.QueryRowContext(ogCtx, `select count(*) from array_map_struct`).Scan(&numRows); err != nil {
177-
panic(err)
178+
log.Fatal(err)
178179
} else {
179180
fmt.Printf("table has %d rows\n", numRows)
180181
}
@@ -186,7 +187,7 @@ func main() {
186187
array(1, 2, 3),
187188
map('key1', 1),
188189
struct('string_val', array(4, 5, 6)))`); err != nil {
189-
panic(err)
190+
log.Fatal(err)
190191
} else {
191192
i, err1 := res.RowsAffected()
192193
if err1 != nil {
@@ -197,7 +198,7 @@ func main() {
197198
}
198199

199200
if rows, err := db.QueryContext(ogCtx, "select * from array_map_struct"); err != nil {
200-
panic(err)
201+
log.Fatal(err)
201202
} else {
202203
// complex data types are returned as string
203204
type row struct {
@@ -208,11 +209,11 @@ func main() {
208209
res := []row{}
209210
cols, err := rows.Columns()
210211
if err != nil {
211-
panic(err)
212+
log.Fatal(err)
212213
}
213214
types, err := rows.ColumnTypes()
214215
if err != nil {
215-
panic(err)
216+
log.Fatal(err)
216217
}
217218
for i, c := range cols {
218219
fmt.Printf("column %d is %s and has type %v\n", i, c, types[i].DatabaseTypeName())
@@ -221,7 +222,7 @@ func main() {
221222
for rows.Next() {
222223
r := row{}
223224
if err := rows.Scan(&r.arrayVal, &r.mapVal, &r.structVal); err != nil {
224-
panic(err)
225+
log.Fatal(err)
225226
}
226227
res = append(res, r)
227228
}

0 commit comments

Comments
 (0)