1- import unittest
1+ import pytest
22from unittest .mock import patch , MagicMock , Mock , PropertyMock
33import gc
44
1212import databricks .sql
1313
1414
15- class SessionTestSuite ( unittest . TestCase ) :
15+ class TestSession :
1616 """
1717 Unit tests for Session functionality
1818 """
@@ -37,8 +37,8 @@ def test_close_uses_the_correct_session_id(self, mock_client_class):
3737
3838 # Check that close_session was called with the correct SessionId
3939 close_session_call_args = instance .close_session .call_args [0 ][0 ]
40- self . assertEqual ( close_session_call_args .guid , b"\x22 " )
41- self . assertEqual ( close_session_call_args .secret , b"\x33 " )
40+ assert close_session_call_args .guid == b"\x22 "
41+ assert close_session_call_args .secret == b"\x33 "
4242
4343 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
4444 def test_auth_args (self , mock_client_class ):
@@ -63,8 +63,8 @@ def test_auth_args(self, mock_client_class):
6363 for args in connection_args :
6464 connection = databricks .sql .connect (** args )
6565 host , port , http_path , * _ = mock_client_class .call_args [0 ]
66- self . assertEqual ( args ["server_hostname" ], host )
67- self . assertEqual ( args ["http_path" ], http_path )
66+ assert args ["server_hostname" ] == host
67+ assert args ["http_path" ] == http_path
6868 connection .close ()
6969
7070 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
@@ -73,7 +73,7 @@ def test_http_header_passthrough(self, mock_client_class):
7373 databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS , http_headers = http_headers )
7474
7575 call_args = mock_client_class .call_args [0 ][3 ]
76- self . assertIn (( "foo" , "bar" ), call_args )
76+ assert ( "foo" , "bar" ) in call_args
7777
7878 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
7979 def test_tls_arg_passthrough (self , mock_client_class ):
@@ -86,10 +86,10 @@ def test_tls_arg_passthrough(self, mock_client_class):
8686 )
8787
8888 kwargs = mock_client_class .call_args [1 ]
89- self . assertEqual ( kwargs ["_tls_verify_hostname" ], "hostname" )
90- self . assertEqual ( kwargs ["_tls_trusted_ca_file" ], "trusted ca file" )
91- self . assertEqual ( kwargs ["_tls_client_cert_key_file" ], "trusted client cert" )
92- self . assertEqual ( kwargs ["_tls_client_cert_key_password" ], "key password" )
89+ assert kwargs ["_tls_verify_hostname" ] == "hostname"
90+ assert kwargs ["_tls_trusted_ca_file" ] == "trusted ca file"
91+ assert kwargs ["_tls_client_cert_key_file" ] == "trusted client cert"
92+ assert kwargs ["_tls_client_cert_key_password" ] == "key password"
9393
9494 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
9595 def test_useragent_header (self , mock_client_class ):
@@ -100,7 +100,7 @@ def test_useragent_header(self, mock_client_class):
100100 "User-Agent" ,
101101 "{}/{}" .format (databricks .sql .USER_AGENT_NAME , databricks .sql .__version__ ),
102102 )
103- self . assertIn ( user_agent_header , http_headers )
103+ assert user_agent_header in http_headers
104104
105105 databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS , user_agent_entry = "foobar" )
106106 user_agent_header_with_entry = (
@@ -110,7 +110,7 @@ def test_useragent_header(self, mock_client_class):
110110 ),
111111 )
112112 http_headers = mock_client_class .call_args [0 ][3 ]
113- self . assertIn ( user_agent_header_with_entry , http_headers )
113+ assert user_agent_header_with_entry in http_headers
114114
115115 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
116116 def test_context_manager_closes_connection (self , mock_client_class ):
@@ -125,13 +125,13 @@ def test_context_manager_closes_connection(self, mock_client_class):
125125
126126 # Check that close_session was called with the correct SessionId
127127 close_session_call_args = instance .close_session .call_args [0 ][0 ]
128- self . assertEqual ( close_session_call_args .guid , b"\x22 " )
129- self . assertEqual ( close_session_call_args .secret , b"\x33 " )
128+ assert close_session_call_args .guid == b"\x22 "
129+ assert close_session_call_args .secret == b"\x33 "
130130
131131 connection = databricks .sql .connect (** self .DUMMY_CONNECTION_ARGS )
132132 connection .close = Mock ()
133133 try :
134- with self . assertRaises (KeyboardInterrupt ):
134+ with pytest . raises (KeyboardInterrupt ):
135135 with connection :
136136 raise KeyboardInterrupt ("Simulated interrupt" )
137137 finally :
@@ -143,14 +143,12 @@ def test_max_number_of_retries_passthrough(self, mock_client_class):
143143 _retry_stop_after_attempts_count = 54 , ** self .DUMMY_CONNECTION_ARGS
144144 )
145145
146- self .assertEqual (
147- mock_client_class .call_args [1 ]["_retry_stop_after_attempts_count" ], 54
148- )
146+ assert mock_client_class .call_args [1 ]["_retry_stop_after_attempts_count" ] == 54
149147
150148 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
151149 def test_socket_timeout_passthrough (self , mock_client_class ):
152150 databricks .sql .connect (_socket_timeout = 234 , ** self .DUMMY_CONNECTION_ARGS )
153- self . assertEqual ( mock_client_class .call_args [1 ]["_socket_timeout" ], 234 )
151+ assert mock_client_class .call_args [1 ]["_socket_timeout" ] == 234
154152
155153 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
156154 def test_configuration_passthrough (self , mock_client_class ):
@@ -160,7 +158,7 @@ def test_configuration_passthrough(self, mock_client_class):
160158 )
161159
162160 call_kwargs = mock_client_class .return_value .open_session .call_args [1 ]
163- self . assertEqual ( call_kwargs ["session_configuration" ], mock_session_config )
161+ assert call_kwargs ["session_configuration" ] == mock_session_config
164162
165163 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
166164 def test_initial_namespace_passthrough (self , mock_client_class ):
@@ -171,8 +169,8 @@ def test_initial_namespace_passthrough(self, mock_client_class):
171169 )
172170
173171 call_kwargs = mock_client_class .return_value .open_session .call_args [1 ]
174- self . assertEqual ( call_kwargs ["catalog" ], mock_cat )
175- self . assertEqual ( call_kwargs ["schema" ], mock_schem )
172+ assert call_kwargs ["catalog" ] == mock_cat
173+ assert call_kwargs ["schema" ] == mock_schem
176174
177175 @patch ("%s.session.ThriftDatabricksClient" % PACKAGE_NAME )
178176 def test_finalizer_closes_abandoned_connection (self , mock_client_class ):
@@ -188,9 +186,5 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class):
188186
189187 # Check that close_session was called with the correct SessionId
190188 close_session_call_args = instance .close_session .call_args [0 ][0 ]
191- self .assertEqual (close_session_call_args .guid , b"\x22 " )
192- self .assertEqual (close_session_call_args .secret , b"\x33 " )
193-
194-
195- if __name__ == "__main__" :
196- unittest .main ()
189+ assert close_session_call_args .guid == b"\x22 "
190+ assert close_session_call_args .secret == b"\x33 "
0 commit comments