33import logging
44import math
55import time
6- import uuid
76import threading
87from typing import List , Union , Any , TYPE_CHECKING
98
109if TYPE_CHECKING :
1110 from databricks .sql .client import Cursor
1211
13- from databricks .sql .thrift_api .TCLIService .ttypes import TOperationState
1412from databricks .sql .backend .types import (
1513 CommandState ,
1614 SessionId ,
1715 CommandId ,
18- BackendType ,
19- guid_to_hex_id ,
2016 ExecuteResponse ,
2117)
18+ from databricks .sql .backend .utils import guid_to_hex_id
19+
2220
2321try :
2422 import pyarrow
@@ -759,11 +757,13 @@ def _results_message_to_execute_response(self, resp, operation_state):
759757 )
760758 direct_results = resp .directResults
761759 has_been_closed_server_side = direct_results and direct_results .closeOperation
760+
762761 has_more_rows = (
763762 (not direct_results )
764763 or (not direct_results .resultSet )
765764 or direct_results .resultSet .hasMoreRows
766765 )
766+
767767 description = self ._hive_schema_to_description (
768768 t_result_set_metadata_resp .schema
769769 )
@@ -779,43 +779,25 @@ def _results_message_to_execute_response(self, resp, operation_state):
779779 schema_bytes = None
780780
781781 lz4_compressed = t_result_set_metadata_resp .lz4Compressed
782- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
783- if direct_results and direct_results .resultSet :
784- assert direct_results .resultSet .results .startRowOffset == 0
785- assert direct_results .resultSetMetadata
786-
787- arrow_queue_opt = ThriftResultSetQueueFactory .build_queue (
788- row_set_type = t_result_set_metadata_resp .resultFormat ,
789- t_row_set = direct_results .resultSet .results ,
790- arrow_schema_bytes = schema_bytes ,
791- max_download_threads = self .max_download_threads ,
792- lz4_compressed = lz4_compressed ,
793- description = description ,
794- ssl_options = self ._ssl_options ,
795- )
796- else :
797- arrow_queue_opt = None
798-
799782 command_id = CommandId .from_thrift_handle (resp .operationHandle )
800783
801784 status = CommandState .from_thrift_state (operation_state )
802785 if status is None :
803786 raise ValueError (f"Unknown command state: { operation_state } " )
804787
805- return (
806- ExecuteResponse (
807- command_id = command_id ,
808- status = status ,
809- description = description ,
810- has_more_rows = has_more_rows ,
811- results_queue = arrow_queue_opt ,
812- has_been_closed_server_side = has_been_closed_server_side ,
813- lz4_compressed = lz4_compressed ,
814- is_staging_operation = is_staging_operation ,
815- ),
816- schema_bytes ,
788+ execute_response = ExecuteResponse (
789+ command_id = command_id ,
790+ status = status ,
791+ description = description ,
792+ has_been_closed_server_side = has_been_closed_server_side ,
793+ lz4_compressed = lz4_compressed ,
794+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation ,
795+ arrow_schema_bytes = schema_bytes ,
796+ result_format = t_result_set_metadata_resp .resultFormat ,
817797 )
818798
799+ return execute_response , has_more_rows
800+
819801 def get_execution_result (
820802 self , command_id : CommandId , cursor : "Cursor"
821803 ) -> "ResultSet" :
@@ -840,9 +822,6 @@ def get_execution_result(
840822
841823 t_result_set_metadata_resp = resp .resultSetMetadata
842824
843- lz4_compressed = t_result_set_metadata_resp .lz4Compressed
844- is_staging_operation = t_result_set_metadata_resp .isStagingOperation
845- has_more_rows = resp .hasMoreRows
846825 description = self ._hive_schema_to_description (
847826 t_result_set_metadata_resp .schema
848827 )
@@ -857,27 +836,21 @@ def get_execution_result(
857836 else :
858837 schema_bytes = None
859838
860- queue = ThriftResultSetQueueFactory .build_queue (
861- row_set_type = resp .resultSetMetadata .resultFormat ,
862- t_row_set = resp .results ,
863- arrow_schema_bytes = schema_bytes ,
864- max_download_threads = self .max_download_threads ,
865- lz4_compressed = lz4_compressed ,
866- description = description ,
867- ssl_options = self ._ssl_options ,
868- )
839+ lz4_compressed = t_result_set_metadata_resp .lz4Compressed
840+ is_staging_operation = t_result_set_metadata_resp .isStagingOperation
841+ has_more_rows = resp .hasMoreRows
869842
870843 status = self .get_query_state (command_id )
871844
872845 execute_response = ExecuteResponse (
873846 command_id = command_id ,
874847 status = status ,
875848 description = description ,
876- has_more_rows = has_more_rows ,
877- results_queue = queue ,
878849 has_been_closed_server_side = False ,
879850 lz4_compressed = lz4_compressed ,
880851 is_staging_operation = is_staging_operation ,
852+ arrow_schema_bytes = schema_bytes ,
853+ result_format = t_result_set_metadata_resp .resultFormat ,
881854 )
882855
883856 return ThriftResultSet (
@@ -887,7 +860,10 @@ def get_execution_result(
887860 buffer_size_bytes = cursor .buffer_size_bytes ,
888861 arraysize = cursor .arraysize ,
889862 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
890- arrow_schema_bytes = schema_bytes ,
863+ t_row_set = resp .results ,
864+ max_download_threads = self .max_download_threads ,
865+ ssl_options = self ._ssl_options ,
866+ has_more_rows = has_more_rows ,
891867 )
892868
893869 def _wait_until_command_done (self , op_handle , initial_operation_status_resp ):
@@ -918,7 +894,7 @@ def get_query_state(self, command_id: CommandId) -> CommandState:
918894 self ._check_command_not_in_error_or_closed_state (thrift_handle , poll_resp )
919895 state = CommandState .from_thrift_state (operation_state )
920896 if state is None :
921- raise ValueError (f"Invalid operation state: { operation_state } " )
897+ raise ValueError (f"Unknown command state: { operation_state } " )
922898 return state
923899
924900 @staticmethod
@@ -1000,18 +976,25 @@ def execute_command(
1000976 self ._handle_execute_response_async (resp , cursor )
1001977 return None
1002978 else :
1003- execute_response , arrow_schema_bytes = self ._handle_execute_response (
979+ execute_response , has_more_rows = self ._handle_execute_response (
1004980 resp , cursor
1005981 )
1006982
983+ t_row_set = None
984+ if resp .directResults and resp .directResults .resultSet :
985+ t_row_set = resp .directResults .resultSet .results
986+
1007987 return ThriftResultSet (
1008988 connection = cursor .connection ,
1009989 execute_response = execute_response ,
1010990 thrift_client = self ,
1011991 buffer_size_bytes = max_bytes ,
1012992 arraysize = max_rows ,
1013993 use_cloud_fetch = use_cloud_fetch ,
1014- arrow_schema_bytes = arrow_schema_bytes ,
994+ t_row_set = t_row_set ,
995+ max_download_threads = self .max_download_threads ,
996+ ssl_options = self ._ssl_options ,
997+ has_more_rows = has_more_rows ,
1015998 )
1016999
10171000 def get_catalogs (
@@ -1033,9 +1016,11 @@ def get_catalogs(
10331016 )
10341017 resp = self .make_request (self ._client .GetCatalogs , req )
10351018
1036- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1037- resp , cursor
1038- )
1019+ execute_response , has_more_rows = self ._handle_execute_response (resp , cursor )
1020+
1021+ t_row_set = None
1022+ if resp .directResults and resp .directResults .resultSet :
1023+ t_row_set = resp .directResults .resultSet .results
10391024
10401025 return ThriftResultSet (
10411026 connection = cursor .connection ,
@@ -1044,7 +1029,10 @@ def get_catalogs(
10441029 buffer_size_bytes = max_bytes ,
10451030 arraysize = max_rows ,
10461031 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1047- arrow_schema_bytes = arrow_schema_bytes ,
1032+ t_row_set = t_row_set ,
1033+ max_download_threads = self .max_download_threads ,
1034+ ssl_options = self ._ssl_options ,
1035+ has_more_rows = has_more_rows ,
10481036 )
10491037
10501038 def get_schemas (
@@ -1070,9 +1058,11 @@ def get_schemas(
10701058 )
10711059 resp = self .make_request (self ._client .GetSchemas , req )
10721060
1073- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1074- resp , cursor
1075- )
1061+ execute_response , has_more_rows = self ._handle_execute_response (resp , cursor )
1062+
1063+ t_row_set = None
1064+ if resp .directResults and resp .directResults .resultSet :
1065+ t_row_set = resp .directResults .resultSet .results
10761066
10771067 return ThriftResultSet (
10781068 connection = cursor .connection ,
@@ -1081,7 +1071,10 @@ def get_schemas(
10811071 buffer_size_bytes = max_bytes ,
10821072 arraysize = max_rows ,
10831073 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1084- arrow_schema_bytes = arrow_schema_bytes ,
1074+ t_row_set = t_row_set ,
1075+ max_download_threads = self .max_download_threads ,
1076+ ssl_options = self ._ssl_options ,
1077+ has_more_rows = has_more_rows ,
10851078 )
10861079
10871080 def get_tables (
@@ -1111,9 +1104,11 @@ def get_tables(
11111104 )
11121105 resp = self .make_request (self ._client .GetTables , req )
11131106
1114- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1115- resp , cursor
1116- )
1107+ execute_response , has_more_rows = self ._handle_execute_response (resp , cursor )
1108+
1109+ t_row_set = None
1110+ if resp .directResults and resp .directResults .resultSet :
1111+ t_row_set = resp .directResults .resultSet .results
11171112
11181113 return ThriftResultSet (
11191114 connection = cursor .connection ,
@@ -1122,7 +1117,10 @@ def get_tables(
11221117 buffer_size_bytes = max_bytes ,
11231118 arraysize = max_rows ,
11241119 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1125- arrow_schema_bytes = arrow_schema_bytes ,
1120+ t_row_set = t_row_set ,
1121+ max_download_threads = self .max_download_threads ,
1122+ ssl_options = self ._ssl_options ,
1123+ has_more_rows = has_more_rows ,
11261124 )
11271125
11281126 def get_columns (
@@ -1152,9 +1150,11 @@ def get_columns(
11521150 )
11531151 resp = self .make_request (self ._client .GetColumns , req )
11541152
1155- execute_response , arrow_schema_bytes = self ._handle_execute_response (
1156- resp , cursor
1157- )
1153+ execute_response , has_more_rows = self ._handle_execute_response (resp , cursor )
1154+
1155+ t_row_set = None
1156+ if resp .directResults and resp .directResults .resultSet :
1157+ t_row_set = resp .directResults .resultSet .results
11581158
11591159 return ThriftResultSet (
11601160 connection = cursor .connection ,
@@ -1163,7 +1163,10 @@ def get_columns(
11631163 buffer_size_bytes = max_bytes ,
11641164 arraysize = max_rows ,
11651165 use_cloud_fetch = cursor .connection .use_cloud_fetch ,
1166- arrow_schema_bytes = arrow_schema_bytes ,
1166+ t_row_set = t_row_set ,
1167+ max_download_threads = self .max_download_threads ,
1168+ ssl_options = self ._ssl_options ,
1169+ has_more_rows = has_more_rows ,
11671170 )
11681171
11691172 def _handle_execute_response (self , resp , cursor ):
@@ -1177,11 +1180,7 @@ def _handle_execute_response(self, resp, cursor):
11771180 resp .directResults and resp .directResults .operationStatus ,
11781181 )
11791182
1180- (
1181- execute_response ,
1182- arrow_schema_bytes ,
1183- ) = self ._results_message_to_execute_response (resp , final_operation_state )
1184- return execute_response , arrow_schema_bytes
1183+ return self ._results_message_to_execute_response (resp , final_operation_state )
11851184
11861185 def _handle_execute_response_async (self , resp , cursor ):
11871186 command_id = CommandId .from_thrift_handle (resp .operationHandle )
@@ -1225,7 +1224,9 @@ def fetch_results(
12251224 )
12261225 )
12271226
1228- queue = ThriftResultSetQueueFactory .build_queue (
1227+ from databricks .sql .utils import ResultSetQueueFactory
1228+
1229+ queue = ResultSetQueueFactory .build_queue (
12291230 row_set_type = resp .resultSetMetadata .resultFormat ,
12301231 t_row_set = resp .results ,
12311232 arrow_schema_bytes = arrow_schema_bytes ,
0 commit comments