Skip to content

Commit 0b7f22f

Browse files
P&R improvements (#60)
Added a retryable http client. Now we also re-use clients across connections. Not need to create one client per connection. Fixed a problem where queries would wait intervalTime before firing Fixed a problem where on context done the query would not be cancelled. Signed-off-by: Andre Furlan <[email protected]>
1 parent 2397470 commit 0b7f22f

File tree

13 files changed

+292
-127
lines changed

13 files changed

+292
-127
lines changed

connection.go

Lines changed: 73 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,11 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
3434
func (c *conn) Close() error {
3535
log := logger.WithContext(c.id, "", "")
3636
ctx := driverctx.NewContextWithConnId(context.Background(), c.id)
37-
sentinel := sentinel.Sentinel{
38-
OnDoneFn: func(statusResp any) (any, error) {
39-
return c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{
40-
SessionHandle: c.session.SessionHandle,
41-
})
42-
},
43-
}
44-
_, _, err := sentinel.Watch(ctx, c.cfg.PollInterval, 15*time.Second)
37+
38+
_, err := c.client.CloseSession(ctx, &cli_service.TCloseSessionReq{
39+
SessionHandle: c.session.SessionHandle,
40+
})
41+
4542
if err != nil {
4643
log.Err(err).Msg("databricks: failed to close connection")
4744
return wrapErr(err, "failed to close connection")
@@ -62,7 +59,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
6259
func (c *conn) Ping(ctx context.Context) error {
6360
log := logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), "")
6461
ctx = driverctx.NewContextWithConnId(ctx, c.id)
65-
ctx1, cancel := context.WithTimeout(ctx, 15*time.Second)
62+
ctx1, cancel := context.WithTimeout(ctx, 60*time.Second)
6663
defer cancel()
6764
_, err := c.QueryContext(ctx1, "select 1", nil)
6865
if err != nil {
@@ -113,7 +110,6 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
113110
}
114111
}
115112
}
116-
117113
if err != nil {
118114
// TODO: are there error situations in which the operation still needs to be closed?
119115
// Currently if there is an error we never get back a TExecuteStatementResponse so
@@ -151,6 +147,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
151147
defer log.Duration(msg, start)
152148

153149
if err != nil {
150+
// gotta also think about close operation here
154151
log.Err(err).Msgf("databricks: failed to run query: query %s", query)
155152
return nil, wrapErrf(err, "failed to run query")
156153
}
@@ -175,7 +172,10 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
175172
// hold on to the operation handle
176173
opHandle := exStmtResp.OperationHandle
177174
if opHandle != nil && opHandle.OperationId != nil {
178-
log = logger.WithContext(c.id, driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(opHandle.OperationId.GUID))
175+
log = logger.WithContext(
176+
c.id,
177+
driverctx.CorrelationIdFromContext(ctx), client.SprintGuid(opHandle.OperationId.GUID),
178+
)
179179
}
180180

181181
if exStmtResp.DirectResults != nil {
@@ -188,12 +188,17 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
188188
// return results
189189
return exStmtResp, opStatus, nil
190190
// bad
191-
case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE:
191+
case cli_service.TOperationState_CANCELED_STATE,
192+
cli_service.TOperationState_CLOSED_STATE,
193+
cli_service.TOperationState_ERROR_STATE,
194+
cli_service.TOperationState_TIMEDOUT_STATE:
192195
// do we need to close the operation in these cases?
193196
logBadQueryState(log, opStatus)
194197
return exStmtResp, opStatus, errors.New(opStatus.GetDisplayMessage())
195198
// live states
196-
case cli_service.TOperationState_INITIALIZED_STATE, cli_service.TOperationState_PENDING_STATE, cli_service.TOperationState_RUNNING_STATE:
199+
case cli_service.TOperationState_INITIALIZED_STATE,
200+
cli_service.TOperationState_PENDING_STATE,
201+
cli_service.TOperationState_RUNNING_STATE:
197202
statusResp, err := c.pollOperation(ctx, opHandle)
198203
if err != nil {
199204
return exStmtResp, statusResp, err
@@ -205,7 +210,10 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
205210
// return handle to fetch results later
206211
return exStmtResp, opStatus, nil
207212
// bad
208-
case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE:
213+
case cli_service.TOperationState_CANCELED_STATE,
214+
cli_service.TOperationState_CLOSED_STATE,
215+
cli_service.TOperationState_ERROR_STATE,
216+
cli_service.TOperationState_TIMEDOUT_STATE:
209217
logBadQueryState(log, statusResp)
210218
return exStmtResp, opStatus, errors.New(statusResp.GetDisplayMessage())
211219
// live states
@@ -231,7 +239,10 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
231239
// return handle to fetch results later
232240
return exStmtResp, statusResp, nil
233241
// bad
234-
case cli_service.TOperationState_CANCELED_STATE, cli_service.TOperationState_CLOSED_STATE, cli_service.TOperationState_ERROR_STATE, cli_service.TOperationState_TIMEDOUT_STATE:
242+
case cli_service.TOperationState_CANCELED_STATE,
243+
cli_service.TOperationState_CLOSED_STATE,
244+
cli_service.TOperationState_ERROR_STATE,
245+
cli_service.TOperationState_TIMEDOUT_STATE:
235246
logBadQueryState(log, statusResp)
236247
return exStmtResp, statusResp, errors.New(statusResp.GetDisplayMessage())
237248
// live states
@@ -250,41 +261,56 @@ func logBadQueryState(log *logger.DBSQLLogger, opStatus *cli_service.TGetOperati
250261
func (c *conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) {
251262
corrId := driverctx.CorrelationIdFromContext(ctx)
252263
log := logger.WithContext(c.id, corrId, "")
253-
sentinel := sentinel.Sentinel{
254-
OnDoneFn: func(statusResp any) (any, error) {
255-
req := cli_service.TExecuteStatementReq{
256-
SessionHandle: c.session.SessionHandle,
257-
Statement: query,
258-
RunAsync: c.cfg.RunAsync,
259-
QueryTimeout: int64(c.cfg.QueryTimeout / time.Second),
260-
// this is specific for databricks. It shortcuts server roundtrips
261-
GetDirectResults: &cli_service.TSparkGetDirectResults{
262-
MaxRows: int64(c.cfg.MaxRows),
263-
},
264-
// CanReadArrowResult_: &t,
265-
// CanDecompressLZ4Result_: &f,
266-
// CanDownloadResult_: &t,
267-
}
268-
ctx = driverctx.NewContextWithConnId(ctx, c.id)
269-
resp, err := c.client.ExecuteStatement(ctx, &req)
270-
return resp, wrapErr(err, "failed to execute statement")
271-
},
272-
OnCancelFn: func() (any, error) {
273-
log.Warn().Msg("databricks: execute statement canceled while creation operation")
274-
return nil, nil
264+
265+
req := cli_service.TExecuteStatementReq{
266+
SessionHandle: c.session.SessionHandle,
267+
Statement: query,
268+
RunAsync: c.cfg.RunAsync,
269+
QueryTimeout: int64(c.cfg.QueryTimeout / time.Second),
270+
// this is specific for databricks. It shortcuts server round trips
271+
GetDirectResults: &cli_service.TSparkGetDirectResults{
272+
MaxRows: int64(c.cfg.MaxRows),
275273
},
274+
// CanReadArrowResult_: &t,
275+
// CanDecompressLZ4Result_: &f,
276+
// CanDownloadResult_: &t,
276277
}
277-
_, res, err := sentinel.Watch(ctx, c.cfg.PollInterval, c.cfg.QueryTimeout)
278-
if err != nil {
279-
return nil, err
278+
279+
ctx = driverctx.NewContextWithConnId(ctx, c.id)
280+
resp, err := c.client.ExecuteStatement(ctx, &req)
281+
282+
var shouldCancel = func(resp *cli_service.TExecuteStatementResp) bool {
283+
if resp == nil {
284+
return false
285+
}
286+
hasHandle := resp.OperationHandle != nil
287+
isOpen := resp.DirectResults != nil && resp.DirectResults.CloseOperation == nil
288+
return hasHandle && isOpen
280289
}
281290

282-
exStmtResp, ok := res.(*cli_service.TExecuteStatementResp)
283-
if !ok {
284-
return exStmtResp, errors.New("databricks: invalid execute statement response")
291+
select {
292+
default:
293+
case <-ctx.Done():
294+
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
295+
// in case context is done, we need to cancel the operation if necessary
296+
if err == nil && shouldCancel(resp) {
297+
log.Debug().Msg("databricks: canceling query")
298+
_, err1 := c.client.CancelOperation(newCtx, &cli_service.TCancelOperationReq{
299+
OperationHandle: resp.GetOperationHandle(),
300+
})
301+
302+
if err1 != nil {
303+
log.Err(err).Msgf("databricks: cancel failed")
304+
}
305+
log.Debug().Msgf("databricks: cancel success")
306+
307+
} else {
308+
log.Debug().Msg("databricks: query did not need cancellation")
309+
}
310+
return nil, ctx.Err()
285311
}
286312

287-
return exStmtResp, err
313+
return resp, err
288314
}
289315

290316
func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
@@ -312,7 +338,9 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
312338
return true
313339
}
314340
switch statusResp.GetOperationState() {
315-
case cli_service.TOperationState_INITIALIZED_STATE, cli_service.TOperationState_PENDING_STATE, cli_service.TOperationState_RUNNING_STATE:
341+
case cli_service.TOperationState_INITIALIZED_STATE,
342+
cli_service.TOperationState_PENDING_STATE,
343+
cli_service.TOperationState_RUNNING_STATE:
316344
return false
317345
default:
318346
log.Debug().Msg("databricks: polling done")

connection_test.go

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ func TestConn_executeStatement(t *testing.T) {
131131
return &cli_service.TCloseOperationResp{}, nil
132132
},
133133
}
134-
135134
testConn := &conn{
136135
session: getTestSession(),
137136
client: testClient,
@@ -181,6 +180,147 @@ func TestConn_executeStatement(t *testing.T) {
181180
assert.Equal(t, 0, closeOperationCount)
182181
})
183182

183+
t.Run("executeStatement should not call cancel if not needed", func(t *testing.T) {
184+
var executeStatementCount int
185+
var cancelOperationCount int
186+
var cancel context.CancelFunc
187+
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
188+
executeStatementCount++
189+
cancel()
190+
executeStatementResp := &cli_service.TExecuteStatementResp{
191+
Status: &cli_service.TStatus{
192+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
193+
},
194+
OperationHandle: &cli_service.TOperationHandle{
195+
OperationId: &cli_service.THandleIdentifier{
196+
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 4, 223, 34, 54},
197+
Secret: []byte("b"),
198+
},
199+
},
200+
DirectResults: &cli_service.TSparkDirectResults{
201+
OperationStatus: &cli_service.TGetOperationStatusResp{
202+
Status: &cli_service.TStatus{
203+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
204+
},
205+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
206+
ErrorMessage: strPtr("error message"),
207+
DisplayMessage: strPtr("display message"),
208+
},
209+
ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{
210+
Status: &cli_service.TStatus{
211+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
212+
},
213+
},
214+
ResultSet: &cli_service.TFetchResultsResp{
215+
Status: &cli_service.TStatus{
216+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
217+
},
218+
},
219+
CloseOperation: &cli_service.TCloseOperationResp{
220+
Status: &cli_service.TStatus{
221+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
222+
},
223+
},
224+
},
225+
}
226+
return executeStatementResp, nil
227+
}
228+
cancelOperation := func(ctx context.Context, req *cli_service.TCancelOperationReq) (r *cli_service.TCancelOperationResp, err error) {
229+
cancelOperationCount++
230+
cancelOperationResp := &cli_service.TCancelOperationResp{
231+
Status: &cli_service.TStatus{
232+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
233+
},
234+
}
235+
return cancelOperationResp, nil
236+
}
237+
testClient := &client.TestClient{
238+
FnExecuteStatement: executeStatement,
239+
FnCancelOperation: cancelOperation,
240+
}
241+
testConn := &conn{
242+
session: getTestSession(),
243+
client: testClient,
244+
cfg: config.WithDefaults(),
245+
}
246+
247+
ctx := context.Background()
248+
ctx, cancel = context.WithCancel(ctx)
249+
defer cancel()
250+
_, err := testConn.executeStatement(ctx, "select 1", []driver.NamedValue{})
251+
252+
assert.Error(t, err)
253+
assert.Equal(t, 1, executeStatementCount)
254+
assert.Equal(t, 0, cancelOperationCount)
255+
})
256+
t.Run("executeStatement should call cancel if needed", func(t *testing.T) {
257+
var executeStatementCount int
258+
var cancelOperationCount int
259+
var cancel context.CancelFunc
260+
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
261+
executeStatementCount++
262+
cancel()
263+
executeStatementResp := &cli_service.TExecuteStatementResp{
264+
Status: &cli_service.TStatus{
265+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
266+
},
267+
OperationHandle: &cli_service.TOperationHandle{
268+
OperationId: &cli_service.THandleIdentifier{
269+
GUID: []byte{1, 2, 3, 4, 2, 23, 4, 2, 3, 1, 2, 3, 4, 4, 223, 34, 54},
270+
Secret: []byte("b"),
271+
},
272+
},
273+
DirectResults: &cli_service.TSparkDirectResults{
274+
OperationStatus: &cli_service.TGetOperationStatusResp{
275+
Status: &cli_service.TStatus{
276+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
277+
},
278+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
279+
ErrorMessage: strPtr("error message"),
280+
DisplayMessage: strPtr("display message"),
281+
},
282+
ResultSetMetadata: &cli_service.TGetResultSetMetadataResp{
283+
Status: &cli_service.TStatus{
284+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
285+
},
286+
},
287+
ResultSet: &cli_service.TFetchResultsResp{
288+
Status: &cli_service.TStatus{
289+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
290+
},
291+
},
292+
},
293+
}
294+
return executeStatementResp, nil
295+
}
296+
cancelOperation := func(ctx context.Context, req *cli_service.TCancelOperationReq) (r *cli_service.TCancelOperationResp, err error) {
297+
cancelOperationCount++
298+
cancelOperationResp := &cli_service.TCancelOperationResp{
299+
Status: &cli_service.TStatus{
300+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
301+
},
302+
}
303+
return cancelOperationResp, nil
304+
}
305+
testClient := &client.TestClient{
306+
FnExecuteStatement: executeStatement,
307+
FnCancelOperation: cancelOperation,
308+
}
309+
testConn := &conn{
310+
session: getTestSession(),
311+
client: testClient,
312+
cfg: config.WithDefaults(),
313+
}
314+
ctx := context.Background()
315+
ctx, cancel = context.WithCancel(ctx)
316+
defer cancel()
317+
_, err := testConn.executeStatement(ctx, "select 1", []driver.NamedValue{})
318+
319+
assert.Error(t, err)
320+
assert.Equal(t, 1, executeStatementCount)
321+
assert.Equal(t, 1, cancelOperationCount)
322+
})
323+
184324
}
185325

186326
func TestConn_pollOperation(t *testing.T) {

0 commit comments

Comments
 (0)