Skip to content

Commit 80f8fea

Browse files
committed
simplify
1 parent 5df8b3a commit 80f8fea

File tree

9 files changed

+146
-315
lines changed

9 files changed

+146
-315
lines changed

src/sftp/azext_sftp/__init__.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,11 @@ def __init__(self, cli_ctx=None):
3333
from azure.cli.core.commands import CliCommandType
3434
from azext_sftp._client_factory import cf_sftp
3535

36-
sftp_custom = CliCommandType(
37-
operations_tmpl='azext_sftp.custom#{}',
38-
client_factory=cf_sftp)
39-
40-
super(SftpCommandsLoader, self).__init__(
36+
super().__init__(
4137
cli_ctx=cli_ctx,
42-
custom_command_type=sftp_custom)
38+
custom_command_type=CliCommandType(
39+
operations_tmpl='azext_sftp.custom#{}',
40+
client_factory=cf_sftp))
4341

4442
def load_command_table(self, args):
4543
"""Load the command table for SFTP commands."""

src/sftp/azext_sftp/connectivity_utils.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,15 @@
1212

1313

1414
def format_relay_info_string(relay_info):
15-
relay_info_string = json.dumps(
16-
{
17-
"relay": {
18-
"namespaceName": relay_info['namespaceName'],
19-
"namespaceNameSuffix": relay_info['namespaceNameSuffix'],
20-
"hybridConnectionName": relay_info['hybridConnectionName'],
21-
"accessKey": relay_info['accessKey'],
22-
"expiresOn": relay_info['expiresOn'],
23-
"serviceConfigurationToken": relay_info['serviceConfigurationToken']
24-
}
25-
})
26-
result_bytes = relay_info_string.encode("ascii")
27-
enc = base64.b64encode(result_bytes)
28-
base64_result_string = enc.decode("ascii")
29-
return base64_result_string
15+
"""Format relay information as base64-encoded JSON string."""
16+
relay_data = {
17+
"relay": {
18+
"namespaceName": relay_info['namespaceName'],
19+
"namespaceNameSuffix": relay_info['namespaceNameSuffix'],
20+
"hybridConnectionName": relay_info['hybridConnectionName'],
21+
"accessKey": relay_info['accessKey'],
22+
"expiresOn": relay_info['expiresOn'],
23+
"serviceConfigurationToken": relay_info['serviceConfigurationToken']
24+
}
25+
}
26+
return base64.b64encode(json.dumps(relay_data).encode("ascii")).decode("ascii")

src/sftp/azext_sftp/constants.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88
# File system constants
99
WINDOWS_INVALID_FOLDERNAME_CHARS = "\\/*:<>?\"|"
1010

11-
# Default SSH/SFTP configuration
12-
DEFAULT_SSH_PORT = 22
13-
DEFAULT_SFTP_PORT = 22
14-
AZURE_STORAGE_SFTP_PORT = 22
11+
# Default ports
12+
DEFAULT_SSH_PORT = DEFAULT_SFTP_PORT = AZURE_STORAGE_SFTP_PORT = 22
1513

1614
# SSH/SFTP client configuration
1715
SSH_CONNECT_TIMEOUT = 30

src/sftp/azext_sftp/custom.py

Lines changed: 27 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,14 @@
44
# --------------------------------------------------------------------------------------------
55

66
import os
7-
import hashlib
8-
import json
97
import tempfile
10-
import time
118
import shutil
12-
import oschmod
139

1410
from knack import log
1511
from azure.cli.core import azclierror
16-
from azure.cli.core import telemetry
1712
from azure.cli.core.style import Style, print_styled_text
1813
from azure.cli.core._profile import Profile
1914

