Skip to content

Commit f4763da

Browse files
authored
handle thrift protocol version
1 parent 644d6da commit f4763da

File tree

7 files changed

+376
-9
lines changed

7 files changed

+376
-9
lines changed

connection.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
2121
"github.com/databricks/databricks-sql-go/internal/rows"
2222
"github.com/databricks/databricks-sql-go/internal/sentinel"
23+
"github.com/databricks/databricks-sql-go/internal/thrift_protocol"
2324
"github.com/databricks/databricks-sql-go/logger"
2425
"github.com/pkg/errors"
2526
)
@@ -285,14 +286,30 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
285286
Statement: query,
286287
RunAsync: true,
287288
QueryTimeout: int64(c.cfg.QueryTimeout / time.Second),
288-
GetDirectResults: &cli_service.TSparkGetDirectResults{
289+
}
290+
291+
// Check protocol version for feature support
292+
serverProtocolVersion := c.session.ServerProtocolVersion
293+
294+
// Add direct results if supported
295+
if thrift_protocol.SupportsDirectResults(serverProtocolVersion) {
296+
req.GetDirectResults = &cli_service.TSparkGetDirectResults{
289297
MaxRows: int64(c.cfg.MaxRows),
290-
},
291-
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
292-
Parameters: parameters,
298+
}
293299
}
294300

