Skip to content

Commit 967110e

Browse files
committed
Add protocol feature tests
1 parent fbc2af2 commit 967110e

File tree

3 files changed

+299
-17
lines changed

3 files changed

+299
-17
lines changed

connection_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/databricks/databricks-sql-go/internal/cli_service"
1515
"github.com/databricks/databricks-sql-go/internal/client"
1616
"github.com/databricks/databricks-sql-go/internal/config"
17+
"github.com/databricks/databricks-sql-go/internal/thrift_protocol"
1718
"github.com/stretchr/testify/assert"
1819
)
1920

@@ -331,6 +332,167 @@ func TestConn_executeStatement(t *testing.T) {
331332

332333
}
333334

335+
func TestConn_executeStatement_ProtocolFeatures(t *testing.T) {
336+
t.Parallel()
337+
338+
protocols := []cli_service.TProtocolVersion{
339+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V1,
340+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V2,
341+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V3,
342+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V4,
343+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V5,
344+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6,
345+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V7,
346+
cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8,
347+
}
348+
349+
testCases := []struct {
350+
cfg *config.Config
351+
supportsDirectResults func(version cli_service.TProtocolVersion) bool
352+
supportsLz4Compression func(version cli_service.TProtocolVersion) bool
353+
supportsCloudFetch func(version cli_service.TProtocolVersion) bool
354+
supportsArrow func(version cli_service.TProtocolVersion) bool
355+
supportsParameterizedQueries func(version cli_service.TProtocolVersion) bool
356+
hasParameters bool
357+
}{
358+
{
359+
cfg: func() *config.Config {
360+
cfg := config.WithDefaults()
361+
cfg.UseLz4Compression = true
362+
cfg.UseCloudFetch = true
363+
cfg.UseArrowBatches = true
364+
cfg.UseArrowNativeDecimal = true
365+
cfg.UseArrowNativeTimestamp = true
366+
cfg.UseArrowNativeComplexTypes = true
367+
cfg.UseArrowNativeIntervalTypes = true
368+
return cfg
369+
}(),
370+
supportsDirectResults: thrift_protocol.SupportsDirectResults,
371+
supportsLz4Compression: thrift_protocol.SupportsLz4Compression,
372+
supportsCloudFetch: thrift_protocol.SupportsCloudFetch,
373+
supportsArrow: thrift_protocol.SupportsArrow,
374+
supportsParameterizedQueries: thrift_protocol.SupportsParameterizedQueries,
375+
hasParameters: true,
376+
},
377+
{
378+
cfg: func() *config.Config {
379+
cfg := config.WithDefaults()
380+
cfg.UseLz4Compression = false
381+
cfg.UseCloudFetch = false
382+
cfg.UseArrowBatches = false
383+
return cfg
384+
}(),
385+
supportsDirectResults: thrift_protocol.SupportsDirectResults,
386+
supportsLz4Compression: thrift_protocol.SupportsLz4Compression,
387+
supportsCloudFetch: thrift_protocol.SupportsCloudFetch,
388+
supportsArrow: thrift_protocol.SupportsArrow,
389+
supportsParameterizedQueries: thrift_protocol.SupportsParameterizedQueries,
390+
hasParameters: false,
391+
},
392+
}
393+
394+
for _, tc := range testCases {
395+
for _, version := range protocols {
396+
t.Run(fmt.Sprintf("protocol_v%d_withParams_%v", version, tc.hasParameters), func(t *testing.T) {
397+
var capturedReq *cli_service.TExecuteStatementReq
398+
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
399+
capturedReq = req
400+
executeStatementResp := &cli_service.TExecuteStatementResp{
401+
Status: &cli_service.TStatus{
402+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
403+
},
404+
OperationHandle: &cli_service.TOperationHandle{
405+
OperationId: &cli_service.THandleIdentifier{
406+
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
407+
Secret: []byte("secret"),
408+
},
409+
},
410+
DirectResults: &cli_service.TSparkDirectResults{
411+
OperationStatus: &cli_service.TGetOperationStatusResp{
412+
Status: &cli_service.TStatus{
413+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
414+
},
415+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
416+
},
417+
},
418+
}
419+
return executeStatementResp, nil
420+
}
421+
422+
session := getTestSession()
423+
session.ServerProtocolVersion = version
424+
425+
testClient := &client.TestClient{
426+
FnExecuteStatement: executeStatement,
427+
}
428+
429+
testConn := &conn{
430+
session: session,
431+
client: testClient,
432+
cfg: tc.cfg,
433+
}
434+
435+
var args []driver.NamedValue
436+
if tc.hasParameters {
437+
args = []driver.NamedValue{
438+
{Name: "param1", Value: "value1"},
439+
}
440+
}
441+
442+
_, err := testConn.executeStatement(context.Background(), "SELECT 1", args)
443+
assert.NoError(t, err)
444+
445+
// Verify direct results
446+
hasDirectResults := tc.supportsDirectResults(version)
447+
assert.Equal(t, hasDirectResults, capturedReq.GetDirectResults != nil, "Direct results should be enabled if protocol supports it")
448+
449+
// Verify LZ4 compression
450+
shouldHaveLz4 := tc.supportsLz4Compression(version) && tc.cfg.UseLz4Compression
451+
if shouldHaveLz4 {
452+
assert.NotNil(t, capturedReq.CanDecompressLZ4Result_)
453+
assert.True(t, *capturedReq.CanDecompressLZ4Result_)
454+
} else {
455+
assert.Nil(t, capturedReq.CanDecompressLZ4Result_)
456+
}
457+
458+
// Verify cloud fetch
459+
shouldHaveCloudFetch := tc.supportsCloudFetch(version) && tc.cfg.UseCloudFetch
460+
if shouldHaveCloudFetch {
461+
assert.NotNil(t, capturedReq.CanDownloadResult_)
462+
assert.True(t, *capturedReq.CanDownloadResult_)
463+
} else {
464+
assert.Nil(t, capturedReq.CanDownloadResult_)
465+
}
466+
467+
// Verify Arrow support
468+
shouldHaveArrow := tc.supportsArrow(version) && tc.cfg.UseArrowBatches
469+
if shouldHaveArrow {
470+
assert.NotNil(t, capturedReq.CanReadArrowResult_)
471+
assert.True(t, *capturedReq.CanReadArrowResult_)
472+
assert.NotNil(t, capturedReq.UseArrowNativeTypes)
473+
assert.Equal(t, tc.cfg.UseArrowNativeDecimal, *capturedReq.UseArrowNativeTypes.DecimalAsArrow)
474+
assert.Equal(t, tc.cfg.UseArrowNativeTimestamp, *capturedReq.UseArrowNativeTypes.TimestampAsArrow)
475+
assert.Equal(t, tc.cfg.UseArrowNativeComplexTypes, *capturedReq.UseArrowNativeTypes.ComplexTypesAsArrow)
476+
assert.Equal(t, tc.cfg.UseArrowNativeIntervalTypes, *capturedReq.UseArrowNativeTypes.IntervalTypesAsArrow)
477+
} else {
478+
assert.Nil(t, capturedReq.CanReadArrowResult_)
479+
assert.Nil(t, capturedReq.UseArrowNativeTypes)
480+
}
481+
482+
// Verify parameters
483+
shouldHaveParams := tc.supportsParameterizedQueries(version) && tc.hasParameters
484+
if shouldHaveParams {
485+
assert.NotNil(t, capturedReq.Parameters)
486+
assert.Len(t, capturedReq.Parameters, 1)
487+
} else if tc.hasParameters {
488+
// Even if we have parameters but protocol doesn't support it, we shouldn't set them
489+
assert.Nil(t, capturedReq.Parameters)
490+
}
491+
})
492+
}
493+
}
494+
}
495+
334496
func TestConn_pollOperation(t *testing.T) {
335497
t.Parallel()
336498
t.Run("pollOperation returns finished state response when query finishes", func(t *testing.T) {

internal/client/testclient.go

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@ import (
1010
var ErrNotImplemented = errors.New("databricks: not implemented")
1111

1212
type TestClient struct {
13-
// Default server protocol version to use in tests
14-
ServerProtocolVersion cli_service.TProtocolVersion
15-
1613
FnOpenSession func(ctx context.Context, req *cli_service.TOpenSessionReq) (_r *cli_service.TOpenSessionResp, _err error)
1714
FnCloseSession func(ctx context.Context, req *cli_service.TCloseSessionReq) (_r *cli_service.TCloseSessionResp, _err error)
1815
FnGetInfo func(ctx context.Context, req *cli_service.TGetInfoReq) (_r *cli_service.TGetInfoResp, _err error)
@@ -42,20 +39,7 @@ func (c *TestClient) OpenSession(ctx context.Context, req *cli_service.TOpenSess
4239
if c.FnOpenSession != nil {
4340
return c.FnOpenSession(ctx, req)
4441
}
45-
46-
// Default implementation for test client
47-
resp := &cli_service.TOpenSessionResp{
48-
Status: &cli_service.TStatus{StatusCode: cli_service.TStatusCode_SUCCESS_STATUS},
49-
ServerProtocolVersion: c.ServerProtocolVersion,
50-
SessionHandle: &cli_service.TSessionHandle{
51-
SessionId: &cli_service.THandleIdentifier{
52-
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
53-
Secret: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
54-
},
55-
},
56-
}
57-
58-
return resp, nil
42+
return nil, ErrNotImplemented
5943
}
6044
func (c *TestClient) CloseSession(ctx context.Context, req *cli_service.TCloseSessionReq) (_r *cli_service.TCloseSessionResp, _err error) {
6145
if c.FnCloseSession != nil {
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package thrift_protocol
2+
3+
import (
4+
"testing"
5+
6+
"github.com/databricks/databricks-sql-go/internal/cli_service"
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestProtocolFeatureSupport(t *testing.T) {
11+
testCases := []struct {
12+
name string
13+
version cli_service.TProtocolVersion
14+
expectDirectResults bool
15+
expectLz4Compression bool
16+
expectCloudFetch bool
17+
expectArrow bool
18+
expectCompressedArrow bool
19+
expectParameterizedQueries bool
20+
expectMultipleCatalogs bool
21+
}{
22+
{
23+
name: "Protocol V1",
24+
version: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V1,
25+
expectDirectResults: true,
26+
expectLz4Compression: false,
27+
expectCloudFetch: false,
28+
expectArrow: false,
29+
expectCompressedArrow: false,
30+
expectParameterizedQueries: false,
31+
expectMultipleCatalogs: false,
32+
},
33+
{
34+
name: "Protocol V2",
35+
version: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V2,
36+
expectDirectResults: true,
37+
expectLz4Compression: false,
38+
expectCloudFetch: false,
39+
expectArrow: false,
40+
expectCompressedArrow: false,
41+
expectParameterizedQueries: false,
42+
expectMultipleCatalogs: false,
43+
},
44+
{
45+
name: "Protocol V3",
46+
version: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V3,
47+
expectDirectResults: true,
48+
expectLz4Compression: false,
49+
expectCloudFetch: true,
50+
expectArrow: false,
51+
expectCompressedArrow: false,
52+
expectParameterizedQueries: false,
53+
expectMultipleCatalogs: false,
54+
},
55+
{
56+
name: "Protocol V4",
57+
version: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V4,
58+
expectDirectResults: true,
59+
expectLz4Compression: false,
60+
expectCloudFetch: true,
61+
expectArrow: false,
62+
expectCompressedArrow: false,
63+
expectParameterizedQueries: false,
64+
expectMultipleCatalogs: true,
65+
},
66+
{
67+
name: "Protocol V5",
68+
version: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V5,
69+
expectDirectResults: true,
70+
expectLz4Compression: false,
71+
expectCloudFetch: true,
72+
expectArrow: true,
73+
expectCompressedArrow: false,
74+
expectParameterizedQueries: false,
75+
expectMultipleCatalogs: true,
76+
},
77+
{
78+
name: "Protocol V6",
79+
version: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6,
80+
expectDirectResults: true,
81+
expectLz4Compression: true,
82+
expectCloudFetch: true,
83+
expectArrow: true,
84+
expectCompressedArrow: true,
85+
expectParameterizedQueries: false,
86+
expectMultipleCatalogs: true,
87+
},
88+
{
89+
name: "Protocol V7",
90+
version: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V7,
91+
expectDirectResults: true,
92+
expectLz4Compression: true,
93+
expectCloudFetch: true,
94+
expectArrow: true,
95+
expectCompressedArrow: true,
96+
expectParameterizedQueries: false,
97+
expectMultipleCatalogs: true,
98+
},
99+
{
100+
name: "Protocol V8",
101+
version: cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8,
102+
expectDirectResults: true,
103+
expectLz4Compression: true,
104+
expectCloudFetch: true,
105+
expectArrow: true,
106+
expectCompressedArrow: true,
107+
expectParameterizedQueries: true,
108+
expectMultipleCatalogs: true,
109+
},
110+
}
111+
112+
for _, tc := range testCases {
113+
t.Run(tc.name, func(t *testing.T) {
114+
assert.Equal(t, tc.expectDirectResults, SupportsDirectResults(tc.version),
115+
"DirectResults support check failed for %s", tc.name)
116+
117+
assert.Equal(t, tc.expectLz4Compression, SupportsLz4Compression(tc.version),
118+
"LZ4Compression support check failed for %s", tc.name)
119+
120+
assert.Equal(t, tc.expectCloudFetch, SupportsCloudFetch(tc.version),
121+
"CloudFetch support check failed for %s", tc.name)
122+
123+
assert.Equal(t, tc.expectArrow, SupportsArrow(tc.version),
124+
"Arrow support check failed for %s", tc.name)
125+
126+
assert.Equal(t, tc.expectCompressedArrow, SupportsCompressedArrow(tc.version),
127+
"CompressedArrow support check failed for %s", tc.name)
128+
129+
assert.Equal(t, tc.expectParameterizedQueries, SupportsParameterizedQueries(tc.version),
130+
"ParameterizedQueries support check failed for %s", tc.name)
131+
132+
assert.Equal(t, tc.expectMultipleCatalogs, SupportsMultipleCatalogs(tc.version),
133+
"MultipleCatalogs support check failed for %s", tc.name)
134+
})
135+
}
136+
}

0 commit comments

Comments
 (0)