Skip to content

Commit b46e066

Browse files
committed
Fix context loss in polling and connection close operations
Signed-off-by: Diego Giagio <[email protected]>
1 parent 9cef5b7 commit b46e066

File tree

3 files changed

+24
-28
lines changed

3 files changed

+24
-28
lines changed

connection.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func (c *conn) Close() error {
5555

5656
if err != nil {
5757
log.Err(err).Msg("databricks: failed to close connection")
58-
return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err)
58+
return dbsqlerrint.NewBadConnectionError(err)
5959
}
6060
return nil
6161
}
@@ -168,9 +168,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
168168
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
169169
}
170170

171-
corrId := driverctx.CorrelationIdFromContext(ctx)
172-
rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
173-
171+
rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
174172
return rows, err
175173

176174
}
@@ -367,7 +365,7 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
367365
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
368366
var statusResp *cli_service.TGetOperationStatusResp
369367
ctx = driverctx.NewContextWithConnId(ctx, c.id)
370-
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
368+
newCtx := driverctx.NewContextWithCorrelationId(ctx, corrId)
371369
pollSentinel := sentinel.Sentinel{
372370
OnDoneFn: func(statusResp any) (any, error) {
373371
return statusResp, nil
@@ -566,7 +564,6 @@ func (c *conn) execStagingOperation(
566564
return nil
567565
}
568566

569-
corrId := driverctx.CorrelationIdFromContext(ctx)
570567
var row driver.Rows
571568
var err error
572569

@@ -589,7 +586,7 @@ func (c *conn) execStagingOperation(
589586
}
590587

591588
if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
592-
row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
589+
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
593590
if err != nil {
594591
return dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
595592
}

internal/rows/rows.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,22 @@ var _ driver.RowsColumnTypeLength = (*rows)(nil)
6767
var _ dbsqlrows.Rows = (*rows)(nil)
6868

6969
func NewRows(
70-
connId string,
71-
correlationId string,
70+
ctx context.Context,
7271
opHandle *cli_service.TOperationHandle,
7372
client cli_service.TCLIService,
7473
config *config.Config,
7574
directResults *cli_service.TSparkDirectResults,
7675
) (driver.Rows, dbsqlerr.DBError) {
7776

77+
connId := driverctx.ConnIdFromContext(ctx)
78+
correlationId := driverctx.CorrelationIdFromContext(ctx)
79+
7880
var logger *dbsqllog.DBSQLLogger
79-
var ctx context.Context
8081
if opHandle != nil {
8182
logger = dbsqllog.WithContext(connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
82-
ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
83+
ctx = driverctx.NewContextWithQueryId(ctx, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
8384
} else {
8485
logger = dbsqllog.WithContext(connId, correlationId, "")
85-
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId)
8686
}
8787

8888
if client == nil {
@@ -140,13 +140,12 @@ func NewRows(
140140
// the operations.
141141
closedOnServer := directResults != nil && directResults.CloseOperation != nil
142142
r.ResultPageIterator = rowscanner.NewResultPageIterator(
143+
ctx,
143144
d,
144145
pageSize,
145146
opHandle,
146147
closedOnServer,
147148
client,
148-
connId,
149-
correlationId,
150149
r.logger(),
151150
)
152151

@@ -417,9 +416,8 @@ func (r *rows) getResultSetSchema() (*cli_service.TTableSchema, dbsqlerr.DBError
417416
req := cli_service.TGetResultSetMetadataReq{
418417
OperationHandle: r.opHandle,
419418
}
420-
ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId)
421419

422-
resp, err2 := r.client.GetResultSetMetadata(ctx, &req)
420+
resp, err2 := r.client.GetResultSetMetadata(r.ctx, &req)
423421
if err2 != nil {
424422
r.logger().Err(err2).Msg(err2.Error())
425423
return nil, dbsqlerr_int.NewRequestError(r.ctx, errRowsMetadataFetchFailed, err)

internal/rows/rowscanner/resultPageIterator.go

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,34 @@ func (d Direction) String() string {
4545

4646
// Create a new result page iterator.
4747
func NewResultPageIterator(
48+
ctx context.Context,
4849
delimiter Delimiter,
4950
maxPageSize int64,
5051
opHandle *cli_service.TOperationHandle,
5152
closedOnServer bool,
5253
client cli_service.TCLIService,
53-
connectionId string,
54-
correlationId string,
5554
logger *dbsqllog.DBSQLLogger,
5655
) ResultPageIterator {
5756

5857
// delimiter and hasMoreRows are used to set up the point in the paginated
5958
// result set that this iterator starts from.
6059
return &resultPageIterator{
60+
ctx: ctx,
6161
Delimiter: delimiter,
6262
isFinished: closedOnServer,
6363
maxPageSize: maxPageSize,
6464
opHandle: opHandle,
6565
closedOnServer: closedOnServer,
6666
client: client,
67-
connectionId: connectionId,
68-
correlationId: correlationId,
67+
connectionId: driverctx.ConnIdFromContext(ctx),
68+
correlationId: driverctx.CorrelationIdFromContext(ctx),
6969
logger: logger,
7070
}
7171
}
7272

7373
type resultPageIterator struct {
74+
ctx context.Context
75+
7476
// Gives the parameters of the current result page
7577
Delimiter
7678

@@ -167,15 +169,14 @@ func (rpf *resultPageIterator) getNextPage() (*cli_service.TFetchResultsResp, er
167169
nextPageStartRow := rpf.Start() + rpf.Count()
168170

169171
rpf.logger.Debug().Msgf("databricks: fetching result page for row %d", nextPageStartRow)
170-
ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), rpf.connectionId), rpf.correlationId)
171172

172173
// Keep fetching in the appropriate direction until we have the expected page.
173174
var fetchResult *cli_service.TFetchResultsResp
174175
var b bool
175176
for b = rpf.Contains(nextPageStartRow); !b; b = rpf.Contains(nextPageStartRow) {
176177

177178
direction := rpf.Direction(nextPageStartRow)
178-
err := rpf.checkDirectionValid(ctx, direction)
179+
err := rpf.checkDirectionValid(direction)
179180
if err != nil {
180181
return nil, err
181182
}
@@ -190,10 +191,10 @@ func (rpf *resultPageIterator) getNextPage() (*cli_service.TFetchResultsResp, er
190191
IncludeResultSetMetadata: &includeResultSetMetadata,
191192
}
192193

193-
fetchResult, err = rpf.client.FetchResults(ctx, &req)
194+
fetchResult, err = rpf.client.FetchResults(rpf.ctx, &req)
194195
if err != nil {
195196
rpf.logger.Err(err).Msg("databricks: Rows instance failed to retrieve results")
196-
return nil, dbsqlerrint.NewRequestError(ctx, errRowsResultFetchFailed, err)
197+
return nil, dbsqlerrint.NewRequestError(rpf.ctx, errRowsResultFetchFailed, err)
197198
}
198199

199200
rpf.Delimiter = NewDelimiter(fetchResult.Results.StartRowOffset, CountRows(fetchResult.Results))
@@ -218,7 +219,7 @@ func (rpf *resultPageIterator) Close() (err error) {
218219
OperationHandle: rpf.opHandle,
219220
}
220221

221-
_, err = rpf.client.CloseOperation(context.Background(), &req)
222+
_, err = rpf.client.CloseOperation(rpf.ctx, &req)
222223
return err
223224
}
224225
}
@@ -283,11 +284,11 @@ func CountRows(rowSet *cli_service.TRowSet) int64 {
283284
}
284285

285286
// Check if trying to fetch in the specified direction creates an error condition.
286-
func (rpf *resultPageIterator) checkDirectionValid(ctx context.Context, direction Direction) error {
287+
func (rpf *resultPageIterator) checkDirectionValid(direction Direction) error {
287288
if direction == DirBack {
288289
// can't fetch rows previous to the start
289290
if rpf.Start() == 0 {
290-
return dbsqlerrint.NewDriverError(ctx, ErrRowsFetchPriorToStart, nil)
291+
return dbsqlerrint.NewDriverError(rpf.ctx, ErrRowsFetchPriorToStart, nil)
291292
}
292293
} else if direction == DirForward {
293294
// can't fetch past the end of the query results
@@ -296,7 +297,7 @@ func (rpf *resultPageIterator) checkDirectionValid(ctx context.Context, directio
296297
}
297298
} else {
298299
rpf.logger.Error().Msgf(errRowsUnandledFetchDirection(direction.String()))
299-
return dbsqlerrint.NewDriverError(ctx, errRowsUnandledFetchDirection(direction.String()), nil)
300+
return dbsqlerrint.NewDriverError(rpf.ctx, errRowsUnandledFetchDirection(direction.String()), nil)
300301
}
301302
return nil
302303
}

0 commit comments

Comments
 (0)