20-
from . import rsa_parser
2115
from . import sftp_info
2216
from . import sftp_utils
2317
from . import file_utils
@@ -185,143 +179,48 @@ def sftp_connect(cmd, storage_account, port=None, cert_file=None, private_key_fi
185179

186180
def _check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder,
187181
ssh_client_folder=None):
188-
delete_keys = False
189-
if not public_key_file and not private_key_file:
190-
delete_keys = True
191-
if not credentials_folder:
192-
credentials_folder = tempfile.mkdtemp(prefix="aadsshcert")
193-
else:
194-
if not os.path.isdir(credentials_folder):
195-
os.makedirs(credentials_folder)
196-
public_key_file = os.path.join(credentials_folder, "id_rsa.pub")
197-
private_key_file = os.path.join(credentials_folder, "id_rsa")
198-
sftp_utils.create_ssh_keyfile(private_key_file, ssh_client_folder)
199-
200-
if not public_key_file:
201-
if private_key_file:
202-
public_key_file = str(private_key_file) + ".pub"
203-
else:
204-
raise azclierror.RequiredArgumentMissingError("Public key file not specified")
205-
206-
if not os.path.isfile(public_key_file):
207-
raise azclierror.FileOperationError(f"Public key file {public_key_file} not found")
208-
209-
if private_key_file:
210-
if not os.path.isfile(private_key_file):
211-
raise azclierror.FileOperationError(f"Private key file {private_key_file} not found")
212-
213-
if not private_key_file:
214-
if public_key_file.endswith(".pub"):
215-
private_key_file = public_key_file[:-4] if os.path.isfile(public_key_file[:-4]) else None
216-
217-
return public_key_file, private_key_file, delete_keys
182+
"""Check for existing key files or create new ones if needed."""
183+
return file_utils.check_or_create_public_private_files(
184+
public_key_file, private_key_file, credentials_folder, ssh_client_folder)
218185

219186

220187
def _get_and_write_certificate(cmd, public_key_file, cert_file, ssh_client_folder):
221-
cloudtoscope = {
222-
"azurecloud": "https://pas.windows.net/CheckMyAccess/Linux/.default",
223-
"azurechinacloud": "https://pas.chinacloudapi.cn/CheckMyAccess/Linux/.default",
224-
"azureusgovernment": "https://pasff.usgovcloudapi.net/CheckMyAccess/Linux/.default"
225-
}
226-
scope = cloudtoscope.get(cmd.cli_ctx.cloud.name.lower(), None)
227-
if not scope:
228-
raise azclierror.InvalidArgumentValueError(
229-
f"Unsupported cloud {cmd.cli_ctx.cloud.name.lower()}",
230-
"Supported clouds include azurecloud,azurechinacloud,azureusgovernment")
231-
232-
scopes = [scope]
233-
data = _prepare_jwk_data(public_key_file)
234-
profile = Profile(cli_ctx=cmd.cli_ctx)
235-
236-
t0 = time.time()
237-
if hasattr(profile, "get_msal_token"):
238-
_, certificate = profile.get_msal_token(scopes, data)
239-
else:
240-
credential, _, _ = profile.get_login_credentials(subscription_id=profile.get_subscription()["id"])
241-
certificatedata = credential.get_token(*scopes, data=data)
242-
certificate = certificatedata.token
243-
244-
time_elapsed = time.time() - t0
245-
telemetry.add_extension_event('sftp',
246-
{'Context.Default.AzureCLI.SftpGetCertificateTime': time_elapsed})
247-
248-
if not cert_file:
249-
base_name = os.path.splitext(str(public_key_file))[0]
250-
cert_file = base_name + "-aadcert.pub"
251-
252-
logger.debug("Generating certificate %s", cert_file)
253-
_write_cert_file(certificate, cert_file)
254-
username = sftp_utils.get_ssh_cert_principals(cert_file, ssh_client_folder)[0]
255-
oschmod.set_mode(cert_file, 0o600)
256-
257-
return cert_file, username.lower()
188+
"""Generate and write an SSH certificate using Azure AD authentication."""
189+
return file_utils.get_and_write_certificate(cmd, public_key_file, cert_file, ssh_client_folder)
258190

259191

260192
def _prepare_jwk_data(public_key_file):
261-
modulus, exponent = _get_modulus_exponent(public_key_file)
262-
key_hash = hashlib.sha256()
263-
key_hash.update(modulus.encode('utf-8'))
264-
key_hash.update(exponent.encode('utf-8'))
265-
key_id = key_hash.hexdigest()
266-
jwk = {
267-
"kty": "RSA",
268-
"n": modulus,
269-
"e": exponent,
270-
"kid": key_id
271-
}
272-
json_jwk = json.dumps(jwk)
273-
data = {
274-
"token_type": "ssh-cert",
275-
"req_cnf": json_jwk,
276-
"key_id": key_id
277-
}
278-
return data
193+
"""Prepare JWK data for certificate request."""
194+
return file_utils._prepare_jwk_data(public_key_file) # pylint: disable=protected-access
279195

280196

281197
def _write_cert_file(certificate_contents, cert_file):
282-
with open(cert_file, 'w', encoding='utf-8') as f:
283-
f.write(f"ssh-rsa-cert-v01@openssh.com {certificate_contents}")
284-
oschmod.set_mode(cert_file, 0o644)
285-
return cert_file
198+
"""Write SSH certificate to file."""
199+
return file_utils._write_cert_file(certificate_contents, cert_file) # pylint: disable=protected-access
286200

