Skip to content

Commit 15cea64

Browse files
authored
Merge pull request #3 from jneprz/personal/johnli1/test_coverage
Personal/johnli1/test coverage
2 parents 1ff741f + fe610c3 commit 15cea64

File tree

9 files changed

+254
-248
lines changed

9 files changed

+254
-248
lines changed

src/sftp/azext_sftp/_help.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@
4646
The certificate can be used with 'az sftp connect' or with standard SFTP clients.
4747
examples:
4848
- name: Generate a certificate using an existing public key
49-
text: az sftp cert --public-key-file ~/.ssh/id_rsa.pub --output-file ~/my_cert.pub
49+
text: az sftp cert --public-key-file ~/.ssh/id_rsa.pub --file ~/my_cert.pub
5050
- name: Generate a certificate and create a new key pair in the same directory
51-
text: az sftp cert --output-file ~/my_cert.pub
51+
text: az sftp cert --file ~/my_cert.pub
5252
- name: Generate a certificate with custom SSH client folder
53-
text: az sftp cert --output-file ~/my_cert.pub --ssh-client-folder "C:\\Program Files\\OpenSSH"
53+
text: az sftp cert --file ~/my_cert.pub --ssh-client-folder "C:\\Program Files\\OpenSSH"
5454
"""
5555

5656
helps['sftp connect'] = """

src/sftp/azext_sftp/_params.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
def load_arguments(self, _):
99

1010
with self.argument_context('sftp cert') as c:
11-
c.argument('cert_path', options_list=['--output-file', '-o'],
11+
c.argument('cert_path', options_list=['--file', '-f'],
1212
help='The file path to write the SSH cert to, defaults to public key path with -aadcert.pub appended')
1313
c.argument('public_key_file', options_list=['--public-key-file', '-p'],
1414
help='The RSA public key file path. If not provided, '
15-
'generated key pair is stored in the same directory as --output-file.')
15+
'generated key pair is stored in the same directory as --file.')
1616
c.argument('ssh_client_folder', options_list=['--ssh-client-folder'],
1717
help='Folder path that contains ssh executables (ssh-keygen, ssh). '
1818
'Default to ssh executables in your PATH or C:\\Windows\\System32\\OpenSSH on Windows.')

src/sftp/azext_sftp/custom.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None)
2424
logger.debug("Starting SFTP certificate generation")
2525

2626
if not cert_path and not public_key_file:
27-
raise azclierror.RequiredArgumentMissingError("--output-file or --public-key-file must be provided.")
27+
raise azclierror.RequiredArgumentMissingError("--file or --public-key-file must be provided.")
2828

2929
if cert_path:
3030
cert_path = os.path.expanduser(cert_path)
@@ -52,14 +52,14 @@ def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None)
5252
logger.debug("Will generate key pair in: %s", keys_folder)
5353

