Skip to content

Commit c219440

Browse files
committed
handle thrift protocol version
1 parent 644d6da commit c219440

File tree

6 files changed

+97
-10
lines changed

6 files changed

+97
-10
lines changed

connection.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/databricks/databricks-sql-go/internal/client"
1919
"github.com/databricks/databricks-sql-go/internal/config"
2020
dbsqlerrint "github.com/databricks/databricks-sql-go/internal/errors"
21+
"github.com/databricks/databricks-sql-go/internal/thrift_protocol"
2122
"github.com/databricks/databricks-sql-go/internal/rows"
2223
"github.com/databricks/databricks-sql-go/internal/sentinel"
2324
"github.com/databricks/databricks-sql-go/logger"
@@ -285,14 +286,30 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
285286
Statement: query,
286287
RunAsync: true,
287288
QueryTimeout: int64(c.cfg.QueryTimeout / time.Second),
288-
GetDirectResults: &cli_service.TSparkGetDirectResults{
289+
}
290+
291+
// Check protocol version for feature support
292+
serverProtocolVersion := c.session.ServerProtocolVersion
293+
294+
// Add direct results if supported
295+
if thrift_protocol.SupportsDirectResults(serverProtocolVersion) {
296+
req.GetDirectResults = &cli_service.TSparkGetDirectResults{
289297
MaxRows: int64(c.cfg.MaxRows),
290-
},
291-
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
292-
Parameters: parameters,
298+
}
293299
}
294300

