44
55from databricks .sql .thrift_api .TCLIService .ttypes import (
66 TOpenSessionResp ,
7+ TSessionHandle ,
8+ THandleIdentifier ,
79)
10+ from databricks .sql .ids import SessionId , BackendType
811
912import databricks .sql
1013
@@ -25,16 +28,17 @@ class SessionTestSuite(unittest.TestCase):
2528 def test_close_uses_the_correct_session_id (self , mock_client_class ):
2629 instance = mock_client_class .return_value
2730
28- mock_open_session_resp = MagicMock ( spec = TOpenSessionResp )()
29- mock_open_session_resp . sessionHandle . sessionId = b"\x22 "
30- instance .open_session .return_value = mock_open_session_resp
31+ # Create a mock SessionId that will be returned by open_session
32+ mock_session_id = SessionId ( BackendType . THRIFT , b"\x22 " , b" \x33 " )
33+ instance .open_session .return_value = mock_session_id
3134
3235 connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
3336 connection .close ()
3437
35- # Check the close session request has an id of x22
36- close_session_id = instance .close_session .call_args [0 ][0 ].sessionId
37- self .assertEqual (close_session_id , b"\x22 " )
38+ # Check that close_session was called with the correct SessionId
39+ close_session_call_args = instance .close_session .call_args [0 ][0 ]
40+ self .assertEqual (close_session_call_args .guid , b"\x22 " )
41+ self .assertEqual (close_session_call_args .secret , b"\x33 " )
3842
3943 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
4044 def test_auth_args (self , mock_client_class ):
@@ -112,16 +116,17 @@ def test_useragent_header(self, mock_client_class):
112116 def test_context_manager_closes_connection (self , mock_client_class ):
113117 instance = mock_client_class .return_value
114118
115- mock_open_session_resp = MagicMock ( spec = TOpenSessionResp )()
116- mock_open_session_resp . sessionHandle . sessionId = b"\x22 "
117- instance .open_session .return_value = mock_open_session_resp
119+ # Create a mock SessionId that will be returned by open_session
120+ mock_session_id = SessionId ( BackendType . THRIFT , b"\x22 " , b" \x33 " )
121+ instance .open_session .return_value = mock_session_id
118122
119123 with databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS ) as connection :
120124 pass
121125
122- # Check the close session request has an id of x22
123- close_session_id = instance .close_session .call_args [0 ][0 ].sessionId
124- self .assertEqual (close_session_id , b"\x22 " )
126+ # Check that close_session was called with the correct SessionId
127+ close_session_call_args = instance .close_session .call_args [0 ][0 ]
128+ self .assertEqual (close_session_call_args .guid , b"\x22 " )
129+ self .assertEqual (close_session_call_args .secret , b"\x33 " )
125130
126131 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
127132 def test_max_number_of_retries_passthrough (self , mock_client_class ):
@@ -141,46 +146,54 @@ def test_socket_timeout_passthrough(self, mock_client_class):
141146 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
142147 def test_configuration_passthrough (self , mock_client_class ):
143148 mock_session_config = Mock ()
149+
150+ # Create a mock SessionId that will be returned by open_session
151+ mock_session_id = SessionId (BackendType .THRIFT , b"\x22 " , b"\x33 " )
152+ mock_client_class .return_value .open_session .return_value = mock_session_id
153+
144154 databricks .sql .connect (
145155 session_configuration = mock_session_config , ** self .DUMMY_CONNECTION_ARGS
146156 )
147157
148- self .assertEqual (
149- mock_client_class .return_value .open_session .call_args [0 ][0 ],
150- mock_session_config ,
151- )
158+ # Check that open_session was called with the correct session_configuration
159+ call_kwargs = mock_client_class .return_value .open_session .call_args [1 ]
160+ self .assertEqual (call_kwargs ["session_configuration" ], mock_session_config )
152161
153162 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
154163 def test_initial_namespace_passthrough (self , mock_client_class ):
155164 mock_cat = Mock ()
156165 mock_schem = Mock ()
157166
167+ # Create a mock SessionId that will be returned by open_session
168+ mock_session_id = SessionId (BackendType .THRIFT , b"\x22 " , b"\x33 " )
169+ mock_client_class .return_value .open_session .return_value = mock_session_id
170+
158171 databricks .sql .connect (
159172 ** self .DUMMY_CONNECTION_ARGS , catalog = mock_cat , schema = mock_schem
160173 )
161- self .assertEqual (
162- mock_client_class .return_value .open_session .call_args [0 ][1 ], mock_cat
163- )
164- self .assertEqual (
165- mock_client_class .return_value .open_session .call_args [0 ][2 ], mock_schem
166- )
174+
175+ # Check that open_session was called with the correct catalog and schema
176+ call_kwargs = mock_client_class .return_value .open_session .call_args [1 ]
177+ self .assertEqual (call_kwargs ["catalog" ], mock_cat )
178+ self .assertEqual (call_kwargs ["schema" ], mock_schem )
167179
168180 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
169181 def test_finalizer_closes_abandoned_connection (self , mock_client_class ):
170182 instance = mock_client_class .return_value
171183
172- mock_open_session_resp = MagicMock ( spec = TOpenSessionResp )()
173- mock_open_session_resp . sessionHandle . sessionId = b"\x22 "
174- instance .open_session .return_value = mock_open_session_resp
184+ # Create a mock SessionId that will be returned by open_session
185+ mock_session_id = SessionId ( BackendType . THRIFT , b"\x22 " , b" \x33 " )
186+ instance .open_session .return_value = mock_session_id
175187
176188 databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
177189
178190 # not strictly necessary as the refcount is 0, but just to be sure
179191 gc .collect ()
180192
181- # Check the close session request has an id of x22
182- close_session_id = instance .close_session .call_args [0 ][0 ].sessionId
183- self .assertEqual (close_session_id , b"\x22 " )
193+ # Check that close_session was called with the correct SessionId
194+ close_session_call_args = instance .close_session .call_args [0 ][0 ]
195+ self .assertEqual (close_session_call_args .guid , b"\x22 " )
196+ self .assertEqual (close_session_call_args .secret , b"\x33 " )
184197
185198
186199if __name__ == "__main__" :
0 commit comments