295-
if c.cfg.UseArrowBatches {
301+
// Add LZ4 compression if supported and enabled
302+
if thrift_protocol.SupportsLz4Compression(serverProtocolVersion) && c.cfg.UseLz4Compression {
303+
req.CanDecompressLZ4Result_ = &c.cfg.UseLz4Compression
304+
}
305+
306+
// Add cloud fetch if supported and enabled
307+
if thrift_protocol.SupportsCloudFetch(serverProtocolVersion) && c.cfg.UseCloudFetch {
308+
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
309+
}
310+
311+
// Add Arrow support if supported and enabled
312+
if thrift_protocol.SupportsArrow(serverProtocolVersion) && c.cfg.UseArrowBatches {
296313
req.CanReadArrowResult_ = &c.cfg.UseArrowBatches
297314
req.UseArrowNativeTypes = &cli_service.TSparkArrowTypes{
298315
DecimalAsArrow: &c.cfg.UseArrowNativeDecimal,
@@ -302,8 +319,9 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
302319
}
303320
}
304321

305-
if c.cfg.UseCloudFetch {
306-
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
322+
// Add parameters if supported and provided
323+
if thrift_protocol.SupportsParameterizedQueries(serverProtocolVersion) && len(parameters) > 0 {
324+
req.Parameters = parameters
307325
}
308326

309327
resp, err := c.client.ExecuteStatement(ctx, &req)

connection_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/databricks/databricks-sql-go/internal/cli_service"
1515
"github.com/databricks/databricks-sql-go/internal/client"
1616
"github.com/databricks/databricks-sql-go/internal/config"
17+
"github.com/databricks/databricks-sql-go/internal/thrift_protocol"
1718
"github.com/stretchr/testify/assert"
1819
)
1920

@@ -331,6 +332,167 @@ func TestConn_executeStatement(t *testing.T) {
331332

332333
}
333334

335+
func TestConn_executeStatement_ProtocolFeatures(t *testing.T) {
336+
t.Parallel()
337+
338+
protocols := []cli_service.TProtocolVersion{
339+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V1,
340+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V2,
341+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V3,
342+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V4,
343+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V5,
344+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6,
345+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V7,
346+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8,
347+
}
348+
349+
testCases := []struct {
350+
cfg *config.Config
351+
supportsDirectResults func(version cli_service.TProtocolVersion) bool
352+
supportsLz4Compression func(version cli_service.TProtocolVersion) bool
353+
supportsCloudFetch func(version cli_service.TProtocolVersion) bool
354+
supportsArrow func(version cli_service.TProtocolVersion) bool
355+
supportsParameterizedQueries func(version cli_service.TProtocolVersion) bool
356+
hasParameters bool
357+
}{
358+
{
359+
cfg: func() *config.Config {
360+
cfg := config.WithDefaults()
361+
cfg.UseLz4Compression = true
362+
cfg.UseCloudFetch = true
363+
cfg.UseArrowBatches = true
364+
cfg.UseArrowNativeDecimal = true
365+
cfg.UseArrowNativeTimestamp = true
366+
cfg.UseArrowNativeComplexTypes = true
367+
cfg.UseArrowNativeIntervalTypes = true
368+
return cfg
369+
}(),
370+
supportsDirectResults: thrift_protocol.SupportsDirectResults,
371+
supportsLz4Compression: thrift_protocol.SupportsLz4Compression,
372+
supportsCloudFetch: thrift_protocol.SupportsCloudFetch,
373+
supportsArrow: thrift_protocol.SupportsArrow,
374+
supportsParameterizedQueries: thrift_protocol.SupportsParameterizedQueries,
375+
hasParameters: true,
376+
},
377+
{
378+
cfg: func() *config.Config {
379+
cfg := config.WithDefaults()
380+
cfg.UseLz4Compression = false
381+
cfg.UseCloudFetch = false
382+
cfg.UseArrowBatches = false
383+
return cfg
384+
}(),
385+
supportsDirectResults: thrift_protocol.SupportsDirectResults,
386+
supportsLz4Compression: thrift_protocol.SupportsLz4Compression,
387+
supportsCloudFetch: thrift_protocol.SupportsCloudFetch,
388+
supportsArrow: thrift_protocol.SupportsArrow,
389+
supportsParameterizedQueries: thrift_protocol.SupportsParameterizedQueries,
390+
hasParameters: false,
391+
},
392+
}
393+
394+
for _, tc := range testCases {
395+
for _, version := range protocols {
396+
t.Run(fmt.Sprintf("protocol_v%d_withParams_%v", version, tc.hasParameters), func(t *testing.T) {
397+
var capturedReq *cli_service.TExecuteStatementReq
398+
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
399+
capturedReq = req
400+
executeStatementResp := &cli_service.TExecuteStatementResp{
401+
Status: &cli_service.TStatus{
402+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
403+
},
404+
OperationHandle: &cli_service.TOperationHandle{
405+
OperationId: &cli_service.THandleIdentifier{
406+
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
407+
Secret: []byte("secret"),
408+
},
409+
},
410+
DirectResults: &cli_service.TSparkDirectResults{
411+
OperationStatus: &cli_service.TGetOperationStatusResp{
412+
Status: &cli_service.TStatus{
413+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
414+
},
415+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
416+
},
417+
},
418+
}
419+
return executeStatementResp, nil
420+
}
421+
422+
session := getTestSession()
423+
session.ServerProtocolVersion = version
424+
425+
testClient := &client.TestClient{
426+
FnExecuteStatement: executeStatement,
427+
}
428+
429+
testConn := &conn{
430+
session: session,
431+
client: testClient,
432+
cfg: tc.cfg,
433+
}
434+
435+
var args []driver.NamedValue
436+
if tc.hasParameters {
437+
args = []driver.NamedValue{
438+
{Name: "param1", Value: "value1"},
439+
}
440+
}
441+
442+
_, err := testConn.executeStatement(context.Background(), "SELECT 1", args)
443+
assert.NoError(t, err)
444+
445+
// Verify direct results
446+
hasDirectResults := tc.supportsDirectResults(version)
447+
assert.Equal(t, hasDirectResults, capturedReq.GetDirectResults != nil, "Direct results should be enabled if protocol supports it")
448+
449+
// Verify LZ4 compression
450+
shouldHaveLz4 := tc.supportsLz4Compression(version) && tc.cfg.UseLz4Compression
451+
if shouldHaveLz4 {
452+
assert.NotNil(t, capturedReq.CanDecompressLZ4Result_)
453+
assert.True(t, *capturedReq.CanDecompressLZ4Result_)
454+
} else {
455+
assert.Nil(t, capturedReq.CanDecompressLZ4Result_)
456+
}
457+
458+
// Verify cloud fetch
459+
shouldHaveCloudFetch := tc.supportsCloudFetch(version) && tc.cfg.UseCloudFetch
460+
if shouldHaveCloudFetch {
461+
assert.NotNil(t, capturedReq.CanDownloadResult_)
462+
assert.True(t, *capturedReq.CanDownloadResult_)
463+
} else {
464+
assert.Nil(t, capturedReq.CanDownloadResult_)
465+
}
466+
467+
// Verify Arrow support
468+
shouldHaveArrow := tc.supportsArrow(version) && tc.cfg.UseArrowBatches
469+
if shouldHaveArrow {
470+
assert.NotNil(t, capturedReq.CanReadArrowResult_)
471+
assert.True(t, *capturedReq.CanReadArrowResult_)
472+
assert.NotNil(t, capturedReq.UseArrowNativeTypes)
473+
assert.Equal(t, tc.cfg.UseArrowNativeDecimal, *capturedReq.UseArrowNativeTypes.DecimalAsArrow)
474+
assert.Equal(t, tc.cfg.UseArrowNativeTimestamp, *capturedReq.UseArrowNativeTypes.TimestampAsArrow)
475+
assert.Equal(t, tc.cfg.UseArrowNativeComplexTypes, *capturedReq.UseArrowNativeTypes.ComplexTypesAsArrow)
476+
assert.Equal(t, tc.cfg.UseArrowNativeIntervalTypes, *capturedReq.UseArrowNativeTypes.IntervalTypesAsArrow)
477+
} else {
478+
assert.Nil(t, capturedReq.CanReadArrowResult_)
479+
assert.Nil(t, capturedReq.UseArrowNativeTypes)
480+
}
481+
482+
// Verify parameters
483+
shouldHaveParams := tc.supportsParameterizedQueries(version) && tc.hasParameters
484+
if shouldHaveParams {
485+
assert.NotNil(t, capturedReq.Parameters)
486+
assert.Len(t, capturedReq.Parameters, 1)
487+
} else if tc.hasParameters {
488+
// Even if we have parameters but protocol doesn't support it, we shouldn't set them
489+
assert.Nil(t, capturedReq.Parameters)
490+
}
491+
})
492+
}
493+
}
494+
}
495+
334496
func TestConn_pollOperation(t *testing.T) {
335497
t.Parallel()
336498
t.Run("pollOperation returns finished state response when query finishes", func(t *testing.T) {

connector.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6363
}
6464
log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "")
6565

66-
log.Info().Msgf("connect: host=%s port=%d httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath)
66+
log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion)
6767

6868
for k, v := range c.cfg.SessionParams {
6969
setStmt := fmt.Sprintf("SET `%s` = `%s`;", k, v)

internal/client/client.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_servic
9494
return resp, err
9595
}
9696

97+
// Log the server protocol version
98+
if resp != nil {
99+
log.Debug().Msgf("Server protocol version: 0x%X", resp.ServerProtocolVersion)
100+
}
101+
97102
recordResult(ctx, resp)
98103

99104
return resp, CheckStatus(resp)

internal/config/config_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ func TestConfig_DeepCopy(t *testing.T) {
648648
DriverVersion: "0.9.0",
649649
ThriftProtocol: "binary",
650650
ThriftTransport: "http",
651-
ThriftProtocolVersion: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6,
651+
ThriftProtocolVersion: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8,
652652
ThriftDebugClientProtocol: false,
653653
}
654654

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package thrift_protocol
2+
3+
import "github.com/databricks/databricks-sql-go/internal/cli_service"
4+
5+
// Feature checks
6+
// SupportsDirectResults checks if the server protocol version supports direct results
7+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V1 and above
8+
func SupportsDirectResults(version cli_service.TProtocolVersion) bool {
9+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V1
10+
}
11+
12+
// SupportsLz4Compression checks if the server protocol version supports LZ4 compression
13+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V6 and above
14+
func SupportsLz4Compression(version cli_service.TProtocolVersion) bool {
15+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6
16+
}
17+
18+
// SupportsCloudFetch checks if the server protocol version supports cloud fetch
19+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V3 and above
20+
func SupportsCloudFetch(version cli_service.TProtocolVersion) bool {
21+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V3
22+
}
23+
24+
// SupportsArrow checks if the server protocol version supports Arrow format
25+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V5 and above
26+
func SupportsArrow(version cli_service.TProtocolVersion) bool {
27+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V5
28+
}
29+
30+
// SupportsCompressedArrow checks if the server protocol version supports compressed Arrow format
31+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V6 and above
32+
func SupportsCompressedArrow(version cli_service.TProtocolVersion) bool {
33+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6
34+
}
35+
36+
// SupportsParameterizedQueries checks if the server protocol version supports parameterized queries
37+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V8 and above
38+
func SupportsParameterizedQueries(version cli_service.TProtocolVersion) bool {
39+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8
40+
}
41+
42+
// SupportsMultipleCatalogs checks if the server protocol version supports multiple catalogs
43+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V4 and above
44+
func SupportsMultipleCatalogs(version cli_service.TProtocolVersion) bool {
45+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V4
46+
}

0 commit comments

Comments
 (0)