5454
try:
55-
public_key_file, _, _ = _check_or_create_public_private_files(
55+
public_key_file, _, delete_keys = _check_or_create_public_private_files(
5656
public_key_file, None, keys_folder, ssh_client_folder)
5757
cert_file, _ = _get_and_write_certificate(cmd, public_key_file, cert_path, ssh_client_folder)
5858
except Exception as e:
5959
logger.debug("Certificate generation failed: %s", str(e))
6060
raise
6161

62-
if keys_folder:
62+
if keys_folder and delete_keys:
6363
logger.warning("%s contains sensitive information (id_rsa, id_rsa.pub). "
6464
"Please delete once this certificate is no longer being used.", keys_folder)
6565

src/sftp/azext_sftp/file_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,15 +96,22 @@ def check_or_create_public_private_files(public_key_file, private_key_file, cred
9696
delete_keys = False
9797

9898
if not public_key_file and not private_key_file:
99-
delete_keys = True
10099
if not credentials_folder:
101100
credentials_folder = tempfile.mkdtemp(prefix="aadsftpcert")
102101
else:
103102
if not os.path.isdir(credentials_folder):
104103
os.makedirs(credentials_folder)
104+
105105
public_key_file = os.path.join(credentials_folder, "id_rsa.pub")
106106
private_key_file = os.path.join(credentials_folder, "id_rsa")
107-
sftp_utils.create_ssh_keyfile(private_key_file, ssh_client_folder)
107+
108+
# Check if existing keys are present before generating new ones
109+
if not (os.path.isfile(public_key_file) and os.path.isfile(private_key_file)):
110+
# Only generate new keys if both don't exist
111+
sftp_utils.create_ssh_keyfile(private_key_file, ssh_client_folder)
112+
# Only set delete_keys to True when we actually create new keys
113+
delete_keys = True
114+
# If existing keys are found, delete_keys remains False
108115

109116
if not public_key_file:
110117
if private_key_file:
@@ -155,7 +162,7 @@ def get_and_write_certificate(cmd, public_key_file, cert_file, ssh_client_folder
155162
telemetry.add_extension_event('sftp', {'Context.Default.AzureCLI.SFTPGetCertificateTime': time_elapsed})
156163

157164
if not cert_file:
158-
cert_file = str(public_key_file) + "-aadcert.pub"
165+
cert_file = str(public_key_file.removesuffix(".pub")) + "-aadcert.pub"
159166

160167
logger.debug("Generating certificate %s", cert_file)
161168
_write_cert_file(certificate, cert_file)

src/sftp/azext_sftp/sftp_utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ def get_ssh_cert_info(cert_file, ssh_client_folder=None):
175175
const.RECOMMENDATION_SSH_CLIENT_NOT_FOUND)
176176

177177

178+
_warned_ssh_client_folders = set()
179+
180+
178181
def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None):
179182
"""Get the path to an SSH client executable."""
180183
if ssh_client_folder:
@@ -184,8 +187,11 @@ def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None):
184187
if os.path.isfile(ssh_path):
185188
logger.debug("Attempting to run %s from path %s", ssh_command, ssh_path)
186189
return ssh_path
187-
logger.warning("Could not find %s in provided --ssh-client-folder %s. "
188-
"Attempting to get pre-installed OpenSSH bits.", ssh_command, ssh_client_folder)
190+
warn_key = (ssh_command, os.path.abspath(ssh_client_folder))
191+
if warn_key not in _warned_ssh_client_folders:
192+
logger.warning("Could not find %s in provided --ssh-client-folder %s. "
193+
"Attempting to get pre-installed OpenSSH bits.", ssh_command, ssh_client_folder)
194+
_warned_ssh_client_folders.add(warn_key)
189195

190196
if platform.system() != 'Windows':
191197
return ssh_command

