Skip to content

Commit 64fb9b2

Browse files
align test_session with pytest instead of unittest
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 61dfc4d commit 64fb9b2

File tree

1 file changed

+23
-29
lines changed

1 file changed

+23
-29
lines changed

tests/unit/test_session.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import unittest
1+
import pytest
22
from unittest.mock import patch, MagicMock, Mock, PropertyMock
33
import gc
44

@@ -12,7 +12,7 @@
1212
import databricks.sql
1313

1414

15-
class SessionTestSuite(unittest.TestCase):
15+
class TestSession:
1616
"""
1717
Unit tests for Session functionality
1818
"""
@@ -37,8 +37,8 @@ def test_close_uses_the_correct_session_id(self, mock_client_class):
3737

3838
# Check that close_session was called with the correct SessionId
3939
close_session_call_args = instance.close_session.call_args[0][0]
40-
self.assertEqual(close_session_call_args.guid, b"\x22")
41-
self.assertEqual(close_session_call_args.secret, b"\x33")
40+
assert close_session_call_args.guid == b"\x22"
41+
assert close_session_call_args.secret == b"\x33"
4242

4343
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
4444
def test_auth_args(self, mock_client_class):
@@ -63,8 +63,8 @@ def test_auth_args(self, mock_client_class):
6363
for args in connection_args:
6464
connection = databricks.sql.connect(**args)
6565
host, port, http_path, *_ = mock_client_class.call_args[0]
66-
self.assertEqual(args["server_hostname"], host)
67-
self.assertEqual(args["http_path"], http_path)
66+
assert args["server_hostname"] == host
67+
assert args["http_path"] == http_path
6868
connection.close()
6969

7070
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
@@ -73,7 +73,7 @@ def test_http_header_passthrough(self, mock_client_class):
7373
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers)
7474

7575
call_args = mock_client_class.call_args[0][3]
76-
self.assertIn(("foo", "bar"), call_args)
76+
assert ("foo", "bar") in call_args
7777

7878
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
7979
def test_tls_arg_passthrough(self, mock_client_class):
@@ -86,10 +86,10 @@ def test_tls_arg_passthrough(self, mock_client_class):
8686
)
8787

8888
kwargs = mock_client_class.call_args[1]
89-
self.assertEqual(kwargs["_tls_verify_hostname"], "hostname")
90-
self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file")
91-
self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert")
92-
self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password")
89+
assert kwargs["_tls_verify_hostname"] == "hostname"
90+
assert kwargs["_tls_trusted_ca_file"] == "trusted ca file"
91+
assert kwargs["_tls_client_cert_key_file"] == "trusted client cert"
92+
assert kwargs["_tls_client_cert_key_password"] == "key password"
9393

9494
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
9595
def test_useragent_header(self, mock_client_class):
@@ -100,7 +100,7 @@ def test_useragent_header(self, mock_client_class):
100100
"User-Agent",
101101
"{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__),
102102
)
103-
self.assertIn(user_agent_header, http_headers)
103+
assert user_agent_header in http_headers
104104

105105
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar")
106106
user_agent_header_with_entry = (
@@ -110,7 +110,7 @@ def test_useragent_header(self, mock_client_class):
110110
),
111111
)
112112
http_headers = mock_client_class.call_args[0][3]
113-
self.assertIn(user_agent_header_with_entry, http_headers)
113+
assert user_agent_header_with_entry in http_headers
114114

115115
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
116116
def test_context_manager_closes_connection(self, mock_client_class):
@@ -125,13 +125,13 @@ def test_context_manager_closes_connection(self, mock_client_class):
125125

126126
# Check that close_session was called with the correct SessionId
127127
close_session_call_args = instance.close_session.call_args[0][0]
128-
self.assertEqual(close_session_call_args.guid, b"\x22")
129-
self.assertEqual(close_session_call_args.secret, b"\x33")
128+
assert close_session_call_args.guid == b"\x22"
129+
assert close_session_call_args.secret == b"\x33"
130130

131131
connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS)
132132
connection.close = Mock()
133133
try:
134-
with self.assertRaises(KeyboardInterrupt):
134+
with pytest.raises(KeyboardInterrupt):
135135
with connection:
136136
raise KeyboardInterrupt("Simulated interrupt")
137137
finally:
@@ -143,14 +143,12 @@ def test_max_number_of_retries_passthrough(self, mock_client_class):
143143
_retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS
144144
)
145145

146-
self.assertEqual(
147-
mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54
148-
)
146+
assert mock_client_class.call_args[1]["_retry_stop_after_attempts_count"] == 54
149147

150148
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
151149
def test_socket_timeout_passthrough(self, mock_client_class):
152150
databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS)
153-
self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234)
151+
assert mock_client_class.call_args[1]["_socket_timeout"] == 234
154152

155153
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
156154
def test_configuration_passthrough(self, mock_client_class):
@@ -160,7 +158,7 @@ def test_configuration_passthrough(self, mock_client_class):
160158
)
161159

162160
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
163-
self.assertEqual(call_kwargs["session_configuration"], mock_session_config)
161+
assert call_kwargs["session_configuration"] == mock_session_config
164162

165163
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
166164
def test_initial_namespace_passthrough(self, mock_client_class):
@@ -171,8 +169,8 @@ def test_initial_namespace_passthrough(self, mock_client_class):
171169
)
172170

173171
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
174-
self.assertEqual(call_kwargs["catalog"], mock_cat)
175-
self.assertEqual(call_kwargs["schema"], mock_schem)
172+
assert call_kwargs["catalog"] == mock_cat
173+
assert call_kwargs["schema"] == mock_schem
176174

177175
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
178176
def test_finalizer_closes_abandoned_connection(self, mock_client_class):
@@ -188,9 +186,5 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class):
188186

189187
# Check that close_session was called with the correct SessionId
190188
close_session_call_args = instance.close_session.call_args[0][0]
191-
self.assertEqual(close_session_call_args.guid, b"\x22")
192-
self.assertEqual(close_session_call_args.secret, b"\x33")
193-
194-
195-
if __name__ == "__main__":
196-
unittest.main()
189+
assert close_session_call_args.guid == b"\x22"
190+
assert close_session_call_args.secret == b"\x33"

0 commit comments

Comments
 (0)