287201

288202
def _get_modulus_exponent(public_key_file):
289-
if not os.path.isfile(public_key_file):
290-
raise azclierror.FileOperationError(f"Public key file '{public_key_file}' was not found")
291-
292-
with open(public_key_file, 'r', encoding='utf-8') as f:
293-
public_key_text = f.read()
294-
295-
parser = rsa_parser.RSAParser()
296-
try:
297-
parser.parse(public_key_text)
298-
except Exception as e:
299-
raise azclierror.FileOperationError(f"Could not parse public key. Error: {str(e)}")
300-
modulus = parser.modulus
301-
exponent = parser.exponent
302-
303-
return modulus, exponent
203+
"""Extract modulus and exponent from RSA public key file."""
204+
return file_utils._get_modulus_exponent(public_key_file) # pylint: disable=protected-access
304205

305206

306207
def _assert_args(storage_account, cert_file, public_key_file, private_key_file):
307208
"""Validate SFTP connection arguments."""
308209
if not storage_account:
309210
raise azclierror.RequiredArgumentMissingError("Storage account name is required.")
310211

311-
if cert_file:
312-
expanded_cert_file = os.path.expanduser(cert_file)
313-
if not os.path.isfile(expanded_cert_file):
314-
raise azclierror.FileOperationError(f"Certificate file {cert_file} not found.")
212+
# Check file existence for provided files
213+
files_to_check = [
214+
(cert_file, "Certificate"),
215+
(public_key_file, "Public key"),
216+
(private_key_file, "Private key")
217+
]
315218

316-
if public_key_file:
317-
expanded_public_key_file = os.path.expanduser(public_key_file)
318-
if not os.path.isfile(expanded_public_key_file):
319-
raise azclierror.FileOperationError(f"Public key file {public_key_file} not found.")
320-
321-
if private_key_file:
322-
expanded_private_key_file = os.path.expanduser(private_key_file)
323-
if not os.path.isfile(expanded_private_key_file):
324-
raise azclierror.FileOperationError(f"Private key file {private_key_file} not found.")
219+
for file_path, file_type in files_to_check:
220+
if file_path:
221+
expanded_path = os.path.expanduser(file_path)
222+
if not os.path.isfile(expanded_path):
223+
raise azclierror.FileOperationError(f"{file_type} file {file_path} not found.")
325224

326225

327226
def _do_sftp_op(sftp_session, op_call):
@@ -337,12 +236,9 @@ def _cleanup_credentials(delete_keys, delete_cert, credentials_folder, cert_file
337236
file_utils.delete_file(cert_file, f"Deleting generated certificate {cert_file}", warning=False)
338237

339238
if delete_keys:
340-
if private_key_file and os.path.isfile(private_key_file):
341-
file_utils.delete_file(private_key_file,
342-
f"Deleting generated private key {private_key_file}", warning=False)
343-
if public_key_file and os.path.isfile(public_key_file):
344-
file_utils.delete_file(public_key_file,
345-
f"Deleting generated public key {public_key_file}", warning=False)
239+
for key_file, key_type in [(private_key_file, "private"), (public_key_file, "public")]:
240+
if key_file and os.path.isfile(key_file):
241+
file_utils.delete_file(key_file, f"Deleting generated {key_type} key {key_file}", warning=False)
346242

347243
if credentials_folder and os.path.isdir(credentials_folder):
348244
logger.debug("Deleting credentials folder %s", credentials_folder)
@@ -354,9 +250,9 @@ def _cleanup_credentials(delete_keys, delete_cert, credentials_folder, cert_file
354250

355251
def _get_storage_endpoint_suffix(cmd):
356252
"""Get the appropriate storage endpoint suffix based on Azure cloud environment."""
357-
cloud_to_storage_suffix = {
253+
cloud_suffixes = {
358254
"azurecloud": "blob.core.windows.net",
359255
"azurechinacloud": "blob.core.chinacloudapi.cn",
360256
"azureusgovernment": "blob.core.usgovcloudapi.net"
361257
}
362-
return cloud_to_storage_suffix.get(cmd.cli_ctx.cloud.name.lower(), "blob.core.windows.net")
258+
return cloud_suffixes.get(cmd.cli_ctx.cloud.name.lower(), "blob.core.windows.net")

0 commit comments

Comments
 (0)