src/sftp/azext_sftp/tests/latest/test_custom.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -86,31 +86,6 @@ def test_sftp_cert(self, mock_write_cert, mock_get_keys, mock_abspath, mock_isdi
8686
mock_get_keys.assert_called_once_with('/pubkey/path', None, None, None)
8787
mock_write_cert.assert_called_once_with(cmd, 'pubkey', '/cert/path', None)
8888

89-
def test_sftp_connect_preprod(self):
90-
"""Test SFTP connection to preprod environment.
91-
92-
Owner: johnli1
93-
"""
94-
cmd = mock.Mock()
95-
cmd.cli_ctx = mock.Mock()
96-
cmd.cli_ctx.cloud = mock.Mock()
97-
cmd.cli_ctx.cloud.name = "azurecloud"
98-
99-
# Create a temporary batch file for automated testing
100-
batch_file = os.path.join(self.temp_dir, "test_batch.txt")
101-
with open(batch_file, 'w') as f:
102-
f.write("pwd\nls\nexit\n")
103-
104-
# Use batch file to avoid interactive prompt
105-
custom.sftp_connect(
106-
cmd=cmd,
107-
storage_account='johnli1canary',
108-
port=22,
109-
cert_file='C:\\Users\\johnli1\\.ssh\\id_rsa-aadcert.pub',
110-
sftp_args=['-b', batch_file] # Use actual batch file
111-
)
112-
self.assertTrue(True)
113-
11489
@mock.patch('azext_sftp.custom._do_sftp_op')
11590
@mock.patch('azext_sftp.sftp_utils.get_ssh_cert_principals')
11691
def test_sftp_connect_certificate_scenarios(self, mock_get_principals, mock_do_sftp):
@@ -713,7 +688,7 @@ def test_sftp_cert_error_cases(self):
713688
"""Test sftp cert error handling with invalid argument combinations."""
714689
# Test cases: (cert_path, public_key_file, setup_mocks, expected_exception, expected_message, description)
715690
test_cases = [
716-
(None, None, {}, azclierror.RequiredArgumentMissingError, "--output-file or --public-key-file must be provided", "no_arguments"),
691+
(None, None, {}, azclierror.RequiredArgumentMissingError, "--file or --public-key-file must be provided", "no_arguments"),
717692
("/bad/cert.pub", None, {"expanduser_return": "/bad/cert.pub", "isdir_return": False}, azclierror.InvalidArgumentValueError, "folder doesn't exist", "invalid_directory"),
718693
]
719694

src/sftp/azext_sftp/tests/latest/test_file_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,11 +379,14 @@ def test_check_or_create_public_private_files_generates_keys(self, mock_create_k
379379
expected_public_key = os.path.join(self.temp_dir, "id_rsa.pub")
380380
expected_private_key = os.path.join(self.temp_dir, "id_rsa")
381381

382-
# Create the expected files so they exist for the function
383-
with open(expected_public_key, 'w') as f:
384-
f.write("ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ [email protected]")
385-
with open(expected_private_key, 'w') as f:
386-
f.write("-----BEGIN OPENSSH PRIVATE KEY-----\ntest\n-----END OPENSSH PRIVATE KEY-----")
382+
# Mock the create_ssh_keyfile to create the files when called
383+
def create_key_files(private_key_path, ssh_client_folder):
384+
with open(private_key_path, 'w') as f:
385+
f.write("-----BEGIN OPENSSH PRIVATE KEY-----\ntest\n-----END OPENSSH PRIVATE KEY-----")
386+
with open(private_key_path + ".pub", 'w') as f:
387+
f.write("ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ [email protected]")
388+
389+
mock_create_keyfile.side_effect = create_key_files
387390

388391
# Act
389392
public_key, private_key, delete_keys = file_utils.check_or_create_public_private_files(
@@ -454,7 +457,7 @@ def test_get_and_write_certificate_success(self, mock_set_mode, mock_get_princip
454457
mock_get_principals.return_value = ["[email protected]"]
455458

456459
# Set up the cert file path that the function will generate
457-
expected_cert_file = str(self.mock_public_key) + "-aadcert.pub"
460+
expected_cert_file = str(self.mock_public_key.removesuffix(".pub")) + "-aadcert.pub"
458461
mock_write_cert.return_value = expected_cert_file
459462

460463
# Act

src/sftp/azext_sftp/tests/latest/test_sftp_info.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,23 @@ def test_invalid_port_handling(self):
190190
# Port should be converted to string
191191
self.assertIsInstance(command_args[port_index + 1], str)
192192

193+
def test_extension_sftp_session_creation(self):
194+
"""Test that the extension creates SFTP session with correct parameters."""
195+
session = sftp_info.SFTPSession(
196+
storage_account=self.test_storage_account,
197+
username=self.test_username,
198+
host=self.test_host,
199+
port=self.test_port,
200+
cert_file=self.test_cert_file,
201+
private_key_file=self.test_private_key_file
202+
)
203+
204+
self.assertEqual(session.storage_account, self.test_storage_account)
205+
self.assertEqual(session.username, self.test_username)
206+
self.assertEqual(session.host, self.test_host)
207+
self.assertEqual(session.port, self.test_port)
208+
self.assertEqual(session.cert_file, os.path.abspath(self.test_cert_file))
209+
self.assertEqual(session.private_key_file, os.path.abspath(self.test_private_key_file))
193210

194211
if __name__ == '__main__':
195212
unittest.main()

0 commit comments

Comments
 (0)