Skip to content

Commit 688cb7f

Browse files
committed
Fix context loss in polling and connection close operations
Signed-off-by: Diego Giagio <[email protected]>
1 parent d5c68a7 commit 688cb7f

File tree

6 files changed

+96
-57
lines changed

6 files changed

+96
-57
lines changed

connection.go

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func (c *conn) Close() error {
5555

5656
if err != nil {
5757
log.Err(err).Msg("databricks: failed to close connection")
58-
return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err)
58+
return dbsqlerrint.NewBadConnectionError(err)
5959
}
6060
return nil
6161
}
@@ -168,9 +168,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
168168
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
169169
}
170170

171-
corrId := driverctx.CorrelationIdFromContext(ctx)
172-
rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
173-
171+
rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
174172
return rows, err
175173

176174
}
@@ -367,7 +365,7 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
367365
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
368366
var statusResp *cli_service.TGetOperationStatusResp
369367
ctx = driverctx.NewContextWithConnId(ctx, c.id)
370-
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
368+
newCtx := context.WithoutCancel(ctx)
371369
pollSentinel := sentinel.Sentinel{
372370
OnDoneFn: func(statusResp any) (any, error) {
373371
return statusResp, nil
@@ -566,7 +564,6 @@ func (c *conn) execStagingOperation(
566564
return nil
567565
}
568566

569-
corrId := driverctx.CorrelationIdFromContext(ctx)
570567
var row driver.Rows
571568
var err error
572569

@@ -589,7 +586,7 @@ func (c *conn) execStagingOperation(
589586
}
590587

591588
if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
592-
row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
589+
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
593590
if err != nil {
594591
return dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
595592
}

internal/rows/arrowbased/arrowRecordIterator_test.go

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"os"
99
"testing"
1010

11+
"github.com/databricks/databricks-sql-go/driverctx"
1112
"github.com/databricks/databricks-sql-go/internal/cli_service"
1213
"github.com/databricks/databricks-sql-go/internal/client"
1314
"github.com/databricks/databricks-sql-go/internal/config"
@@ -32,15 +33,17 @@ func TestArrowRecordIterator(t *testing.T) {
3233

3334
var fetchesInfo []fetchResultsInfo
3435

36+
ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
37+
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")
38+
3539
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2})
3640
rpi := rowscanner.NewResultPageIterator(
41+
ctx,
3742
rowscanner.NewDelimiter(0, 7311),
3843
5000,
3944
nil,
4045
false,
4146
simpleClient,
42-
"connectionId",
43-
"correlationId",
4447
logger,
4548
)
4649

@@ -126,17 +129,19 @@ func TestArrowRecordIterator(t *testing.T) {
126129
fetchResp3 := cli_service.TFetchResultsResp{}
127130
loadTestData2(t, "multipleFetch/FetchResults3.json", &fetchResp3)
128131

132+
ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
133+
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")
134+
129135
var fetchesInfo []fetchResultsInfo
130136

131137
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3})
132138
rpi := rowscanner.NewResultPageIterator(
139+
ctx,
133140
rowscanner.NewDelimiter(0, 0),
134141
5000,
135142
nil,
136143
false,
137144
simpleClient,
138-
"connectionId",
139-
"correlationId",
140145
logger,
141146
)
142147

