@@ -34,14 +34,11 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
3434func (c * conn ) Close () error {
3535 log := logger .WithContext (c .id , "" , "" )
3636 ctx := driverctx .NewContextWithConnId (context .Background (), c .id )
37- sentinel := sentinel.Sentinel {
38- OnDoneFn : func (statusResp any ) (any , error ) {
39- return c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
40- SessionHandle : c .session .SessionHandle ,
41- })
42- },
43- }
44- _ , _ , err := sentinel .Watch (ctx , c .cfg .PollInterval , 15 * time .Second )
37+
38+ _ , err := c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
39+ SessionHandle : c .session .SessionHandle ,
40+ })
41+
4542 if err != nil {
4643 log .Err (err ).Msg ("databricks: failed to close connection" )
4744 return wrapErr (err , "failed to close connection" )
@@ -62,7 +59,7 @@ func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
6259func (c * conn ) Ping (ctx context.Context ) error {
6360 log := logger .WithContext (c .id , driverctx .CorrelationIdFromContext (ctx ), "" )
6461 ctx = driverctx .NewContextWithConnId (ctx , c .id )
65- ctx1 , cancel := context .WithTimeout (ctx , 15 * time .Second )
62+ ctx1 , cancel := context .WithTimeout (ctx , 60 * time .Second )
6663 defer cancel ()
6764 _ , err := c .QueryContext (ctx1 , "select 1" , nil )
6865 if err != nil {
@@ -113,7 +110,6 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
113110 }
114111 }
115112 }
116-
117113 if err != nil {
118114 // TODO: are there error situations in which the operation still needs to be closed?
119115 // Currently if there is an error we never get back a TExecuteStatementResponse so
@@ -151,6 +147,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
151147 defer log .Duration (msg , start )
152148
153149 if err != nil {
150+ // gotta also think about close operation here
154151 log .Err (err ).Msgf ("databricks: failed to run query: query %s" , query )
155152 return nil , wrapErrf (err , "failed to run query" )
156153 }
@@ -175,7 +172,10 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
175172 // hold on to the operation handle
176173 opHandle := exStmtResp .OperationHandle
177174 if opHandle != nil && opHandle .OperationId != nil {
178- log = logger .WithContext (c .id , driverctx .CorrelationIdFromContext (ctx ), client .SprintGuid (opHandle .OperationId .GUID ))
175+ log = logger .WithContext (
176+ c .id ,
177+ driverctx .CorrelationIdFromContext (ctx ), client .SprintGuid (opHandle .OperationId .GUID ),
178+ )
179179 }
180180
181181 if exStmtResp .DirectResults != nil {
@@ -188,12 +188,17 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
188188 // return results
189189 return exStmtResp , opStatus , nil
190190 // bad
191- case cli_service .TOperationState_CANCELED_STATE , cli_service .TOperationState_CLOSED_STATE , cli_service .TOperationState_ERROR_STATE , cli_service .TOperationState_TIMEDOUT_STATE :
191+ case cli_service .TOperationState_CANCELED_STATE ,
192+ cli_service .TOperationState_CLOSED_STATE ,
193+ cli_service .TOperationState_ERROR_STATE ,
194+ cli_service .TOperationState_TIMEDOUT_STATE :
192195 // do we need to close the operation in these cases?
193196 logBadQueryState (log , opStatus )
194197 return exStmtResp , opStatus , errors .New (opStatus .GetDisplayMessage ())
195198 // live states
196- case cli_service .TOperationState_INITIALIZED_STATE , cli_service .TOperationState_PENDING_STATE , cli_service .TOperationState_RUNNING_STATE :
199+ case cli_service .TOperationState_INITIALIZED_STATE ,
200+ cli_service .TOperationState_PENDING_STATE ,
201+ cli_service .TOperationState_RUNNING_STATE :
197202 statusResp , err := c .pollOperation (ctx , opHandle )
198203 if err != nil {
199204 return exStmtResp , statusResp , err
@@ -205,7 +210,10 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
205210 // return handle to fetch results later
206211 return exStmtResp , opStatus , nil
207212 // bad
208- case cli_service .TOperationState_CANCELED_STATE , cli_service .TOperationState_CLOSED_STATE , cli_service .TOperationState_ERROR_STATE , cli_service .TOperationState_TIMEDOUT_STATE :
213+ case cli_service .TOperationState_CANCELED_STATE ,
214+ cli_service .TOperationState_CLOSED_STATE ,
215+ cli_service .TOperationState_ERROR_STATE ,
216+ cli_service .TOperationState_TIMEDOUT_STATE :
209217 logBadQueryState (log , statusResp )
210218 return exStmtResp , opStatus , errors .New (statusResp .GetDisplayMessage ())
211219 // live states
@@ -231,7 +239,10 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
231239 // return handle to fetch results later
232240 return exStmtResp , statusResp , nil
233241 // bad
234- case cli_service .TOperationState_CANCELED_STATE , cli_service .TOperationState_CLOSED_STATE , cli_service .TOperationState_ERROR_STATE , cli_service .TOperationState_TIMEDOUT_STATE :
242+ case cli_service .TOperationState_CANCELED_STATE ,
243+ cli_service .TOperationState_CLOSED_STATE ,
244+ cli_service .TOperationState_ERROR_STATE ,
245+ cli_service .TOperationState_TIMEDOUT_STATE :
235246 logBadQueryState (log , statusResp )
236247 return exStmtResp , statusResp , errors .New (statusResp .GetDisplayMessage ())
237248 // live states
@@ -250,41 +261,56 @@ func logBadQueryState(log *logger.DBSQLLogger, opStatus *cli_service.TGetOperati
250261func (c * conn ) executeStatement (ctx context.Context , query string , args []driver.NamedValue ) (* cli_service.TExecuteStatementResp , error ) {
251262 corrId := driverctx .CorrelationIdFromContext (ctx )
252263 log := logger .WithContext (c .id , corrId , "" )
253- sentinel := sentinel.Sentinel {
254- OnDoneFn : func (statusResp any ) (any , error ) {
255- req := cli_service.TExecuteStatementReq {
256- SessionHandle : c .session .SessionHandle ,
257- Statement : query ,
258- RunAsync : c .cfg .RunAsync ,
259- QueryTimeout : int64 (c .cfg .QueryTimeout / time .Second ),
260- // this is specific for databricks. It shortcuts server roundtrips
261- GetDirectResults : & cli_service.TSparkGetDirectResults {
262- MaxRows : int64 (c .cfg .MaxRows ),
263- },
264- // CanReadArrowResult_: &t,
265- // CanDecompressLZ4Result_: &f,
266- // CanDownloadResult_: &t,
267- }
268- ctx = driverctx .NewContextWithConnId (ctx , c .id )
269- resp , err := c .client .ExecuteStatement (ctx , & req )
270- return resp , wrapErr (err , "failed to execute statement" )
271- },
272- OnCancelFn : func () (any , error ) {
273- log .Warn ().Msg ("databricks: execute statement canceled while creation operation" )
274- return nil , nil
264+
265+ req := cli_service.TExecuteStatementReq {
266+ SessionHandle : c .session .SessionHandle ,
267+ Statement : query ,
268+ RunAsync : c .cfg .RunAsync ,
269+ QueryTimeout : int64 (c .cfg .QueryTimeout / time .Second ),
270+ // this is specific for databricks. It shortcuts server round trips
271+ GetDirectResults : & cli_service.TSparkGetDirectResults {
272+ MaxRows : int64 (c .cfg .MaxRows ),
275273 },
274+ // CanReadArrowResult_: &t,
275+ // CanDecompressLZ4Result_: &f,
276+ // CanDownloadResult_: &t,
276277 }
277- _ , res , err := sentinel .Watch (ctx , c .cfg .PollInterval , c .cfg .QueryTimeout )
278- if err != nil {
279- return nil , err
278+
279+ ctx = driverctx .NewContextWithConnId (ctx , c .id )
280+ resp , err := c .client .ExecuteStatement (ctx , & req )
281+
282+ var shouldCancel = func (resp * cli_service.TExecuteStatementResp ) bool {
283+ if resp == nil {
284+ return false
285+ }
286+ hasHandle := resp .OperationHandle != nil
287+ isOpen := resp .DirectResults != nil && resp .DirectResults .CloseOperation == nil
288+ return hasHandle && isOpen
280289 }
281290
282- exStmtResp , ok := res .(* cli_service.TExecuteStatementResp )
283- if ! ok {
284- return exStmtResp , errors .New ("databricks: invalid execute statement response" )
291+ select {
292+ default :
293+ case <- ctx .Done ():
294+ newCtx := driverctx .NewContextWithCorrelationId (driverctx .NewContextWithConnId (context .Background (), c .id ), corrId )
295+ // in case context is done, we need to cancel the operation if necessary
296+ if err == nil && shouldCancel (resp ) {
297+ log .Debug ().Msg ("databricks: canceling query" )
298+ _ , err1 := c .client .CancelOperation (newCtx , & cli_service.TCancelOperationReq {
299+ OperationHandle : resp .GetOperationHandle (),
300+ })
301+
302+ if err1 != nil {
303+ log .Err (err ).Msgf ("databricks: cancel failed" )
304+ }
305+ log .Debug ().Msgf ("databricks: cancel success" )
306+
307+ } else {
308+ log .Debug ().Msg ("databricks: query did not need cancellation" )
309+ }
310+ return nil , ctx .Err ()
285311 }
286312
287- return exStmtResp , err
313+ return resp , err
288314}
289315
290316func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
@@ -312,7 +338,9 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
312338 return true
313339 }
314340 switch statusResp .GetOperationState () {
315- case cli_service .TOperationState_INITIALIZED_STATE , cli_service .TOperationState_PENDING_STATE , cli_service .TOperationState_RUNNING_STATE :
341+ case cli_service .TOperationState_INITIALIZED_STATE ,
342+ cli_service .TOperationState_PENDING_STATE ,
343+ cli_service .TOperationState_RUNNING_STATE :
316344 return false
317345 default :
318346 log .Debug ().Msg ("databricks: polling done" )
0 commit comments