Skip to content

Commit 5df8b3a

Browse files
committed
unit tests
1 parent 39c96a9 commit 5df8b3a

File tree

10 files changed

+1154
-10
lines changed

10 files changed

+1154
-10
lines changed

src/sftp/azext_sftp/_validators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def storage_account_name_or_id_validator(cmd, namespace):
1313
Validator for storage account name or resource ID.
1414
Converts storage account name to full resource ID if needed.
1515
"""
16-
if namespace.storage_account:
16+
if hasattr(namespace, 'storage_account') and namespace.storage_account:
1717
if not is_valid_resource_id(namespace.storage_account):
1818
if not hasattr(namespace, 'resource_group_name') or not namespace.resource_group_name:
1919
raise azclierror.RequiredArgumentMissingError(

src/sftp/azext_sftp/custom.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None)
6969
logger.warning("%s contains sensitive information (id_rsa, id_rsa.pub). "
7070
"Please delete once this certificate is no longer being used.", keys_folder)
7171

72+
# pylint: disable=broad-except
7273
try:
7374
cert_expiration = sftp_utils.get_certificate_start_and_end_times(cert_file, ssh_client_folder)[1]
7475
print_styled_text((Style.SUCCESS,
@@ -103,15 +104,15 @@ def sftp_connect(cmd, storage_account, port=None, cert_file=None, private_key_fi
103104
delete_cert = True
104105
delete_keys = True
105106
credentials_folder = tempfile.mkdtemp(prefix="aadsftp")
106-
107+
107108
try:
108109
profile = Profile(cli_ctx=cmd.cli_ctx)
109110
profile.get_subscription()
110111
except Exception:
111112
if credentials_folder and os.path.isdir(credentials_folder):
112113
shutil.rmtree(credentials_folder)
113114
raise
114-
115+
115116
print_styled_text((Style.ACTION, "Generating temporary credentials..."))
116117

117118
if cert_file and public_key_file:
@@ -123,12 +124,9 @@ def sftp_connect(cmd, storage_account, port=None, cert_file=None, private_key_fi
123124
None, None, credentials_folder, ssh_client_folder)
124125
cert_file, user = _get_and_write_certificate(cmd, public_key_file, None, ssh_client_folder)
125126
elif not cert_file:
126-
try:
127-
profile = Profile(cli_ctx=cmd.cli_ctx)
128-
profile.get_subscription()
129-
except Exception:
130-
raise
131-
127+
profile = Profile(cli_ctx=cmd.cli_ctx)
128+
profile.get_subscription()
129+
132130
public_key_file, private_key_file, _ = _check_or_create_public_private_files(
133131
public_key_file, private_key_file, None, ssh_client_folder)
134132
print_styled_text((Style.ACTION, "Generating certificate..."))

src/sftp/azext_sftp/file_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def mkdir_p(path):
4040

4141
def delete_file(file_path, message, warning=False):
4242
if os.path.isfile(file_path):
43+
# pylint: disable=broad-except
4344
try:
4445
os.remove(file_path)
4546
except Exception as e:
@@ -51,6 +52,7 @@ def delete_file(file_path, message, warning=False):
5152

5253
def delete_folder(dir_path, message, warning=False):
5354
if os.path.isdir(dir_path):
55+
# pylint: disable=broad-except
5456
try:
5557
os.rmdir(dir_path)
5658
except Exception as e:

src/sftp/azext_sftp/rsa_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import struct
88

99

10+
# pylint: disable=too-few-public-methods
1011
class RSAParser():
1112
RSAAlgorithm = 'ssh-rsa'
1213

src/sftp/azext_sftp/sftp_info.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
logger = log.get_logger(__name__)
1111

1212

13+
# pylint: disable=too-few-public-methods
1314
class ConnectionInfo:
1415
"""Encapsulates connection-specific information."""
1516

@@ -20,6 +21,7 @@ def __init__(self, storage_account, username=None, host=None, port=None):
2021
self.port = port
2122

2223

24+
# pylint: disable=too-few-public-methods
2325
class AuthenticationFiles:
2426
"""Encapsulates authentication file paths."""
2527

@@ -29,6 +31,7 @@ def __init__(self, public_key_file=None, private_key_file=None, cert_file=None):
2931
self.cert_file = os.path.abspath(os.path.expanduser(cert_file)) if cert_file else None
3032

3133

34+
# pylint: disable=too-few-public-methods
3235
class SessionConfiguration:
3336
"""Encapsulates session configuration options."""
3437

@@ -41,6 +44,7 @@ def __init__(self, sftp_args=None, ssh_client_folder=None, ssh_proxy_folder=None
4144
self.yes_without_prompt = yes_without_prompt
4245

4346

47+
# pylint: disable=too-few-public-methods
4448
class RuntimeState:
4549
"""Encapsulates runtime state information."""
4650

src/sftp/azext_sftp/sftp_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _execute_sftp_process(command, env, creationflags):
6969
return sftp_process, None
7070

7171

72-
def _attempt_connection(command, env, creationflags, op_info, attempt_num):
72+
def _attempt_connection(command, env, creationflags, op_info, attempt_num): # pylint: disable=unused-argument
7373
"""Attempt a single SFTP connection."""
7474
connection_start_time = time.time()
7575
try:
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
6+
import unittest
7+
import json
8+
import base64
9+
from unittest import mock
10+
11+
from azext_sftp import connectivity_utils
12+
13+
14+
class SftpConnectivityUtilsTest(unittest.TestCase):
15+
"""Test suite for SFTP connectivity utilities.
16+
17+
Owner: johnli1
18+
"""
19+
20+
def test_format_relay_info_string_success(self):
21+
"""Test format_relay_info_string with valid relay information."""
22+
# Arrange
23+
relay_info = {
24+
'namespaceName': 'test-namespace',
25+
'namespaceNameSuffix': 'servicebus.windows.net',
26+
'hybridConnectionName': 'test-connection',
27+
'accessKey': 'test-access-key',
28+
'expiresOn': '2025-07-04T10:00:00Z',
29+
'serviceConfigurationToken': 'test-token'
30+
}
31+
32+
# Act
33+
result = connectivity_utils.format_relay_info_string(relay_info)
34+
35+
# Assert
36+
# Decode the base64 result to verify the structure
37+
decoded_bytes = base64.b64decode(result.encode('ascii'))
38+
decoded_string = decoded_bytes.decode('ascii')
39+
parsed_result = json.loads(decoded_string)
40+
41+
self.assertIn('relay', parsed_result)
42+
relay_data = parsed_result['relay']
43+
44+
self.assertEqual(relay_data['namespaceName'], 'test-namespace')
45+
self.assertEqual(relay_data['namespaceNameSuffix'], 'servicebus.windows.net')
46+
self.assertEqual(relay_data['hybridConnectionName'], 'test-connection')
47+
self.assertEqual(relay_data['accessKey'], 'test-access-key')
48+
self.assertEqual(relay_data['expiresOn'], '2025-07-04T10:00:00Z')
49+
self.assertEqual(relay_data['serviceConfigurationToken'], 'test-token')
50+
51+
def test_format_relay_info_string_minimal_data(self):
52+
"""Test format_relay_info_string with minimal required fields."""
53+
# Arrange
54+
relay_info = {
55+
'namespaceName': 'minimal',
56+
'namespaceNameSuffix': 'suffix',
57+
'hybridConnectionName': 'connection',
58+
'accessKey': 'key',
59+
'expiresOn': '2025-01-01T00:00:00Z',
60+
'serviceConfigurationToken': 'token'
61+
}
62+
63+
# Act
64+
result = connectivity_utils.format_relay_info_string(relay_info)
65+
66+
# Assert
67+
self.assertIsInstance(result, str)
68+
self.assertTrue(len(result) > 0)
69+
70+
# Verify it's valid base64
71+
try:
72+
decoded = base64.b64decode(result.encode('ascii'))
73+
json.loads(decoded.decode('ascii'))
74+
except (ValueError, json.JSONDecodeError):
75+
self.fail("Result is not valid base64-encoded JSON")
76+
77+
def test_format_relay_info_string_special_characters(self):
78+
"""Test format_relay_info_string with special characters in data."""
79+
# Arrange
80+
relay_info = {
81+
'namespaceName': 'test-namespace-with-dashes',
82+
'namespaceNameSuffix': 'test.suffix.com',
83+
'hybridConnectionName': 'connection_with_underscores',
84+
'accessKey': 'key+with/special=chars',
85+
'expiresOn': '2025-12-31T23:59:59.999Z',
86+
'serviceConfigurationToken': 'token-with-various_special.chars'
87+
}
88+
89+
# Act
90+
result = connectivity_utils.format_relay_info_string(relay_info)
91+
92+
# Assert
93+
decoded_bytes = base64.b64decode(result.encode('ascii'))
94+
decoded_string = decoded_bytes.decode('ascii')
95+
parsed_result = json.loads(decoded_string)
96+
97+
relay_data = parsed_result['relay']
98+
self.assertEqual(relay_data['namespaceName'], 'test-namespace-with-dashes')
99+
self.assertEqual(relay_data['accessKey'], 'key+with/special=chars')
100+
101+
def test_format_relay_info_string_unicode_characters(self):
102+
"""Test format_relay_info_string with unicode characters."""
103+
# Arrange
104+
relay_info = {
105+
'namespaceName': 'test-unicode-αβγ',
106+
'namespaceNameSuffix': 'suffix-δεζ',
107+
'hybridConnectionName': 'connection-ηθι',
108+
'accessKey': 'key-κλμ',
109+
'expiresOn': '2025-06-15T12:30:45Z',
110+
'serviceConfigurationToken': 'token-νξο'
111+
}
112+
113+
# Act
114+
result = connectivity_utils.format_relay_info_string(relay_info)
115+
116+
# Assert
117+
self.assertIsInstance(result, str)
118+
119+
# Verify roundtrip encoding/decoding works
120+
decoded_bytes = base64.b64decode(result.encode('ascii'))
121+
decoded_string = decoded_bytes.decode('ascii')
122+
parsed_result = json.loads(decoded_string)
123+
124+
relay_data = parsed_result['relay']
125+
self.assertEqual(relay_data['namespaceName'], 'test-unicode-αβγ')
126+
127+
def test_format_relay_info_string_empty_values(self):
128+
"""Test format_relay_info_string with empty string values."""
129+
# Arrange
130+
relay_info = {
131+
'namespaceName': '',
132+
'namespaceNameSuffix': '',
133+
'hybridConnectionName': '',
134+
'accessKey': '',
135+
'expiresOn': '',
136+
'serviceConfigurationToken': ''
137+
}
138+
139+
# Act
140+
result = connectivity_utils.format_relay_info_string(relay_info)
141+
142+
# Assert
143+
decoded_bytes = base64.b64decode(result.encode('ascii'))
144+
decoded_string = decoded_bytes.decode('ascii')
145+
parsed_result = json.loads(decoded_string)
146+
147+
relay_data = parsed_result['relay']
148+
for key in relay_data:
149+
self.assertEqual(relay_data[key], '')
150+
151+
def test_format_relay_info_string_large_values(self):
152+
"""Test format_relay_info_string with large string values."""
153+
# Arrange
154+
large_string = 'x' * 1000 # 1000 character string
155+
relay_info = {
156+
'namespaceName': large_string,
157+
'namespaceNameSuffix': large_string,
158+
'hybridConnectionName': large_string,
159+
'accessKey': large_string,
160+
'expiresOn': '2025-07-04T10:00:00Z',
161+
'serviceConfigurationToken': large_string
162+
}
163+
164+
# Act
165+
result = connectivity_utils.format_relay_info_string(relay_info)
166+
167+
# Assert
168+
self.assertIsInstance(result, str)
169+
decoded_bytes = base64.b64decode(result.encode('ascii'))
170+
decoded_string = decoded_bytes.decode('ascii')
171+
parsed_result = json.loads(decoded_string)
172+
173+
relay_data = parsed_result['relay']
174+
self.assertEqual(relay_data['namespaceName'], large_string)
175+
176+
@mock.patch('json.dumps')
177+
def test_format_relay_info_string_json_error_handling(self, mock_dumps):
178+
"""Test format_relay_info_string handles JSON serialization errors."""
179+
# Arrange
180+
mock_dumps.side_effect = TypeError("Object not serializable")
181+
relay_info = {
182+
'namespaceName': 'test',
183+
'namespaceNameSuffix': 'test',
184+
'hybridConnectionName': 'test',
185+
'accessKey': 'test',
186+
'expiresOn': 'test',
187+
'serviceConfigurationToken': 'test'
188+
}
189+
190+
# Act & Assert
191+
with self.assertRaises(TypeError):
192+
connectivity_utils.format_relay_info_string(relay_info)
193+
194+
def test_format_relay_info_string_structure_validation(self):
195+
"""Test that format_relay_info_string produces the expected JSON structure."""
196+
# Arrange
197+
relay_info = {
198+
'namespaceName': 'test-namespace',
199+
'namespaceNameSuffix': 'servicebus.windows.net',
200+
'hybridConnectionName': 'test-connection',
201+
'accessKey': 'test-access-key',
202+
'expiresOn': '2025-07-04T10:00:00Z',
203+
'serviceConfigurationToken': 'test-token'
204+
}
205+
206+
# Act
207+
result = connectivity_utils.format_relay_info_string(relay_info)
208+
209+
# Assert - Verify the exact structure matches expectations
210+
decoded_bytes = base64.b64decode(result.encode('ascii'))
211+
decoded_string = decoded_bytes.decode('ascii')
212+
parsed_result = json.loads(decoded_string)
213+
214+
# Check top-level structure
215+
self.assertEqual(list(parsed_result.keys()), ['relay'])
216+
217+
# Check relay object has all required fields
218+
relay_data = parsed_result['relay']
219+
expected_keys = {
220+
'namespaceName', 'namespaceNameSuffix', 'hybridConnectionName',
221+
'accessKey', 'expiresOn', 'serviceConfigurationToken'
222+
}
223+
self.assertEqual(set(relay_data.keys()), expected_keys)
224+
225+
226+
if __name__ == '__main__':
227+
unittest.main()

0 commit comments

Comments
 (0)