@@ -199,16 +204,18 @@ func TestArrowRecordIteratorSchema(t *testing.T) {
199204
fetchResp1 := cli_service.TFetchResultsResp{}
200205
loadTestData2(t, "directResultsMultipleFetch/FetchResults1.json", &fetchResp1)
201206

207+
ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
208+
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")
209+
202210
var fetchesInfo []fetchResultsInfo
203211
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
204212
rpi := rowscanner.NewResultPageIterator(
213+
ctx,
205214
rowscanner.NewDelimiter(0, 0),
206215
5000,
207216
nil,
208217
false,
209218
simpleClient,
210-
"connectionId",
211-
"correlationId",
212219
logger,
213220
)
214221

@@ -251,16 +258,18 @@ func TestArrowRecordIteratorSchema(t *testing.T) {
251258
fetchResp1 := cli_service.TFetchResultsResp{}
252259
loadTestData2(t, "multipleFetch/FetchResults1.json", &fetchResp1)
253260

261+
ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
262+
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")
263+
254264
var fetchesInfo []fetchResultsInfo
255265
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
256266
rpi := rowscanner.NewResultPageIterator(
267+
ctx,
257268
rowscanner.NewDelimiter(0, 0),
258269
5000,
259270
nil,
260271
false,
261272
simpleClient,
262-
"connectionId",
263-
"correlationId",
264273
logger,
265274
)
266275

@@ -293,14 +302,16 @@ func TestArrowRecordIteratorSchema(t *testing.T) {
293302
},
294303
}
295304

305+
ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
306+
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")
307+
296308
rpi := rowscanner.NewResultPageIterator(
309+
ctx,
297310
rowscanner.NewDelimiter(0, 0),
298311
5000,
299312
nil,
300313
false,
301314
failingClient,
302-
"connectionId",
303-
"correlationId",
304315
logger,
305316
)
306317

internal/rows/arrowbased/arrowRows_test.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313

1414
"github.com/apache/arrow/go/v12/arrow"
1515
"github.com/apache/arrow/go/v12/arrow/array"
16+
"github.com/databricks/databricks-sql-go/driverctx"
1617
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1718
"github.com/databricks/databricks-sql-go/internal/cli_service"
1819
"github.com/databricks/databricks-sql-go/internal/config"
@@ -1525,18 +1526,20 @@ func TestArrowRowScanner(t *testing.T) {
15251526
fetchResp2 := cli_service.TFetchResultsResp{}
15261527
loadTestData2(t, "directResultsMultipleFetch/FetchResults2.json", &fetchResp2)
15271528

1529+
ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
1530+
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")
1531+
15281532
var fetchesInfo []fetchResultsInfo
15291533
client := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2})
15301534
logger := dbsqllog.WithContext("connectionId", "correlationId", "")
15311535

15321536
rpi := rowscanner.NewResultPageIterator(
1537+
ctx,
15331538
rowscanner.NewDelimiter(0, 7311),
15341539
5000,
15351540
nil,
15361541
false,
15371542
client,
1538-
"connectionId",
1539-
"correlationId",
15401543
logger)
15411544

15421545
cfg := config.WithDefaults()

internal/rows/rows.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,22 @@ var _ driver.RowsColumnTypeLength = (*rows)(nil)
6767
var _ dbsqlrows.Rows = (*rows)(nil)
6868

6969
func NewRows(
70-
connId string,
71-
correlationId string,
70+
ctx context.Context,
7271
opHandle *cli_service.TOperationHandle,
7372
client cli_service.TCLIService,
7473
config *config.Config,
7574
directResults *cli_service.TSparkDirectResults,
7675
) (driver.Rows, dbsqlerr.DBError) {
7776

77+
connId := driverctx.ConnIdFromContext(ctx)
78+
correlationId := driverctx.CorrelationIdFromContext(ctx)
79+
7880
var logger *dbsqllog.DBSQLLogger
79-
var ctx context.Context
8081
if opHandle != nil {
8182
logger = dbsqllog.WithContext(connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
82-
ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
83+
ctx = driverctx.NewContextWithQueryId(ctx, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
8384
} else {
8485
logger = dbsqllog.WithContext(connId, correlationId, "")
85-
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId)
8686
}
8787

8888
if client == nil {
@@ -140,13 +140,12 @@ func NewRows(
140140
// the operations.
141141
closedOnServer := directResults != nil && directResults.CloseOperation != nil
142142
r.ResultPageIterator = rowscanner.NewResultPageIterator(
143+
ctx,
143144
d,
144145
pageSize,
145146
opHandle,
146147
closedOnServer,
147148
client,
148-
connId,
149-
correlationId,
150149
r.logger(),
151150
)
152151

@@ -417,9 +416,8 @@ func (r *rows) getResultSetSchema() (*cli_service.TTableSchema, dbsqlerr.DBError
417416
req := cli_service.TGetResultSetMetadataReq{
418417
OperationHandle: r.opHandle,
419418
}
420-
ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId)
421419

422-
resp, err2 := r.client.GetResultSetMetadata(ctx, &req)
420+
resp, err2 := r.client.GetResultSetMetadata(r.ctx, &req)
423421
if err2 != nil {
424422
r.logger().Err(err2).Msg(err2.Error())
425423
return nil, dbsqlerr_int.NewRequestError(r.ctx, errRowsMetadataFetchFailed, err)

0 commit comments

Comments
 (0)