295-
if c.cfg.UseArrowBatches {
301+
// Add LZ4 compression if supported and enabled
302+
if thrift_protocol.SupportsLz4Compression(serverProtocolVersion) && c.cfg.UseLz4Compression {
303+
req.CanDecompressLZ4Result_ = &c.cfg.UseLz4Compression
304+
}
305+
306+
// Add cloud fetch if supported and enabled
307+
if thrift_protocol.SupportsCloudFetch(serverProtocolVersion) && c.cfg.UseCloudFetch {
308+
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
309+
}
310+
311+
// Add Arrow support if supported and enabled
312+
if thrift_protocol.SupportsArrow(serverProtocolVersion) && c.cfg.UseArrowBatches {
296313
req.CanReadArrowResult_ = &c.cfg.UseArrowBatches
297314
req.UseArrowNativeTypes = &cli_service.TSparkArrowTypes{
298315
DecimalAsArrow: &c.cfg.UseArrowNativeDecimal,
@@ -302,8 +319,9 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
302319
}
303320
}
304321

305-
if c.cfg.UseCloudFetch {
306-
req.CanDownloadResult_ = &c.cfg.UseCloudFetch
322+
// Add parameters if supported and provided
323+
if thrift_protocol.SupportsParameterizedQueries(serverProtocolVersion) && len(parameters) > 0 {
324+
req.Parameters = parameters
307325
}
308326

309327
resp, err := c.client.ExecuteStatement(ctx, &req)

connector.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
6363
}
6464
log := logger.WithContext(conn.id, driverctx.CorrelationIdFromContext(ctx), "")
6565

66-
log.Info().Msgf("connect: host=%s port=%d httpPath=%s", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath)
66+
log.Info().Msgf("connect: host=%s port=%d httpPath=%s serverProtocolVersion=0x%X", c.cfg.Host, c.cfg.Port, c.cfg.HTTPPath, session.ServerProtocolVersion)
6767

6868
for k, v := range c.cfg.SessionParams {
6969
setStmt := fmt.Sprintf("SET `%s` = `%s`;", k, v)

internal/client/client.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,11 @@ func (tsc *ThriftServiceClient) OpenSession(ctx context.Context, req *cli_servic
9494
return resp, err
9595
}
9696

97+
// Log the server protocol version
98+
if resp != nil {
99+
log.Debug().Msgf("Server protocol version: 0x%X", resp.ServerProtocolVersion)
100+
}
101+
97102
recordResult(ctx, resp)
98103

99104
return resp, CheckStatus(resp)

internal/client/testclient.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ import (
1010
var ErrNotImplemented = errors.New("databricks: not implemented")
1111

1212
type TestClient struct {
13+
// Default server protocol version to use in tests
14+
ServerProtocolVersion cli_service.TProtocolVersion
15+
1316
FnOpenSession func(ctx context.Context, req *cli_service.TOpenSessionReq) (_r *cli_service.TOpenSessionResp, _err error)
1417
FnCloseSession func(ctx context.Context, req *cli_service.TCloseSessionReq) (_r *cli_service.TCloseSessionResp, _err error)
1518
FnGetInfo func(ctx context.Context, req *cli_service.TGetInfoReq) (_r *cli_service.TGetInfoResp, _err error)
@@ -39,7 +42,20 @@ func (c *TestClient) OpenSession(ctx context.Context, req *cli_service.TOpenSess
3942
if c.FnOpenSession != nil {
4043
return c.FnOpenSession(ctx, req)
4144
}
42-
return nil, ErrNotImplemented
45+
46+
// Default implementation for test client
47+
resp := &cli_service.TOpenSessionResp{
48+
Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS},
49+
ServerProtocolVersion: c.ServerProtocolVersion,
50+
SessionHandle: &cli_service.TSessionHandle{
51+
SessionId: &cli_service.THandleIdentifier{
52+
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
53+
Secret: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
54+
},
55+
},
56+
}
57+
58+
return resp, nil
4359
}
4460
func (c *TestClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (_r *cli_service.TCloseSessionResp, _err error) {
4561
if c.FnCloseSession != nil {

internal/config/config_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ func TestConfig_DeepCopy(t *testing.T) {
648648
DriverVersion: "0.9.0",
649649
ThriftProtocol: "binary",
650650
ThriftTransport: "http",
651-
ThriftProtocolVersion: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6,
651+
ThriftProtocolVersion: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8,
652652
ThriftDebugClientProtocol: false,
653653
}
654654

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package thrift_protocol
2+
3+
import "github.com/databricks/databricks-sql-go/internal/cli_service"
4+
5+
6+
7+
// Feature checks
8+
// SupportsDirectResults checks if the server protocol version supports direct results
9+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V1 and above
10+
func SupportsDirectResults(version cli_service.TProtocolVersion) bool {
11+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V1
12+
}
13+
14+
// SupportsLz4Compression checks if the server protocol version supports LZ4 compression
15+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V6 and above
16+
func SupportsLz4Compression(version cli_service.TProtocolVersion) bool {
17+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6
18+
}
19+
20+
// SupportsCloudFetch checks if the server protocol version supports cloud fetch
21+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V3 and above
22+
func SupportsCloudFetch(version cli_service.TProtocolVersion) bool {
23+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V3
24+
}
25+
26+
// SupportsArrow checks if the server protocol version supports Arrow format
27+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V5 and above
28+
func SupportsArrow(version cli_service.TProtocolVersion) bool {
29+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V5
30+
}
31+
32+
// SupportsCompressedArrow checks if the server protocol version supports compressed Arrow format
33+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V6 and above
34+
func SupportsCompressedArrow(version cli_service.TProtocolVersion) bool {
35+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6
36+
}
37+
38+
// SupportsParameterizedQueries checks if the server protocol version supports parameterized queries
39+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V8 and above
40+
func SupportsParameterizedQueries(version cli_service.TProtocolVersion) bool {
41+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8
42+
}
43+
44+
// SupportsMultipleCatalogs checks if the server protocol version supports multiple catalogs
45+
// Supported in SPARK_CLI_SERVICE_PROTOCOL_V4 and above
46+
func SupportsMultipleCatalogs(version cli_service.TProtocolVersion) bool {
47+
return version >= cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V4
48+
}

0 commit comments

Comments
 (0)