Skip to content

Commit 3bd3aef

Browse files
fix merge artifacts
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 77e23d3 commit 3bd3aef

File tree

5 files changed

+30
-21
lines changed

5 files changed

+30
-21
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId:
605605
session_id = SessionId.from_thrift_handle(
606606
response.sessionHandle, properties
607607
)
608-
self._session_id_hex = session_id.hex_guid
608+
self._session_id_hex = session_id.guid_hex
609609
return session_id
610610
except:
611611
self._transport.close()

src/databricks/sql/backend/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __str__(self) -> str:
161161
if isinstance(self.secret, bytes)
162162
else str(self.secret)
163163
)
164-
return f"{self.hex_guid}|{secret_hex}"
164+
return f"{self.guid_hex}|{secret_hex}"
165165
return str(self.guid)
166166

167167
@classmethod
@@ -240,7 +240,7 @@ def to_sea_session_id(self):
240240
return self.guid
241241

242242
@property
243-
def hex_guid(self) -> str:
243+
def guid_hex(self) -> str:
244244
"""
245245
Get a hexadecimal string representation of the session ID.
246246

src/databricks/sql/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def open(self):
131131

132132
@staticmethod
133133
def get_protocol_version(session_id: SessionId):
134-
return session_id.get_protocol_version()
134+
return session_id.protocol_version
135135

136136
@staticmethod
137137
def server_parameterized_queries_enabled(protocolVersion):

tests/unit/test_client.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,9 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
9898
mock_thrift_client_class: Mock for ThriftBackend class
9999
"""
100100

101+
# Test once with has_been_closed_server side, once without
101102
for closed in (True, False):
102103
with self.subTest(closed=closed):
103-
# Set initial state based on whether the command is already closed
104-
initial_state = (
105-
CommandState.CLOSED if closed else CommandState.SUCCEEDED
106-
)
107-
108104
# Mock the execute response with controlled state
109105
mock_execute_response = Mock(spec=ExecuteResponse)
110106

@@ -114,11 +110,14 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
114110
)
115111
mock_execute_response.has_been_closed_server_side = closed
116112
mock_execute_response.is_staging_operation = False
117-
mock_execute_response.command_id = Mock(spec=CommandId)
113+
mock_execute_response.description = []
118114

119-
# Mock the backend that will be used
115+
# Mock the backend that will be used by the real ThriftResultSet
120116
mock_backend = Mock(spec=ThriftDatabricksClient)
121117
mock_backend.staging_allowed_local_path = None
118+
mock_backend.fetch_results.return_value = (Mock(), False)
119+
120+
# Configure the decorator's mock to return our specific mock_backend
122121
mock_thrift_client_class.return_value = mock_backend
123122

124123
# Create connection and cursor
@@ -137,16 +136,22 @@ def test_closing_connection_closes_commands(self, mock_thrift_client_class):
137136
# Execute a command - this should set cursor.active_result_set to our real result set
138137
cursor.execute("SELECT 1")
139138

139+
# Verify that cursor.execute() set up the result set correctly
140+
self.assertIsInstance(cursor.active_result_set, ThriftResultSet)
141+
self.assertEqual(
142+
cursor.active_result_set.has_been_closed_server_side, closed
143+
)
144+
140145
# Close the connection - this should trigger the real close chain:
141146
# connection.close() -> cursor.close() -> result_set.close()
142147
connection.close()
143148

144149
# Verify the REAL close logic worked through the chain:
145150
# 1. has_been_closed_server_side should always be True after close()
146-
assert real_result_set.has_been_closed_server_side is True
151+
self.assertTrue(real_result_set.has_been_closed_server_side)
147152

148-
# 2. op_state should always be CLOSED after close()
149-
assert real_result_set.op_state == CommandState.CLOSED
153+
# 2. status should always be CLOSED after close()
154+
self.assertEqual(real_result_set.status, CommandState.CLOSED)
150155

151156
# 3. Backend close_command should be called appropriately
152157
if not closed:
@@ -183,6 +188,7 @@ def test_arraysize_buffer_size_passthrough(
183188
def test_closing_result_set_with_closed_connection_soft_closes_commands(self):
184189
mock_connection = Mock()
185190
mock_backend = Mock()
191+
mock_backend.fetch_results.return_value = (Mock(), False)
186192

187193
result_set = ThriftResultSet(
188194
connection=mock_connection,
@@ -209,6 +215,7 @@ def test_closing_result_set_hard_closes_commands(self):
209215
mock_session.open = True
210216
type(mock_connection).session = PropertyMock(return_value=mock_session)
211217

218+
mock_thrift_backend.fetch_results.return_value = (Mock(), False)
212219
result_set = ThriftResultSet(
213220
mock_connection, mock_results_response, mock_thrift_backend
214221
)

tests/unit/test_session.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,18 @@ def test_auth_args(self, mock_client_class):
6262

6363
for args in connection_args:
6464
connection = databricks.sql.connect(**args)
65-
host, port, http_path, *_ = mock_client_class.call_args[0]
66-
assert args["server_hostname"] == host
67-
assert args["http_path"] == http_path
65+
call_kwargs = mock_client_class.call_args[1]
66+
assert args["server_hostname"] == call_kwargs["server_hostname"]
67+
assert args["http_path"] == call_kwargs["http_path"]
6868
connection.close()
6969

7070
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
7171
def test_http_header_passthrough(self, mock_client_class):
7272
http_headers = [("foo", "bar")]
7373
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers)
7474

75-
call_args = mock_client_class.call_args[0][3]
76-
assert ("foo", "bar") in call_args
75+
call_kwargs = mock_client_class.call_args[1]
76+
assert ("foo", "bar") in call_kwargs["http_headers"]
7777

7878
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
7979
def test_tls_arg_passthrough(self, mock_client_class):
@@ -95,7 +95,8 @@ def test_tls_arg_passthrough(self, mock_client_class):
9595
def test_useragent_header(self, mock_client_class):
9696
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
9797

98-
http_headers = mock_client_class.call_args[0][3]
98+
call_kwargs = mock_client_class.call_args[1]
99+
http_headers = call_kwargs["http_headers"]
99100
user_agent_header = (
100101
"User-Agent",
101102
"{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__),
@@ -109,7 +110,8 @@ def test_useragent_header(self, mock_client_class):
109110
databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar"
110111
),
111112
)
112-
http_headers = mock_client_class.call_args[0][3]
113+
call_kwargs = mock_client_class.call_args[1]
114+
http_headers = call_kwargs["http_headers"]
113115
assert user_agent_header_with_entry in http_headers
114116

115117
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)

0 commit comments

Comments
 (0)