@@ -14,6 +14,7 @@ import (
1414 "github.com/databricks/databricks-sql-go/internal/cli_service"
1515 "github.com/databricks/databricks-sql-go/internal/client"
1616 "github.com/databricks/databricks-sql-go/internal/config"
17+ "github.com/databricks/databricks-sql-go/internal/thrift_protocol"
1718 "github.com/stretchr/testify/assert"
1819)
1920
@@ -331,6 +332,167 @@ func TestConn_executeStatement(t *testing.T) {
331332
332333}
333334
335+ func TestConn_executeStatement_ProtocolFeatures (t * testing.T ) {
336+ t .Parallel ()
337+
338+ protocols := []cli_service.TProtocolVersion {
339+ cli_service .TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V1 ,
340+ cli_service .TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V2 ,
341+ cli_service .TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V3 ,
342+ cli_service .TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V4 ,
343+ cli_service .TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V5 ,
344+ cli_service .TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V6 ,
345+ cli_service .TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V7 ,
346+ cli_service .TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8 ,
347+ }
348+
349+ testCases := []struct {
350+ cfg * config.Config
351+ supportsDirectResults func (version cli_service.TProtocolVersion ) bool
352+ supportsLz4Compression func (version cli_service.TProtocolVersion ) bool
353+ supportsCloudFetch func (version cli_service.TProtocolVersion ) bool
354+ supportsArrow func (version cli_service.TProtocolVersion ) bool
355+ supportsParameterizedQueries func (version cli_service.TProtocolVersion ) bool
356+ hasParameters bool
357+ }{
358+ {
359+ cfg : func () * config.Config {
360+ cfg := config .WithDefaults ()
361+ cfg .UseLz4Compression = true
362+ cfg .UseCloudFetch = true
363+ cfg .UseArrowBatches = true
364+ cfg .UseArrowNativeDecimal = true
365+ cfg .UseArrowNativeTimestamp = true
366+ cfg .UseArrowNativeComplexTypes = true
367+ cfg .UseArrowNativeIntervalTypes = true
368+ return cfg
369+ }(),
370+ supportsDirectResults : thrift_protocol .SupportsDirectResults ,
371+ supportsLz4Compression : thrift_protocol .SupportsLz4Compression ,
372+ supportsCloudFetch : thrift_protocol .SupportsCloudFetch ,
373+ supportsArrow : thrift_protocol .SupportsArrow ,
374+ supportsParameterizedQueries : thrift_protocol .SupportsParameterizedQueries ,
375+ hasParameters : true ,
376+ },
377+ {
378+ cfg : func () * config.Config {
379+ cfg := config .WithDefaults ()
380+ cfg .UseLz4Compression = false
381+ cfg .UseCloudFetch = false
382+ cfg .UseArrowBatches = false
383+ return cfg
384+ }(),
385+ supportsDirectResults : thrift_protocol .SupportsDirectResults ,
386+ supportsLz4Compression : thrift_protocol .SupportsLz4Compression ,
387+ supportsCloudFetch : thrift_protocol .SupportsCloudFetch ,
388+ supportsArrow : thrift_protocol .SupportsArrow ,
389+ supportsParameterizedQueries : thrift_protocol .SupportsParameterizedQueries ,
390+ hasParameters : false ,
391+ },
392+ }
393+
394+ for _ , tc := range testCases {
395+ for _ , version := range protocols {
396+ t .Run (fmt .Sprintf ("protocol_v%d_withParams_%v" , version , tc .hasParameters ), func (t * testing.T ) {
397+ var capturedReq * cli_service.TExecuteStatementReq
398+ executeStatement := func (ctx context.Context , req * cli_service.TExecuteStatementReq ) (r * cli_service.TExecuteStatementResp , err error ) {
399+ capturedReq = req
400+ executeStatementResp := & cli_service.TExecuteStatementResp {
401+ Status : & cli_service.TStatus {
402+ StatusCode : cli_service .TStatusCode_SUCCESS_STATUS ,
403+ },
404+ OperationHandle : & cli_service.TOperationHandle {
405+ OperationId : & cli_service.THandleIdentifier {
406+ GUID : []byte {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 },
407+ Secret : []byte ("secret" ),
408+ },
409+ },
410+ DirectResults : & cli_service.TSparkDirectResults {
411+ OperationStatus : & cli_service.TGetOperationStatusResp {
412+ Status : & cli_service.TStatus {
413+ StatusCode : cli_service .TStatusCode_SUCCESS_STATUS ,
414+ },
415+ OperationState : cli_service .TOperationStatePtr (cli_service .TOperationState_FINISHED_STATE ),
416+ },
417+ },
418+ }
419+ return executeStatementResp , nil
420+ }
421+
422+ session := getTestSession ()
423+ session .ServerProtocolVersion = version
424+
425+ testClient := & client.TestClient {
426+ FnExecuteStatement : executeStatement ,
427+ }
428+
429+ testConn := & conn {
430+ session : session ,
431+ client : testClient ,
432+ cfg : tc .cfg ,
433+ }
434+
435+ var args []driver.NamedValue
436+ if tc .hasParameters {
437+ args = []driver.NamedValue {
438+ {Name : "param1" , Value : "value1" },
439+ }
440+ }
441+
442+ _ , err := testConn .executeStatement (context .Background (), "SELECT 1" , args )
443+ assert .NoError (t , err )
444+
445+ // Verify direct results
446+ hasDirectResults := tc .supportsDirectResults (version )
447+ assert .Equal (t , hasDirectResults , capturedReq .GetDirectResults != nil , "Direct results should be enabled if protocol supports it" )
448+
449+ // Verify LZ4 compression
450+ shouldHaveLz4 := tc .supportsLz4Compression (version ) && tc .cfg .UseLz4Compression
451+ if shouldHaveLz4 {
452+ assert .NotNil (t , capturedReq .CanDecompressLZ4Result_ )
453+ assert .True (t , * capturedReq .CanDecompressLZ4Result_ )
454+ } else {
455+ assert .Nil (t , capturedReq .CanDecompressLZ4Result_ )
456+ }
457+
458+ // Verify cloud fetch
459+ shouldHaveCloudFetch := tc .supportsCloudFetch (version ) && tc .cfg .UseCloudFetch
460+ if shouldHaveCloudFetch {
461+ assert .NotNil (t , capturedReq .CanDownloadResult_ )
462+ assert .True (t , * capturedReq .CanDownloadResult_ )
463+ } else {
464+ assert .Nil (t , capturedReq .CanDownloadResult_ )
465+ }
466+
467+ // Verify Arrow support
468+ shouldHaveArrow := tc .supportsArrow (version ) && tc .cfg .UseArrowBatches
469+ if shouldHaveArrow {
470+ assert .NotNil (t , capturedReq .CanReadArrowResult_ )
471+ assert .True (t , * capturedReq .CanReadArrowResult_ )
472+ assert .NotNil (t , capturedReq .UseArrowNativeTypes )
473+ assert .Equal (t , tc .cfg .UseArrowNativeDecimal , * capturedReq .UseArrowNativeTypes .DecimalAsArrow )
474+ assert .Equal (t , tc .cfg .UseArrowNativeTimestamp , * capturedReq .UseArrowNativeTypes .TimestampAsArrow )
475+ assert .Equal (t , tc .cfg .UseArrowNativeComplexTypes , * capturedReq .UseArrowNativeTypes .ComplexTypesAsArrow )
476+ assert .Equal (t , tc .cfg .UseArrowNativeIntervalTypes , * capturedReq .UseArrowNativeTypes .IntervalTypesAsArrow )
477+ } else {
478+ assert .Nil (t , capturedReq .CanReadArrowResult_ )
479+ assert .Nil (t , capturedReq .UseArrowNativeTypes )
480+ }
481+
482+ // Verify parameters
483+ shouldHaveParams := tc .supportsParameterizedQueries (version ) && tc .hasParameters
484+ if shouldHaveParams {
485+ assert .NotNil (t , capturedReq .Parameters )
486+ assert .Len (t , capturedReq .Parameters , 1 )
487+ } else if tc .hasParameters {
488+ // Even if we have parameters but protocol doesn't support it, we shouldn't set them
489+ assert .Nil (t , capturedReq .Parameters )
490+ }
491+ })
492+ }
493+ }
494+ }
495+
334496func TestConn_pollOperation (t * testing.T ) {
335497 t .Parallel ()
336498 t .Run ("pollOperation returns finished state response when query finishes" , func (t * testing.T ) {
0 commit comments