diff --git a/src/service_name.json b/src/service_name.json index bbc5685fb42..c4dc229473a 100644 --- a/src/service_name.json +++ b/src/service_name.json @@ -539,6 +539,11 @@ "AzureServiceName": "Virtual Machines", "URL": "https://learn.microsoft.com/azure/virtual-machines/shared-image-galleries" }, + { + "Command": "az sftp", + "AzureServiceName": "Storage", + "URL": "https://learn.microsoft.com/en-us/azure/storage/blobs/secure-file-transfer-protocol-support" + }, { "Command": "az spatial-anchors-account", "AzureServiceName": "Mixed Reality", diff --git a/src/sftp/HISTORY.rst b/src/sftp/HISTORY.rst new file mode 100644 index 00000000000..7b7fa262c06 --- /dev/null +++ b/src/sftp/HISTORY.rst @@ -0,0 +1,8 @@ +.. :changelog: + +Release History +=============== + +1.0.0b1 ++++++++ +* Initial preview release with SFTP connection and certificate generation support. \ No newline at end of file diff --git a/src/sftp/README.rst b/src/sftp/README.rst new file mode 100644 index 00000000000..542d136de3c --- /dev/null +++ b/src/sftp/README.rst @@ -0,0 +1,6 @@ +Azure CLI SFTP Commands +======================== + +Secure connections to Azure Storage via SFTP with SSH certificates. + +Commands include certificate generation and SFTP connection management. \ No newline at end of file diff --git a/src/sftp/azext_sftp/__init__.py b/src/sftp/azext_sftp/__init__.py new file mode 100644 index 00000000000..4e709dfe838 --- /dev/null +++ b/src/sftp/azext_sftp/__init__.py @@ -0,0 +1,52 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +""" +Azure CLI SFTP Extension + +This extension provides secure SFTP connectivity to Azure Storage Accounts +with automatic Azure AD authentication and certificate management. + +Key Features: +- Fully managed SSH certificate generation using Azure AD +- Support for existing SSH keys and certificates +- Interactive and batch SFTP operations +- Automatic credential cleanup for security +- Integration with Azure Storage SFTP endpoints + +Commands: +- az sftp cert: Generate SSH certificates for SFTP authentication +- az sftp connect: Connect to Azure Storage Account via SFTP +""" + +from azure.cli.core import AzCommandsLoader + +from azext_sftp._help import helps # pylint: disable=unused-import + + +class SftpCommandsLoader(AzCommandsLoader): + """Command loader for the SFTP extension.""" + + def __init__(self, cli_ctx=None): + from azure.cli.core.commands import CliCommandType + + super().__init__( + cli_ctx=cli_ctx, + custom_command_type=CliCommandType( + operations_tmpl='azext_sftp.custom#{}')) + + def load_command_table(self, args): + """Load the command table for SFTP commands.""" + from azext_sftp.commands import load_command_table + load_command_table(self, args) + return self.command_table + + def load_arguments(self, command): + """Load arguments for SFTP commands.""" + from azext_sftp._params import load_arguments + load_arguments(self, command) + + +COMMAND_LOADER_CLS = SftpCommandsLoader diff --git a/src/sftp/azext_sftp/_help.py b/src/sftp/azext_sftp/_help.py new file mode 100644 index 00000000000..59b952719fe --- /dev/null +++ b/src/sftp/azext_sftp/_help.py @@ -0,0 +1,98 @@ +# coding=utf-8 +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from knack.help_files import helps # pylint: disable=unused-import + + +helps['sftp'] = """ + type: group + short-summary: Generate SSH certificates and access Azure Storage blob data via SFTP + long-summary: | + These commands allow you to generate certificates and connect to Azure Storage Accounts using SFTP. + + PREREQUISITES: + - Azure Storage Account with SFTP enabled + - Appropriate RBAC permissions (Storage Blob Data Contributor or similar) + - Azure CLI authentication (az login) + - Network connectivity to Azure Storage endpoints + + The SFTP extension provides two main capabilities: + 1. Certificate generation using Azure AD authentication (similar to 'az ssh cert') + 2. Fully managed SFTP connections to Azure Storage with automatic credential handling + + AUTHENTICATION MODES: + - Fully managed: No credentials needed - automatically generates SSH certificate + - Certificate-based: Use existing SSH certificate file + - Key-based: Use SSH public/private key pair (generates certificate automatically) + + This extension closely follows the patterns established by the SSH extension. +""" + +helps['sftp cert'] = """ + type: command + short-summary: Generate SSH certificate for SFTP authentication + long-summary: | + Generate an SSH certificate that can be used for authenticating to Azure Storage SFTP endpoints. + This uses Azure AD authentication to generate a certificate similar to 'az ssh cert'. + + CERTIFICATE NAMING: + - Generated certificates have '-aadcert.pub' suffix (e.g., id_rsa-aadcert.pub) + - Certificates are valid for a limited time (typically 1 hour) + - Private keys are generated with 'id_rsa' name when key pair is created + + The certificate can be used with 'az sftp connect' or with standard SFTP clients. + examples: + - name: Generate a certificate using an existing public key + text: az sftp cert --public-key-file ~/.ssh/id_rsa.pub --file ~/my_cert.pub + - name: Generate a certificate and create a new key pair in the same directory + text: az sftp cert --file ~/my_cert.pub + - name: Generate a certificate with custom SSH client folder + text: az sftp cert --file ~/my_cert.pub --ssh-client-folder "C:\\Program Files\\OpenSSH" +""" + +helps['sftp connect'] = """ + type: command + short-summary: Access Azure Storage blob data via SFTP + long-summary: | + Establish an SFTP connection to an Azure Storage Account. + + AUTHENTICATION MODES: + 1. Fully managed (RECOMMENDED): Run without credentials - automatically generates SSH certificate + and establishes connection. Credentials are cleaned up after use. + + 2. Certificate-based: Use existing SSH certificate file. Certificate must be generated with + 'az sftp cert' or compatible with Azure AD authentication. + + 3. Key-based: Provide SSH keys - command will generate certificate automatically from your keys. + + CONNECTION DETAILS: + - Username format: {storage-account}.{azure-username} + - Port: Uses SSH default (typically 22) unless specified with --port + - Endpoints resolved automatically based on Azure cloud environment: + * Azure Public: {storage-account}.blob.core.windows.net + * Azure China: {storage-account}.blob.core.chinacloudapi.cn + * Azure Government: {storage-account}.blob.core.usgovcloudapi.net + + SECURITY: + - Generated credentials are automatically cleaned up after connection + - Temporary files stored in secure temporary directories + - OpenSSH handles certificate validation during connection + examples: + - name: Connect with automatic certificate generation (fully managed - RECOMMENDED) + text: az sftp connect --storage-account mystorageaccount + - name: Connect to storage account with existing certificate + text: az sftp connect --storage-account mystorageaccount --certificate-file ~/my_cert.pub + - name: Connect with existing SSH key pair + text: az sftp connect --storage-account mystorageaccount --public-key-file ~/.ssh/id_rsa.pub --private-key-file ~/.ssh/id_rsa + - name: Connect with custom port + text: az sftp connect --storage-account mystorageaccount --port 2222 + - name: Connect with additional SFTP arguments for debugging + text: az sftp connect --storage-account mystorageaccount --sftp-args="-v" + - name: Connect with custom SSH client folder (Windows) + text: az sftp connect --storage-account mystorageaccount --ssh-client-folder "C:\\Program Files\\OpenSSH" + - name: Connect with custom connection timeout + text: az sftp connect --storage-account mystorageaccount --sftp-args="-o ConnectTimeout=30" +""" diff --git a/src/sftp/azext_sftp/_params.py b/src/sftp/azext_sftp/_params.py new file mode 100644 index 00000000000..789c20169cb --- /dev/null +++ b/src/sftp/azext_sftp/_params.py @@ -0,0 +1,42 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +# pylint: disable=line-too-long + + +def load_arguments(self, _): + + with self.argument_context('sftp cert') as c: + c.argument('cert_path', options_list=['--file', '-f'], + help='The file path to write the SSH cert to, defaults to public key path with -aadcert.pub appended') + c.argument('public_key_file', options_list=['--public-key-file', '-p'], + help='The RSA public key file path. If not provided, ' + 'generated key pair is stored in the same directory as --file.') + c.argument('ssh_client_folder', options_list=['--ssh-client-folder'], + help='Folder path that contains ssh executables (ssh-keygen, ssh). ' + 'Default to ssh executables in your PATH or C:\\Windows\\System32\\OpenSSH on Windows.') + + with self.argument_context('sftp connect') as c: + c.argument('storage_account', options_list=['--storage-account', '-s'], + help='Azure Storage Account name for SFTP connection. Must have SFTP enabled.') + c.argument('port', options_list=['--port'], + help='SFTP port. If not specified, uses SSH default port (typically 22).', + type=int) + c.argument('cert_file', options_list=['--certificate-file', '-c'], + help='Path to SSH certificate file for authentication. ' + 'Must be generated with "az sftp cert" or compatible Azure AD certificate. ' + 'If not provided, certificate will be generated automatically.') + c.argument('private_key_file', options_list=['--private-key-file', '-i'], + help='Path to RSA private key file. If provided without certificate, ' + 'a certificate will be generated automatically from this key.') + c.argument('public_key_file', options_list=['--public-key-file', '-p'], + help='Path to RSA public key file. If provided without certificate, ' + 'a certificate will be generated automatically from this key.') + c.argument('sftp_args', options_list=['--sftp-args'], + help='Additional arguments to pass to the SFTP client. ' + 'Example: "-v" for verbose output, "-b batchfile.txt" for batch commands, ' + '"-o ConnectTimeout=30" for custom timeout.') + c.argument('ssh_client_folder', options_list=['--ssh-client-folder'], + help='Path to folder containing SSH client executables (ssh, sftp, ssh-keygen). ' + 'Default: Uses executables from PATH or C:\\Windows\\System32\\OpenSSH on Windows.') diff --git a/src/sftp/azext_sftp/azext_metadata.json b/src/sftp/azext_sftp/azext_metadata.json new file mode 100644 index 00000000000..0e6e6f5a0e2 --- /dev/null +++ b/src/sftp/azext_sftp/azext_metadata.json @@ -0,0 +1,4 @@ +{ + "azext.isPreview": true, + "azext.minCliCoreVersion": "2.75.0" +} diff --git a/src/sftp/azext_sftp/commands.py b/src/sftp/azext_sftp/commands.py new file mode 100644 index 00000000000..41cd795c61c --- /dev/null +++ b/src/sftp/azext_sftp/commands.py @@ -0,0 +1,13 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +"""Command definitions for the Azure CLI SFTP extension.""" + + +def load_command_table(self, _): + """Load command table for SFTP extension.""" + with self.command_group('sftp') as g: + g.custom_command('cert', 'sftp_cert') + g.custom_command('connect', 'sftp_connect') diff --git a/src/sftp/azext_sftp/constants.py b/src/sftp/azext_sftp/constants.py new file mode 100644 index 00000000000..3666831f90e --- /dev/null +++ b/src/sftp/azext_sftp/constants.py @@ -0,0 +1,37 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from colorama import Fore, Style + +# File system constants +WINDOWS_INVALID_FOLDERNAME_CHARS = "\\/*:<>?\"|" + +# Default ports +DEFAULT_SSH_PORT = DEFAULT_SFTP_PORT = AZURE_STORAGE_SFTP_PORT = 22 + +# SSH/SFTP client configuration +SSH_CONNECT_TIMEOUT = 30 +SSH_SERVER_ALIVE_INTERVAL = 60 +SSH_SERVER_ALIVE_COUNT_MAX = 3 + +# Certificate and key file naming +SSH_PRIVATE_KEY_NAME = "id_rsa" +SSH_PUBLIC_KEY_NAME = "id_rsa.pub" +SSH_CERT_SUFFIX = "-aadcert.pub" + +# Error messages and recommendations +RECOMMENDATION_SSH_CLIENT_NOT_FOUND = ( + Fore.YELLOW + + "Ensure OpenSSH is installed correctly.\n" + "Alternatively, use --ssh-client-folder to provide OpenSSH folder path." + + Style.RESET_ALL +) + +RECOMMENDATION_STORAGE_ACCOUNT_SFTP = ( + Fore.YELLOW + + "Ensure your Azure Storage Account has SFTP enabled.\n" + "Verify your account permissions include Storage Blob Data Contributor or similar." + + Style.RESET_ALL +) diff --git a/src/sftp/azext_sftp/custom.py b/src/sftp/azext_sftp/custom.py new file mode 100644 index 00000000000..035608e01a5 --- /dev/null +++ b/src/sftp/azext_sftp/custom.py @@ -0,0 +1,231 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import os +import tempfile +import shutil + +from knack import log +from azure.cli.core import azclierror +from azure.cli.core.style import Style, print_styled_text +from azure.cli.core._profile import Profile + +from . import sftp_info +from . import sftp_utils +from . import file_utils + +logger = log.get_logger(__name__) + + +def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None): + """Generate SSH certificate for SFTP authentication using Azure AD.""" + logger.debug("Starting SFTP certificate generation") + + if not cert_path and not public_key_file: + raise azclierror.RequiredArgumentMissingError("--file or --public-key-file must be provided.") + + if cert_path: + cert_path = os.path.expanduser(cert_path) + if public_key_file: + public_key_file = os.path.expanduser(public_key_file) + if ssh_client_folder: + ssh_client_folder = os.path.expanduser(ssh_client_folder) + + if cert_path and not os.path.isdir(os.path.dirname(cert_path)): + raise azclierror.InvalidArgumentValueError(f"{os.path.dirname(cert_path)} folder doesn't exist") + + if public_key_file: + public_key_file = os.path.abspath(public_key_file) + logger.debug("Using public key file: %s", public_key_file) + if cert_path: + cert_path = os.path.abspath(cert_path) + logger.debug("Certificate will be written to: %s", cert_path) + if ssh_client_folder: + ssh_client_folder = os.path.abspath(ssh_client_folder) + logger.debug("Using SSH client folder: %s", ssh_client_folder) + + keys_folder = None + if not public_key_file: + keys_folder = os.path.dirname(cert_path) + logger.debug("Will generate key pair in: %s", keys_folder) + + try: + public_key_file, _, delete_keys = file_utils.check_or_create_public_private_files( + public_key_file, None, keys_folder, ssh_client_folder) + cert_file, _ = file_utils.get_and_write_certificate(cmd, public_key_file, cert_path, ssh_client_folder) + except Exception as e: + logger.debug("Certificate generation failed: %s", str(e)) + raise + + if keys_folder and delete_keys: + logger.warning("%s contains sensitive information (id_rsa, id_rsa.pub). " + "Please delete once this certificate is no longer being used.", keys_folder) + + # pylint: disable=broad-except + try: + cert_expiration = sftp_utils.get_certificate_start_and_end_times(cert_file, ssh_client_folder)[1] + print_styled_text((Style.SUCCESS, + f"Generated SSH certificate {cert_file} is valid until {cert_expiration} in local time.")) + except Exception as e: + logger.warning("Couldn't determine certificate validity. Error: %s", str(e)) + print_styled_text((Style.SUCCESS, f"Generated SSH certificate {cert_file}.")) + + +def sftp_connect(cmd, storage_account, port=None, cert_file=None, private_key_file=None, + public_key_file=None, sftp_args=None, ssh_client_folder=None): + """Connect to Azure Storage Account via SFTP with automatic certificate generation if needed.""" + logger.debug("Starting SFTP connection to storage account: %s", storage_account) + + if cert_file: + cert_file = os.path.expanduser(cert_file) + if private_key_file: + private_key_file = os.path.expanduser(private_key_file) + if public_key_file: + public_key_file = os.path.expanduser(public_key_file) + + _assert_args(storage_account, cert_file, public_key_file, private_key_file) + + auto_generate_cert = False + delete_keys = False + delete_cert = False + credentials_folder = None + + if not cert_file and not public_key_file and not private_key_file: + logger.info("Fully managed mode: No credentials provided") + auto_generate_cert = True + delete_cert = True + delete_keys = True + credentials_folder = tempfile.mkdtemp(prefix="aadsftp") + + try: + profile = Profile(cli_ctx=cmd.cli_ctx) + profile.get_subscription() + except Exception: + if credentials_folder and os.path.isdir(credentials_folder): + shutil.rmtree(credentials_folder) + raise + + print_styled_text((Style.ACTION, "Generating temporary credentials...")) + + if cert_file and public_key_file: + print_styled_text((Style.WARNING, "Using certificate file (ignoring public key).")) + + try: + if auto_generate_cert: + public_key_file, private_key_file, _ = file_utils.check_or_create_public_private_files( + None, None, credentials_folder, ssh_client_folder) + cert_file, user = file_utils.get_and_write_certificate(cmd, public_key_file, None, ssh_client_folder) + elif not cert_file: + profile = Profile(cli_ctx=cmd.cli_ctx) + profile.get_subscription() + + public_key_file, private_key_file, _ = file_utils.check_or_create_public_private_files( + public_key_file, private_key_file, None, ssh_client_folder) + print_styled_text((Style.ACTION, "Generating certificate...")) + cert_file, user = file_utils.get_and_write_certificate(cmd, public_key_file, None, ssh_client_folder) + delete_cert = True + else: + logger.debug("Using provided certificate file...") + if not os.path.isfile(cert_file): + raise azclierror.FileOperationError(f"Certificate file {cert_file} not found.") + + user = sftp_utils.get_ssh_cert_principals(cert_file, ssh_client_folder)[0].lower() + + if '@' in user: + user = user.split('@')[0] + + username = f"{storage_account}.{user}" + + storage_suffix = _get_storage_endpoint_suffix(cmd) + hostname = f"{storage_account}.{storage_suffix}" + + sftp_session = sftp_info.SFTPSession( + storage_account=storage_account, + username=username, + host=hostname, + port=port, + cert_file=cert_file, + private_key_file=private_key_file, + sftp_args=sftp_args, + ssh_client_folder=ssh_client_folder, + ssh_proxy_folder=None, + credentials_folder=credentials_folder, + yes_without_prompt=False + ) + + sftp_session.local_user = user + sftp_session.resolve_connection_info() + + if port is not None: + print_styled_text((Style.PRIMARY, f"Connecting to {username}@{hostname}:{port}")) + else: + print_styled_text((Style.PRIMARY, f"Connecting to {username}@{hostname}")) + + _do_sftp_op(sftp_session, sftp_utils.start_sftp_connection) + + except Exception as e: + if delete_keys or delete_cert: + logger.debug("An error occurred. Cleaning up generated credentials.") + _cleanup_credentials(delete_keys, delete_cert, credentials_folder, cert_file, + private_key_file, public_key_file) + raise e + finally: + if delete_keys or delete_cert: + _cleanup_credentials(delete_keys, delete_cert, credentials_folder, cert_file, + private_key_file, public_key_file) + + +def _assert_args(storage_account, cert_file, public_key_file, private_key_file): + """Validate SFTP connection arguments.""" + if not storage_account: + raise azclierror.RequiredArgumentMissingError("Storage account name is required.") + + # Check file existence for provided files + files_to_check = [ + (cert_file, "Certificate"), + (public_key_file, "Public key"), + (private_key_file, "Private key") + ] + + for file_path, file_type in files_to_check: + if file_path: + expanded_path = os.path.expanduser(file_path) + if not os.path.isfile(expanded_path): + raise azclierror.FileOperationError(f"{file_type} file {file_path} not found.") + + +def _do_sftp_op(sftp_session, op_call): + """Execute SFTP operation with session.""" + sftp_session.validate_session() + return op_call(sftp_session) + + +def _cleanup_credentials(delete_keys, delete_cert, credentials_folder, cert_file, private_key_file, public_key_file): + """Clean up generated credentials.""" + try: + if delete_cert and cert_file and os.path.isfile(cert_file): + file_utils.delete_file(cert_file, f"Deleting generated certificate {cert_file}", warning=False) + + if delete_keys: + for key_file, key_type in [(private_key_file, "private"), (public_key_file, "public")]: + if key_file and os.path.isfile(key_file): + file_utils.delete_file(key_file, f"Deleting generated {key_type} key {key_file}", warning=False) + + if credentials_folder and os.path.isdir(credentials_folder): + logger.debug("Deleting credentials folder %s", credentials_folder) + shutil.rmtree(credentials_folder) + + except OSError as e: + logger.warning("Failed to clean up credentials: %s", str(e)) + + +def _get_storage_endpoint_suffix(cmd): + """Get the appropriate storage endpoint suffix based on Azure cloud environment.""" + cloud_suffixes = { + "azurecloud": "blob.core.windows.net", + "azurechinacloud": "blob.core.chinacloudapi.cn", + "azureusgovernment": "blob.core.usgovcloudapi.net" + } + return cloud_suffixes.get(cmd.cli_ctx.cloud.name.lower(), "blob.core.windows.net") diff --git a/src/sftp/azext_sftp/file_utils.py b/src/sftp/azext_sftp/file_utils.py new file mode 100644 index 00000000000..776b979db06 --- /dev/null +++ b/src/sftp/azext_sftp/file_utils.py @@ -0,0 +1,154 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import os +import hashlib +import json +import tempfile +import time + +from azure.cli.core import azclierror +from azure.cli.core import telemetry +from azure.cli.core._profile import Profile +from knack import log + +from . import rsa_parser +from . import sftp_utils + +logger = log.get_logger(__name__) + + +def delete_file(file_path, message, warning=False): + """Delete a file with error handling.""" + if os.path.isfile(file_path): + # pylint: disable=broad-except + try: + os.remove(file_path) + except Exception as e: + if warning: + logger.warning(message) + else: + raise azclierror.FileOperationError(f"{message}Error: {str(e)}") from e + + +def check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder, ssh_client_folder=None): + """Check for existing key files or create new ones if needed.""" + delete_keys = False + + if not public_key_file and not private_key_file: + if not credentials_folder: + credentials_folder = tempfile.mkdtemp(prefix="aadsftpcert") + else: + if not os.path.isdir(credentials_folder): + os.makedirs(credentials_folder) + + public_key_file = os.path.join(credentials_folder, "id_rsa.pub") + private_key_file = os.path.join(credentials_folder, "id_rsa") + + # Check if existing keys are present before generating new ones + if not (os.path.isfile(public_key_file) and os.path.isfile(private_key_file)): + # Only generate new keys if both don't exist + sftp_utils.create_ssh_keyfile(private_key_file, ssh_client_folder) + # Only set delete_keys to True when we actually create new keys + delete_keys = True + # If existing keys are found, delete_keys remains False + + if not public_key_file: + if private_key_file: + public_key_file = str(private_key_file) + ".pub" + else: + raise azclierror.RequiredArgumentMissingError("Public key file not specified") + + if not os.path.isfile(public_key_file): + raise azclierror.FileOperationError(f"Public key file {public_key_file} not found") + + if private_key_file: + if not os.path.isfile(private_key_file): + raise azclierror.FileOperationError(f"Private key file {private_key_file} not found") + + if not private_key_file: + if public_key_file.endswith(".pub"): + private_key_file = public_key_file[:-4] if os.path.isfile(public_key_file[:-4]) else None + + return public_key_file, private_key_file, delete_keys + + +def get_and_write_certificate(cmd, public_key_file, cert_file, ssh_client_folder): + """Generate and write an SSH certificate using Azure AD authentication.""" + cloud_scopes = { + "azurecloud": "https://pas.windows.net/CheckMyAccess/Linux/.default", + "azurechinacloud": "https://pas.chinacloudapi.cn/CheckMyAccess/Linux/.default", + "azureusgovernment": "https://pasff.usgovcloudapi.net/CheckMyAccess/Linux/.default" + } + + scope = cloud_scopes.get(cmd.cli_ctx.cloud.name.lower()) + if not scope: + raise azclierror.InvalidArgumentValueError( + f"Unsupported cloud {cmd.cli_ctx.cloud.name.lower()}", + "Supported clouds include azurecloud,azurechinacloud,azureusgovernment") + + data = _prepare_jwk_data(public_key_file) + profile = Profile(cli_ctx=cmd.cli_ctx) + t0 = time.time() + + if hasattr(profile, "get_msal_token"): + _, certificate = profile.get_msal_token([scope], data) + else: + credential, _, _ = profile.get_login_credentials(subscription_id=profile.get_subscription()["id"]) + certificatedata = credential.get_token(scope, data=data) + certificate = certificatedata.token + + time_elapsed = time.time() - t0 + telemetry.add_extension_event('sftp', {'Context.Default.AzureCLI.SFTPGetCertificateTime': time_elapsed}) + + if not cert_file: + cert_file = str(public_key_file.removesuffix(".pub")) + "-aadcert.pub" + + logger.debug("Generating certificate %s", cert_file) + _write_cert_file(certificate, cert_file) + username = sftp_utils.get_ssh_cert_principals(cert_file, ssh_client_folder)[0] + + return cert_file, username.lower() + + +def _prepare_jwk_data(public_key_file): + """Prepare JWK data for certificate request.""" + modulus, exponent = _get_modulus_exponent(public_key_file) + key_hash = hashlib.sha256() + key_hash.update(modulus.encode('utf-8')) + key_hash.update(exponent.encode('utf-8')) + key_id = key_hash.hexdigest() + + jwk = {"kty": "RSA", "n": modulus, "e": exponent, "kid": key_id} + + return { + "token_type": "ssh-cert", + "req_cnf": json.dumps(jwk), + "key_id": key_id + } + + +def _write_cert_file(certificate_contents, cert_file): + """Write SSH certificate to file.""" + with open(cert_file, 'w', encoding='utf-8') as f: + f.write(f"ssh-rsa-cert-v01@openssh.com {certificate_contents}") + return cert_file + + +def _get_modulus_exponent(public_key_file): + """Extract modulus and exponent from RSA public key file.""" + if not os.path.isfile(public_key_file): + raise azclierror.FileOperationError(f"Public key file '{public_key_file}' was not found") + + with open(public_key_file, 'r', encoding='utf-8') as f: + public_key_text = f.read() + + parser = rsa_parser.RSAParser() + try: + parser.parse(public_key_text) + except Exception as e: + raise azclierror.FileOperationError(f"Could not parse public key. Error: {str(e)}") + + return parser.modulus, parser.exponent diff --git a/src/sftp/azext_sftp/rsa_parser.py b/src/sftp/azext_sftp/rsa_parser.py new file mode 100644 index 00000000000..b7df98fbdfc --- /dev/null +++ b/src/sftp/azext_sftp/rsa_parser.py @@ -0,0 +1,54 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import base64 +import struct + + +# pylint: disable=too-few-public-methods +class RSAParser: + RSAAlgorithm = 'ssh-rsa' + + def __init__(self): + self.algorithm = '' + self.modulus = '' + self.exponent = '' + self._key_length_big_endian = True + + def parse(self, public_key_text): + text_parts = public_key_text.split(' ') + if len(text_parts) < 2: + raise ValueError("Incorrectly formatted public key. " + "Key must be format ' '") + + algorithm = text_parts[0] + if algorithm != self.RSAAlgorithm: + raise ValueError(f"Public key is not ssh-rsa algorithm ({algorithm})") + + key_bytes = base64.b64decode(text_parts[1]) + fields = list(self._get_fields(key_bytes)) + if len(fields) < 3: + raise ValueError("Incorrectly encoded public key. " + "Encoded key must be base64 encoded ") + + encoded_algorithm = fields[0].decode("ascii") + if encoded_algorithm != self.RSAAlgorithm: + raise ValueError(f"Encoded public key is not ssh-rsa algorithm ({encoded_algorithm})") + + self.algorithm = encoded_algorithm + self.exponent = base64.urlsafe_b64encode(fields[1]).decode("ascii") + self.modulus = base64.urlsafe_b64encode(fields[2]).decode("ascii") + + def _get_fields(self, key_bytes): + read = 0 + while read < len(key_bytes): + length = struct.unpack(self._get_struct_format(), key_bytes[read:read + 4])[0] + read = read + 4 + data = key_bytes[read:read + length] + read = read + length + yield data + + def _get_struct_format(self): + return ">" + "L" if self._key_length_big_endian else "= retry_attempts_allowed: + raise azclierror.UnclassifiedUserFault(error_msg, const.RECOMMENDATION_SSH_CLIENT_NOT_FOUND) + logger.warning("%s. Retrying...", error_msg) + + if duration is not None: + logger.debug("Connection attempt %d duration: %.2f seconds", attempt + 1, duration) + if attempt < retry_attempts_allowed: + time.sleep(1) + + raise azclierror.UnclassifiedUserFault( + "Failed to establish SFTP connection after multiple attempts.", + "Please check your network connection, credentials, and that the SFTP server is accessible." + ) + + except KeyboardInterrupt: + logger.info("SFTP connection interrupted by user") + print("\nSFTP session exited cleanly.") + + +def create_ssh_keyfile(private_key_file, ssh_client_folder=None): + """Create an SSH key file using ssh-keygen.""" + sshkeygen_path = get_ssh_client_path("ssh-keygen", ssh_client_folder) + command = [sshkeygen_path, "-f", private_key_file, "-t", "rsa", "-q", "-N", ""] + logger.debug("Running ssh-keygen command %s", ' '.join(command)) + try: + subprocess.call(command) + except OSError as e: + colorama.init() + raise azclierror.BadRequestError(f"Failed to create ssh key file with error: {str(e)}.", + const.RECOMMENDATION_SSH_CLIENT_NOT_FOUND) + + +def get_ssh_cert_principals(cert_file, ssh_client_folder=None): + """Extract principals from SSH certificate.""" + info = get_ssh_cert_info(cert_file, ssh_client_folder) + principals = [] + in_principal = False + for line in info: + if ":" in line: + in_principal = False + if "Principals:" in line: + in_principal = True + continue + if in_principal: + principals.append(line.strip()) + return principals + + +def get_ssh_cert_info(cert_file, ssh_client_folder=None): + """Get SSH certificate information using ssh-keygen.""" + sshkeygen_path = get_ssh_client_path("ssh-keygen", ssh_client_folder) + command = [sshkeygen_path, "-L", "-f", cert_file] + logger.debug("Running ssh-keygen command %s", ' '.join(command)) + try: + return subprocess.check_output(command).decode().splitlines() + except OSError as e: + colorama.init() + raise azclierror.BadRequestError(f"Failed to get certificate info with error: {str(e)}.", + const.RECOMMENDATION_SSH_CLIENT_NOT_FOUND) + + +_warned_ssh_client_folders = set() + + +def get_ssh_client_path(ssh_command="ssh", ssh_client_folder=None): + """Get the path to an SSH client executable.""" + if ssh_client_folder: + ssh_path = os.path.join(ssh_client_folder, ssh_command) + if platform.system() == 'Windows': + ssh_path += '.exe' + if os.path.isfile(ssh_path): + logger.debug("Attempting to run %s from path %s", ssh_command, ssh_path) + return ssh_path + warn_key = (ssh_command, os.path.abspath(ssh_client_folder)) + if warn_key not in _warned_ssh_client_folders: + logger.warning("Could not find %s in provided --ssh-client-folder %s. " + "Attempting to get pre-installed OpenSSH bits.", ssh_command, ssh_client_folder) + _warned_ssh_client_folders.add(warn_key) + + if platform.system() != 'Windows': + return ssh_command + + # Windows-specific logic + machine = platform.machine() + if not machine.endswith(('64', '86')): + if machine == '': + raise azclierror.BadRequestError("Couldn't identify the OS architecture.") + raise azclierror.BadRequestError(f"Unsupported OS architecture: {machine} is not currently supported") + + # Determine system path + is_64bit = machine.endswith('64') + is_32bit_python = platform.architecture()[0] == '32bit' + sys_path = 'SysNative' if is_64bit and is_32bit_python else 'System32' + + system_root = os.environ['SystemRoot'] + ssh_path = os.path.join(system_root, sys_path, "openSSH", f"{ssh_command}.exe") + + logger.debug("Platform architecture: %s", platform.architecture()[0]) + logger.debug("OS architecture: %s", '64bit' if is_64bit else '32bit') + logger.debug("System Root: %s", system_root) + logger.debug("Attempting to run %s from path %s", ssh_command, ssh_path) + + if not os.path.isfile(ssh_path): + raise azclierror.UnclassifiedUserFault( + f"Could not find {ssh_command}.exe on path {ssh_path}. ", + colorama.Fore.YELLOW + "Make sure OpenSSH is installed correctly: " + "https://docs.microsoft.com/en-us/windows-server/administration/openssh/openssh_install_firstuse . " + "Or use --ssh-client-folder to provide folder path with ssh executables. " + colorama.Style.RESET_ALL) + + return ssh_path + + +def get_certificate_start_and_end_times(cert_file, ssh_client_folder=None): + """Get start and end times from SSH certificate validity.""" + validity_str = _get_ssh_cert_validity(cert_file, ssh_client_folder) + times = None + if validity_str and "Valid: from " in validity_str and " to " in validity_str: + try: + times = validity_str.replace("Valid: from ", "").split(" to ") + t0 = datetime.datetime.strptime(times[0], '%Y-%m-%dT%X') + t1 = datetime.datetime.strptime(times[1], '%Y-%m-%dT%X') + times = (t0, t1) + except (ValueError, TypeError, IndexError): + # Invalid date format or parsing error + times = None + return times + + +def _get_ssh_cert_validity(cert_file, ssh_client_folder=None): + """Get validity line from SSH certificate info.""" + if cert_file: + info = get_ssh_cert_info(cert_file, ssh_client_folder) + for line in info: + if "Valid:" in line: + return line.strip() + return None diff --git a/src/sftp/azext_sftp/tests/__init__.py b/src/sftp/azext_sftp/tests/__init__.py new file mode 100644 index 00000000000..2dcf9bb68b3 --- /dev/null +++ b/src/sftp/azext_sftp/tests/__init__.py @@ -0,0 +1,5 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ----------------------------------------------------------------------------- \ No newline at end of file diff --git a/src/sftp/azext_sftp/tests/latest/__init__.py b/src/sftp/azext_sftp/tests/latest/__init__.py new file mode 100644 index 00000000000..2dcf9bb68b3 --- /dev/null +++ b/src/sftp/azext_sftp/tests/latest/__init__.py @@ -0,0 +1,5 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ----------------------------------------------------------------------------- \ No newline at end of file diff --git a/src/sftp/azext_sftp/tests/latest/test_custom.py b/src/sftp/azext_sftp/tests/latest/test_custom.py new file mode 100644 index 00000000000..fc1aa4645fe --- /dev/null +++ b/src/sftp/azext_sftp/tests/latest/test_custom.py @@ -0,0 +1,685 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import io +import unittest +import pytest +import json +from unittest import mock +from azext_sftp import custom + +from azure.cli.core import azclierror +import tempfile +import os +import shutil + + +class SftpCustomCommandTest(unittest.TestCase): + """Test suite for SFTP custom commands.""" + + def setUp(self): + """Set up test fixtures before each test method.""" + super().setUp() + # Set up temporary directory for test files + self.temp_dir = tempfile.mkdtemp(prefix="sftp_test_") + self.mock_cert_file = os.path.join(self.temp_dir, "test_cert.pub") + self.mock_private_key = os.path.join(self.temp_dir, "test_key") + self.mock_public_key = os.path.join(self.temp_dir, "test_key.pub") + + # Create mock files + with open(self.mock_cert_file, 'w') as f: + f.write("ssh-rsa-cert-v01@openssh.com MOCK_CERT_DATA") + with open(self.mock_private_key, 'w') as f: + f.write("-----BEGIN OPENSSH PRIVATE KEY-----\nMOCK_PRIVATE_KEY\n-----END OPENSSH PRIVATE KEY-----") + with open(self.mock_public_key, 'w') as f: + f.write("ssh-rsa AAAAB3NzaC1yc2EAAA mock@test.com") + + def tearDown(self): + """Tear down test fixtures after each test method.""" + super().tearDown() + # Clean up temporary directory + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_sftp_cert_basic_error_cases(self): + """Test basic sftp cert error cases with parameterized inputs.""" + basic_error_cases = [ + # (description, exception_type, cert_path, public_key_file, setup_mocks) + ("no arguments provided", azclierror.RequiredArgumentMissingError, None, None, {}), + ("certificate directory doesn't exist", azclierror.InvalidArgumentValueError, "cert", None, {"isdir_return": False}), + ] + + for description, exception_type, cert_path, public_key_file, setup_mocks in basic_error_cases: + with self.subTest(case=description): + cmd = mock.Mock() + + # Apply setup mocks + patches = [] + if "isdir_return" in setup_mocks: + patches.append(mock.patch('os.path.isdir', return_value=setup_mocks["isdir_return"])) + + for patch in patches: + patch.start() + + try: + with self.assertRaises(exception_type): + custom.sftp_cert(cmd, cert_path=cert_path, public_key_file=public_key_file) + finally: + for patch in patches: + patch.stop() + + @mock.patch('os.path.isdir') + @mock.patch('os.path.abspath') + @mock.patch('azext_sftp.file_utils.check_or_create_public_private_files') + @mock.patch('azext_sftp.file_utils.get_and_write_certificate') + def test_sftp_cert(self, mock_write_cert, mock_get_keys, mock_abspath, mock_isdir): + """Test successful certificate generation.""" + cmd = mock.Mock() + mock_isdir.return_value = True + mock_abspath.side_effect = ['/pubkey/path', '/cert/path', '/client/path'] + mock_get_keys.return_value = "pubkey", "privkey", False + mock_write_cert.return_value = "cert", "username" + + custom.sftp_cert(cmd, "cert", "pubkey") + + mock_get_keys.assert_called_once_with('/pubkey/path', None, None, None) + mock_write_cert.assert_called_once_with(cmd, 'pubkey', '/cert/path', None) + + @mock.patch('azext_sftp.custom._do_sftp_op') + @mock.patch('azext_sftp.sftp_utils.get_ssh_cert_principals') + def test_sftp_connect_certificate_scenarios(self, mock_get_principals, mock_do_sftp): + """Test connect with various certificate scenarios.""" + # Test cases: (description, cert_file, public_key_file, private_key_file, expected_calls) + cert_test_cases = [ + ("valid cert provided", self.mock_cert_file, None, self.mock_private_key, "cert_used"), + ("cert and public key both provided", self.mock_cert_file, self.mock_public_key, self.mock_private_key, "cert_used"), + ("existing cert with existing private key", self.mock_cert_file, None, self.mock_private_key, "cert_used"), + ] + + for description, cert_file, public_key_file, private_key_file, expected_calls in cert_test_cases: + with self.subTest(case=description): + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = "azurecloud" + mock_get_principals.return_value = ["testuser@domain.com"] + mock_do_sftp.return_value = None + + # Reset mocks for each test case + mock_get_principals.reset_mock() + mock_do_sftp.reset_mock() + + custom.sftp_connect( + cmd=cmd, + storage_account="teststorage", + port=22, + cert_file=cert_file, + public_key_file=public_key_file, + private_key_file=private_key_file, + sftp_args=['-b', '/dev/stdin'] # Use sftp_args for batch mode + ) + + # Verify certificate was used + if expected_calls == "cert_used": + mock_get_principals.assert_called_once() + mock_do_sftp.assert_called_once() + + + + @mock.patch('azext_sftp.custom._assert_args') + @mock.patch('azext_sftp.custom.Profile') + @mock.patch('azext_sftp.custom._get_storage_endpoint_suffix') + @mock.patch('azext_sftp.custom._do_sftp_op') + @mock.patch('azext_sftp.file_utils.get_and_write_certificate') + @mock.patch('azext_sftp.file_utils.check_or_create_public_private_files') + @mock.patch('tempfile.mkdtemp') + def test_sftp_connect_key_generation_scenarios(self, mock_mkdtemp, mock_create_keys, mock_gen_cert, mock_do_sftp, mock_get_suffix, mock_profile, mock_assert_args): + """Test connect with various key generation scenarios.""" + # Test cases: (description, public_key_file, private_key_file, keys_generated, expected_create_keys_args_template) + key_gen_test_cases = [ + ("no cert auto generate", None, None, True, (None, None, "{temp_dir}", None)), + ("public key provided no cert", self.mock_public_key, None, False, (self.mock_public_key, None, None, None)), + ("private key provided no cert", None, self.mock_private_key, False, (None, self.mock_private_key, None, None)), + ] + + for description, public_key_file, private_key_file, keys_generated, expected_create_keys_args_template in key_gen_test_cases: + with self.subTest(case=description): + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = "azurecloud" + + # Mock Profile and subscription check + mock_profile_instance = mock.Mock() + mock_profile.return_value = mock_profile_instance + mock_profile_instance.get_subscription.return_value = {"id": "test-subscription-id"} + + # Reset mocks for each test case + mock_mkdtemp.reset_mock() + mock_create_keys.reset_mock() + mock_gen_cert.reset_mock() + mock_do_sftp.reset_mock() + + # Setup mocks + mock_assert_args.return_value = None # Skip argument validation + mock_mkdtemp.return_value = self.temp_dir + mock_create_keys.return_value = (self.mock_public_key, self.mock_private_key, keys_generated) + mock_gen_cert.return_value = (self.mock_cert_file, "testuser") + mock_do_sftp.return_value = None + mock_get_suffix.return_value = "blob.core.windows.net" + + custom.sftp_connect( + cmd=cmd, + storage_account="teststorage", + port=22, + public_key_file=public_key_file, + private_key_file=private_key_file, + sftp_args=['-b', '/dev/stdin'] # Use sftp_args for batch mode + ) + + # Build expected args with actual temp_dir value + expected_create_keys_args = tuple( + self.temp_dir if arg == "{temp_dir}" else arg + for arg in expected_create_keys_args_template + ) + + # Verify function calls + mock_create_keys.assert_called_once_with(*expected_create_keys_args) + mock_gen_cert.assert_called_once() + mock_do_sftp.assert_called_once() + + def test_sftp_connect_error_cases(self): + """Test connect error cases with parameterized inputs.""" + error_test_cases = [ + # (description, exception_type, kwargs) + ("invalid/missing private key file", azclierror.FileOperationError, {"private_key_file": "/nonexistent/key"}), + ("invalid/missing public key file", azclierror.FileOperationError, {"public_key_file": "/nonexistent/key.pub"}), + ("invalid/missing certificate file", azclierror.FileOperationError, {"cert_file": "/nonexistent/cert.pub", "private_key_file": self.mock_private_key}), + ("missing storage account", azclierror.RequiredArgumentMissingError, {"storage_account": None}), + ] + + for description, exception_type, kwargs in error_test_cases: + with self.subTest(case=description): + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = "azurecloud" + + base_kwargs = {"cmd": cmd, "storage_account": "teststorage", "port": 22} + base_kwargs.update(kwargs) + + with self.assertRaises(exception_type): + custom.sftp_connect(**base_kwargs) + + @mock.patch('azext_sftp.custom._do_sftp_op') + @mock.patch('azext_sftp.sftp_utils.get_ssh_cert_principals') + def test_sftp_connect_port_configurations(self, mock_get_principals, mock_do_sftp): + """Test connect with different port configurations.""" + port_test_cases = [ + # (description, port_value, expected_port) + ("default port (None)", None, None), + ("custom port", 2222, 2222), + ] + + for description, port_value, expected_port in port_test_cases: + with self.subTest(case=description): + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = "azurecloud" + mock_get_principals.return_value = ["testuser@domain.com"] + mock_do_sftp.return_value = None + + # Reset mock for each test case + mock_do_sftp.reset_mock() + + custom.sftp_connect( + cmd=cmd, + storage_account="teststorage", + port=port_value, + cert_file=self.mock_cert_file, + private_key_file=self.mock_private_key, + sftp_args=['-b', '/dev/stdin'] # Use sftp_args for batch mode + ) + + # Verify the session was created with expected port + mock_do_sftp.assert_called_once() + call_args = mock_do_sftp.call_args[0] + sftp_session = call_args[0] # First argument is the SFTP session + self.assertEqual(sftp_session.port, expected_port) + + @mock.patch('azext_sftp.custom._do_sftp_op') + @mock.patch('azext_sftp.sftp_utils.get_ssh_cert_principals') + def test_sftp_connect_argument_combinations(self, mock_get_principals, mock_do_sftp): + """Test sftp_connect with various argument combinations to ensure comprehensive coverage.""" + # Test cases: (description, kwargs, expected_behavior) + test_cases = [ + ("minimal args with cert", {"cert_file": self.mock_cert_file, "private_key_file": self.mock_private_key}, "success"), + ("with custom port", {"cert_file": self.mock_cert_file, "private_key_file": self.mock_private_key, "port": 2222}, "success"), + ("with sftp_args", {"cert_file": self.mock_cert_file, "private_key_file": self.mock_private_key, "sftp_args": "-v"}, "success"), + ("with ssh_client_folder", {"cert_file": self.mock_cert_file, "private_key_file": self.mock_private_key, "ssh_client_folder": "ssh_folder"}, "success"), + ("with sftp_args for batch", {"cert_file": self.mock_cert_file, "private_key_file": self.mock_private_key, "sftp_args": ["-b", "batchfile.txt"]}, "success"), + ("all args combined", { + "cert_file": self.mock_cert_file, + "private_key_file": self.mock_private_key, + "port": 2222, + "sftp_args": ["-v", "-o", "StrictHostKeyChecking=no", "-b", "batchfile.txt"], + "ssh_client_folder": "ssh_folder" + }, "success"), + ] + + for description, kwargs, expected_behavior in test_cases: + with self.subTest(case=description): + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = "azurecloud" + mock_get_principals.return_value = ["testuser@domain.com"] + mock_do_sftp.return_value = None + + # Reset mocks for each test case + mock_do_sftp.reset_mock() + + base_kwargs = { + "cmd": cmd, + "storage_account": "teststorage" + } + base_kwargs.update(kwargs) + + if expected_behavior == "success": + custom.sftp_connect(**base_kwargs) + mock_do_sftp.assert_called_once() + + # Verify specific arguments were passed correctly + call_args = mock_do_sftp.call_args[0] + sftp_session = call_args[0] + + # Check that the session object has the expected properties + if "port" in kwargs: + self.assertEqual(sftp_session.port, kwargs["port"]) + if "sftp_args" in kwargs: + self.assertEqual(sftp_session.sftp_args, kwargs["sftp_args"]) + if "ssh_client_folder" in kwargs: + # Just check that ssh_client_folder was set - path may be normalized + self.assertIsNotNone(sftp_session.ssh_client_folder) + self.assertIn("ssh_folder", sftp_session.ssh_client_folder) + + def test_sftp_connect_sftp_args_variations(self): + """Test different sftp_args formats and common SSH options.""" + sftp_args_cases = [ + # (description, sftp_args) + ("None", None), + ("verbose flag", "-v"), + ("multiple flags", "-v -o StrictHostKeyChecking=no"), + ("compression", "-C"), + ("custom identity file", "-i /path/to/custom/key"), + ("timeout setting", "-o ConnectTimeout=30"), + ("complex args", "-v -C -o StrictHostKeyChecking=no -o ConnectTimeout=30"), + ] + + for description, sftp_args in sftp_args_cases: + with self.subTest(case=description): + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = "azurecloud" + + with mock.patch('azext_sftp.custom._do_sftp_op') as mock_do_sftp, \ + mock.patch('azext_sftp.sftp_utils.get_ssh_cert_principals', return_value=["testuser@domain.com"]): + + mock_do_sftp.return_value = None + + custom.sftp_connect( + cmd=cmd, + storage_account="teststorage", + cert_file=self.mock_cert_file, + private_key_file=self.mock_private_key, + sftp_args=sftp_args + ) + + mock_do_sftp.assert_called_once() + call_args = mock_do_sftp.call_args[0] + sftp_session = call_args[0] + + # Verify sftp_args are set correctly + if sftp_args is None: + self.assertEqual(sftp_session.sftp_args, []) # None becomes empty list + else: + self.assertEqual(sftp_session.sftp_args, sftp_args) + + def test_sftp_connect_ssh_client_folder_variations(self): + """Test different ssh_client_folder path formats.""" + ssh_folder_cases = [ + # (description, ssh_client_folder) + ("None", None), + ("relative path", "ssh_client"), + ("absolute path", "/tmp/ssh"), + ] + + for description, ssh_client_folder in ssh_folder_cases: + with self.subTest(case=description): + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = "azurecloud" + + with mock.patch('azext_sftp.custom._do_sftp_op') as mock_do_sftp, \ + mock.patch('azext_sftp.sftp_utils.get_ssh_cert_principals', return_value=["testuser@domain.com"]): + + mock_do_sftp.return_value = None + + custom.sftp_connect( + cmd=cmd, + storage_account="teststorage", + cert_file=self.mock_cert_file, + private_key_file=self.mock_private_key, + ssh_client_folder=ssh_client_folder + ) + + mock_do_sftp.assert_called_once() + call_args = mock_do_sftp.call_args[0] + sftp_session = call_args[0] + + # Verify ssh_client_folder is set correctly + if ssh_client_folder is None: + self.assertIsNone(sftp_session.ssh_client_folder) + else: + # Path will be converted to absolute path, so just check it's not None + self.assertIsNotNone(sftp_session.ssh_client_folder) + + @mock.patch('azext_sftp.file_utils.get_and_write_certificate') + @mock.patch('azext_sftp.file_utils.check_or_create_public_private_files') + @mock.patch('os.path.isdir') + @mock.patch('os.path.abspath') + def test_sftp_cert_key_generation_warning(self, mock_abspath, mock_isdir, mock_check_files, mock_write_cert): + """Test that warning is displayed when keys are generated. + + When keys are generated, user should be warned about sensitive information. + """ + # Setup mocks + cmd = mock.Mock() + mock_isdir.return_value = True + mock_abspath.side_effect = lambda x: x # Return input unchanged + mock_check_files.return_value = (self.mock_public_key, self.mock_private_key, True) + mock_write_cert.return_value = (self.mock_cert_file, "testuser@domain.com") + + # Mock logger to capture warning + with mock.patch('azext_sftp.custom.logger') as mock_logger: + custom.sftp_cert(cmd, cert_path=self.mock_cert_file) + + # Verify warning is logged when keys are generated + mock_logger.warning.assert_called() + # Check all warning calls to find the sensitive information one + warning_calls = [call[0][0] for call in mock_logger.warning.call_args_list] + sensitive_info_warning = next((call for call in warning_calls if "contains sensitive information" in call), None) + self.assertIsNotNone(sensitive_info_warning, "Sensitive information warning not found") + self.assertIn("id_rsa", sensitive_info_warning) + + @mock.patch('azext_sftp.file_utils.check_or_create_public_private_files') + @mock.patch('os.path.isdir') + @mock.patch('os.path.abspath') + def test_sftp_cert_certificate_generation_failure(self, mock_abspath, mock_isdir, mock_check_files): + """Test proper error handling when certificate generation fails.""" + # Setup mocks + cmd = mock.Mock() + mock_isdir.return_value = True + mock_abspath.side_effect = lambda x: x # Return input unchanged + mock_check_files.return_value = (self.mock_public_key, None, False) + + # Mock certificate generation to fail + with mock.patch('azext_sftp.file_utils.get_and_write_certificate') as mock_write_cert: + mock_write_cert.side_effect = Exception("Certificate generation failed") + + with mock.patch('azext_sftp.custom.logger') as mock_logger: + with self.assertRaises(Exception): + custom.sftp_cert(cmd, cert_path=self.mock_cert_file, + public_key_file=self.mock_public_key) + + # Verify error is logged - certificate generation failed so exception should be raised + # The debug logging might not be called depending on where the exception occurs + + @mock.patch('azext_sftp.file_utils.get_and_write_certificate') + @mock.patch('azext_sftp.file_utils.check_or_create_public_private_files') + @mock.patch('os.path.isdir') + @mock.patch('os.path.abspath') + def test_sftp_cert_parameter_combinations(self, mock_abspath, mock_isdir, mock_check_files, mock_write_cert): + """Test sftp cert with all valid parameter combinations using subTest.""" + # Test cases: (cert_path, public_key_file, ssh_client_folder, expected_keys_folder, description) + test_cases = [ + (None, "pubkey.pub", None, None, "public_key_only"), + ("cert.pub", None, None, "cert_dir", "cert_path_only"), + ("cert.pub", None, "/ssh", "cert_dir", "cert_path_with_ssh_client"), + (None, "pubkey.pub", "/ssh", None, "public_key_with_ssh_client"), + ("cert.pub", "pubkey.pub", None, None, "cert_path_with_public_key"), + ("cert.pub", "pubkey.pub", "/ssh", None, "all_parameters"), + ] + + for cert_path, public_key_file, ssh_client_folder, expected_keys_folder, description in test_cases: + with self.subTest(case=description): + # Reset mocks and setup + mock_check_files.reset_mock() + mock_write_cert.reset_mock() + cmd = mock.Mock() + mock_isdir.return_value = True + mock_abspath.side_effect = lambda x: x + + # Configure mocks based on test case + keys_generated = public_key_file is None + effective_public_key = public_key_file or self.mock_public_key + mock_check_files.return_value = (effective_public_key, self.mock_private_key if keys_generated else None, keys_generated) + + expected_cert = (effective_public_key + "-aadcert.pub") if cert_path is None else cert_path + mock_write_cert.return_value = (expected_cert, "testuser@domain.com") + + # Execute test + custom.sftp_cert(cmd, cert_path=cert_path, public_key_file=public_key_file, ssh_client_folder=ssh_client_folder) + + # Verify calls + expected_keys_dir = os.path.dirname(cert_path) if expected_keys_folder == "cert_dir" else expected_keys_folder + mock_check_files.assert_called_once_with(public_key_file, None, expected_keys_dir, ssh_client_folder) + mock_write_cert.assert_called_once_with(cmd, effective_public_key, cert_path, ssh_client_folder) + + def test_sftp_cert_error_cases(self): + """Test sftp cert error handling with invalid argument combinations.""" + # Test cases: (cert_path, public_key_file, setup_mocks, expected_exception, expected_message, description) + test_cases = [ + (None, None, {}, azclierror.RequiredArgumentMissingError, "--file or --public-key-file must be provided", "no_arguments"), + ("/bad/cert.pub", None, {"expanduser_return": "/bad/cert.pub", "isdir_return": False}, azclierror.InvalidArgumentValueError, "folder doesn't exist", "invalid_directory"), + ] + + for cert_path, public_key_file, setup_mocks, expected_exception, expected_message, description in test_cases: + with self.subTest(case=description): + cmd = mock.Mock() + patches = [] + + # Apply setup mocks + if "expanduser_return" in setup_mocks: + patches.append(mock.patch('os.path.expanduser', return_value=setup_mocks["expanduser_return"])) + if "isdir_return" in setup_mocks: + patches.append(mock.patch('os.path.isdir', return_value=setup_mocks["isdir_return"])) + + for patch in patches: + patch.start() + + try: + with self.assertRaises(expected_exception) as context: + custom.sftp_cert(cmd, cert_path=cert_path, public_key_file=public_key_file) + self.assertIn(expected_message, str(context.exception)) + finally: + for patch in patches: + patch.stop() + + @mock.patch('os.path.expanduser') + @mock.patch('os.path.abspath') + def test_sftp_cert_path_expansion(self, mock_abspath, mock_expanduser): + """Test that all path arguments are properly expanded from ~ to full paths.""" + # Test cases: (cert_path, public_key_file, ssh_client_folder, description) + test_cases = [ + ("~/cert.pub", "~/.ssh/id_rsa.pub", "~/ssh_client", "all_paths_with_tilde"), + (None, "~/.ssh/id_rsa.pub", "~/ssh_client", "public_key_and_ssh_client_with_tilde"), + ("~/cert.pub", None, "~/ssh_client", "cert_path_and_ssh_client_with_tilde"), + ] + + for cert_path, public_key_file, ssh_client_folder, description in test_cases: + with self.subTest(case=description): + mock_expanduser.reset_mock() + mock_abspath.reset_mock() + cmd = mock.Mock() + + # Setup mocks + mock_expanduser.side_effect = lambda x: x.replace('~', '/home/user') if x else x + mock_abspath.side_effect = lambda x: '/absolute' + x if x else x + + # Mock dependencies + with mock.patch('os.path.isdir', return_value=True), \ + mock.patch('azext_sftp.file_utils.check_or_create_public_private_files') as mock_check_files, \ + mock.patch('azext_sftp.file_utils.get_and_write_certificate') as mock_write_cert: + + mock_check_files.return_value = ("/absolute/home/user/.ssh/id_rsa.pub", None, False) + mock_write_cert.return_value = ("/absolute/home/user/cert.pub", "user@domain.com") + + # Execute test + custom.sftp_cert(cmd, cert_path=cert_path, public_key_file=public_key_file, ssh_client_folder=ssh_client_folder) + + # Verify path expansion for tilde paths + expected_calls = [mock.call(path) for path in [cert_path, public_key_file, ssh_client_folder] if path and path.startswith('~')] + if expected_calls: + mock_expanduser.assert_has_calls(expected_calls, any_order=True) + self.assertTrue(mock_abspath.called) + + def test_sftp_cert_valid_minimal_call(self): + """Test that a minimal valid call works correctly.""" + cmd = mock.Mock() + + with mock.patch('os.path.expanduser', side_effect=lambda x: x), \ + mock.patch('os.path.abspath', side_effect=lambda x: x), \ + mock.patch('os.path.isdir', return_value=True), \ + mock.patch('azext_sftp.file_utils.check_or_create_public_private_files') as mock_check_files, \ + mock.patch('azext_sftp.file_utils.get_and_write_certificate') as mock_write_cert: + + mock_check_files.return_value = ("pubkey.pub", None, False) + mock_write_cert.return_value = ("pubkey.pub-aadcert.pub", "user@domain.com") + + # Should not raise any exception + custom.sftp_cert(cmd, public_key_file="pubkey.pub") + + # Verify function was called correctly + mock_check_files.assert_called_once_with("pubkey.pub", None, None, None) + + # Additional tests for private helper functions + + def test_assert_args_validation(self): + """Test _assert_args function with various input combinations.""" + # Test cases: (storage_account, cert_file, public_key_file, private_key_file, expected_exception, description) + test_cases = [ + (None, None, None, None, azclierror.RequiredArgumentMissingError, "missing storage account"), + ("test", "/nonexistent/cert.pub", None, None, azclierror.FileOperationError, "invalid cert file"), + ("test", None, "/nonexistent/key.pub", None, azclierror.FileOperationError, "invalid public key file"), + ("test", None, None, "/nonexistent/key", azclierror.FileOperationError, "invalid private key file"), + ("test", self.mock_cert_file, self.mock_public_key, self.mock_private_key, None, "all valid files"), + ] + + for storage_account, cert_file, public_key_file, private_key_file, expected_exception, description in test_cases: + with self.subTest(case=description): + if expected_exception: + with self.assertRaises(expected_exception): + custom._assert_args(storage_account, cert_file, public_key_file, private_key_file) + else: + # Should not raise any exception + custom._assert_args(storage_account, cert_file, public_key_file, private_key_file) + + def test_do_sftp_op_execution(self): + """Test _do_sftp_op function with mock session and operation.""" + mock_session = mock.Mock() + mock_session.validate_session.return_value = None + + mock_operation = mock.Mock() + mock_operation.return_value = "operation_result" + + result = custom._do_sftp_op(mock_session, mock_operation) + + mock_session.validate_session.assert_called_once() + mock_operation.assert_called_once_with(mock_session) + self.assertEqual(result, "operation_result") + + def test_cleanup_credentials_selective_cleanup(self): + """Test _cleanup_credentials with different cleanup scenarios.""" + # Create test files for cleanup testing + temp_cert = os.path.join(self.temp_dir, "test_cert.pub") + temp_private = os.path.join(self.temp_dir, "test_private_key") + temp_public = os.path.join(self.temp_dir, "test_public_key.pub") + temp_credentials_dir = os.path.join(self.temp_dir, "credentials") + + # Create the files and directory + with open(temp_cert, 'w') as f: + f.write("cert content") + with open(temp_private, 'w') as f: + f.write("private key") + with open(temp_public, 'w') as f: + f.write("public key") + os.makedirs(temp_credentials_dir) + + # Test cases: (delete_keys, delete_cert, credentials_folder, description) + cleanup_cases = [ + (True, True, temp_credentials_dir, "cleanup all"), + (True, False, None, "cleanup keys only"), + (False, True, None, "cleanup cert only"), + (False, False, temp_credentials_dir, "cleanup folder only"), + (False, False, None, "cleanup nothing"), + ] + + for delete_keys, delete_cert, credentials_folder, description in cleanup_cases: + with self.subTest(case=description): + # Recreate files for each test + if not os.path.exists(temp_cert): + with open(temp_cert, 'w') as f: + f.write("cert content") + if not os.path.exists(temp_private): + with open(temp_private, 'w') as f: + f.write("private key") + if not os.path.exists(temp_public): + with open(temp_public, 'w') as f: + f.write("public key") + if not os.path.exists(temp_credentials_dir): + os.makedirs(temp_credentials_dir) + + # Mock file_utils.delete_file to avoid actual deletion but track calls + with mock.patch('azext_sftp.file_utils.delete_file') as mock_delete: + custom._cleanup_credentials( + delete_keys, delete_cert, credentials_folder, + temp_cert if delete_cert else None, + temp_private if delete_keys else None, + temp_public if delete_keys else None + ) + + # Verify expected file deletions were called + expected_calls = [] + if delete_cert: + expected_calls.append(mock.call(temp_cert, mock.ANY, warning=False)) + if delete_keys: + expected_calls.append(mock.call(temp_private, mock.ANY, warning=False)) + expected_calls.append(mock.call(temp_public, mock.ANY, warning=False)) + + if expected_calls: + mock_delete.assert_has_calls(expected_calls, any_order=True) + else: + mock_delete.assert_not_called() + + def test_cleanup_credentials_error_handling(self): + """Test _cleanup_credentials handles errors gracefully.""" + with mock.patch('azext_sftp.file_utils.delete_file', side_effect=OSError("Permission denied")): + with mock.patch('azext_sftp.custom.logger') as mock_logger: + custom._cleanup_credentials( + delete_keys=True, delete_cert=True, credentials_folder=None, + cert_file=self.mock_cert_file, private_key_file=self.mock_private_key, + public_key_file=self.mock_public_key + ) + + # Should log warning but not raise exception + mock_logger.warning.assert_called_once() + + def test_get_storage_endpoint_suffix_cloud_variants(self): + """Test _get_storage_endpoint_suffix for different Azure clouds.""" + cloud_test_cases = [ + # (cloud_name, expected_suffix, description) + ("azurecloud", "blob.core.windows.net", "public cloud"), + ("AZURECLOUD", "blob.core.windows.net", "public cloud uppercase"), + ("azurechinacloud", "blob.core.chinacloudapi.cn", "china cloud"), + ("azureusgovernment", "blob.core.usgovcloudapi.net", "us government cloud"), + ("unknowncloud", "blob.core.windows.net", "unknown cloud defaults to public"), + ] + + for cloud_name, expected_suffix, description in cloud_test_cases: + with self.subTest(case=description): + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = cloud_name + + result = custom._get_storage_endpoint_suffix(cmd) + self.assertEqual(result, expected_suffix) diff --git a/src/sftp/azext_sftp/tests/latest/test_file_utils.py b/src/sftp/azext_sftp/tests/latest/test_file_utils.py new file mode 100644 index 00000000000..75ee1143337 --- /dev/null +++ b/src/sftp/azext_sftp/tests/latest/test_file_utils.py @@ -0,0 +1,375 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +import tempfile +import os +import shutil +from unittest import mock + +from azext_sftp import file_utils +from azure.cli.core import azclierror + + +class SftpFileUtilsTest(unittest.TestCase): + """Test suite for SFTP file utilities. + + Owner: johnli1 + """ + + def setUp(self): + """Set up test fixtures before each test method.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp(prefix="sftp_file_utils_test_") + + def tearDown(self): + """Tear down test fixtures after each test method.""" + super().tearDown() + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def test_delete_file_removes_existing_file(self): + """Test delete_file removes an existing file.""" + # Arrange + test_file = os.path.join(self.temp_dir, "test_delete.txt") + with open(test_file, 'w') as f: + f.write("test content") + + # Act + file_utils.delete_file(test_file, "Test deletion message") + + # Assert + self.assertFalse(os.path.isfile(test_file)) + + def test_delete_file_with_nonexistent_file(self): + """Test delete_file with nonexistent file does nothing.""" + # Arrange + nonexistent_file = os.path.join(self.temp_dir, "nonexistent.txt") + + # Act & Assert - Should not raise an exception + file_utils.delete_file(nonexistent_file, "Test deletion message") + + @mock.patch('os.remove') + def test_delete_file_handles_removal_error_with_warning(self, mock_remove): + """Test delete_file handles removal errors with warning flag.""" + # Arrange + test_file = os.path.join(self.temp_dir, "test_file.txt") + with open(test_file, 'w') as f: + f.write("test") + mock_remove.side_effect = OSError("Permission denied") + + # Act & Assert - Should not raise exception when warning=True + with mock.patch('azext_sftp.file_utils.logger') as mock_logger: + file_utils.delete_file(test_file, "Test message", warning=True) + mock_logger.warning.assert_called_once() + + @mock.patch('os.remove') + def test_delete_file_raises_error_without_warning(self, mock_remove): + """Test delete_file raises error when warning=False.""" + # Arrange + test_file = os.path.join(self.temp_dir, "test_file.txt") + with open(test_file, 'w') as f: + f.write("test") + mock_remove.side_effect = OSError("Permission denied") + + # Act & Assert + with self.assertRaises(azclierror.FileOperationError): + file_utils.delete_file(test_file, "Test message", warning=False) + +class SftpFileUtilsCertificateTest(unittest.TestCase): + """Test suite for SFTP file utilities certificate-related functions. + + Owner: johnli1 + """ + + def setUp(self): + """Set up test fixtures before each test method.""" + super().setUp() + self.temp_dir = tempfile.mkdtemp(prefix="sftp_cert_test_") + self.mock_public_key = os.path.join(self.temp_dir, "test_key.pub") + + # Create a mock public key file + with open(self.mock_public_key, 'w') as f: + f.write("ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com") + + def tearDown(self): + """Tear down test fixtures after each test method.""" + super().tearDown() + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + @mock.patch('tempfile.mkdtemp') + @mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile') + def test_check_or_create_public_private_files_generates_keys(self, mock_create_keyfile, mock_mkdtemp): + """Test check_or_create_public_private_files generates new keys when none provided.""" + # Arrange + mock_mkdtemp.return_value = self.temp_dir + expected_public_key = os.path.join(self.temp_dir, "id_rsa.pub") + expected_private_key = os.path.join(self.temp_dir, "id_rsa") + + # Mock the create_ssh_keyfile to create the files when called + def create_key_files(private_key_path, ssh_client_folder): + with open(private_key_path, 'w') as f: + f.write("-----BEGIN OPENSSH PRIVATE KEY-----\ntest\n-----END OPENSSH PRIVATE KEY-----") + with open(private_key_path + ".pub", 'w') as f: + f.write("ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQ test@example.com") + + mock_create_keyfile.side_effect = create_key_files + + # Act + public_key, private_key, delete_keys = file_utils.check_or_create_public_private_files( + None, None, None) + + # Assert + self.assertEqual(public_key, expected_public_key) + self.assertEqual(private_key, expected_private_key) + self.assertTrue(delete_keys) + mock_create_keyfile.assert_called_once_with(expected_private_key, None) + + def test_check_or_create_public_private_files_with_credentials_folder(self): + """Test key generation in specified credentials folder. + + This verifies SSH extension pattern for controlled key placement. + """ + with mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile') as mock_create_keyfile: + with mock.patch('os.makedirs') as mock_makedirs: + with mock.patch('os.path.isdir', return_value=False): + + # Mock the create_ssh_keyfile to actually create the files + def create_key_files(private_key_path, passphrase): + with open(private_key_path, 'w') as f: + f.write("mock private key") + with open(private_key_path + ".pub", 'w') as f: + f.write("mock public key") + + mock_create_keyfile.side_effect = create_key_files + + # Test with credentials folder that doesn't exist + public_key, private_key, delete_keys = file_utils.check_or_create_public_private_files( + None, None, self.temp_dir, None) + + # Verify keys are generated in the specified folder + self.assertTrue(public_key.startswith(self.temp_dir)) + self.assertTrue(private_key.startswith(self.temp_dir)) + self.assertTrue(delete_keys) # Should be marked for deletion + + # Verify key generation was called with correct path + mock_create_keyfile.assert_called_once_with(private_key, None) + + @mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile') + @mock.patch('tempfile.mkdtemp') + def test_check_or_create_public_private_files_with_existing_credentials_folder(self, mock_mkdtemp, mock_create_keyfile): + """Test key generation with existing credentials folder. + + This verifies SSH extension pattern where keys are generated in existing folder. + """ + # Create the credentials folder + os.makedirs(self.temp_dir, exist_ok=True) + + # Mock the create_ssh_keyfile to actually create the files + def create_key_files(private_key_path, passphrase): + with open(private_key_path, 'w') as f: + f.write("mock private key") + with open(private_key_path + ".pub", 'w') as f: + f.write("mock public key") + + mock_create_keyfile.side_effect = create_key_files + + with mock.patch('os.path.isdir', return_value=True): + public_key, private_key, delete_keys = file_utils.check_or_create_public_private_files( + None, None, self.temp_dir, None) + + # Verify keys are generated in the specified folder + self.assertTrue(public_key.startswith(self.temp_dir)) + self.assertTrue(private_key.startswith(self.temp_dir)) + self.assertTrue(delete_keys) # Should be marked for deletion + + def test_check_or_create_public_private_files_with_existing_files(self): + """Test check_or_create_public_private_files with existing key files.""" + # Arrange + private_key = os.path.join(self.temp_dir, "existing_key") + with open(private_key, 'w') as f: + f.write("-----BEGIN OPENSSH PRIVATE KEY-----\ntest\n-----END OPENSSH PRIVATE KEY-----") + + # Act + public_key, returned_private_key, delete_keys = file_utils.check_or_create_public_private_files( + self.mock_public_key, private_key, None) + + # Assert + self.assertEqual(public_key, self.mock_public_key) + self.assertEqual(returned_private_key, private_key) + self.assertFalse(delete_keys) + + def test_check_or_create_public_private_files_missing_public_key_error(self): + """Test check_or_create_public_private_files raises error when public key file missing.""" + # Arrange + nonexistent_public_key = os.path.join(self.temp_dir, "nonexistent.pub") + + # Act & Assert + with self.assertRaises(azclierror.FileOperationError) as context: + file_utils.check_or_create_public_private_files(nonexistent_public_key, None, None) + + self.assertIn("not found", str(context.exception)) + + def test_check_or_create_public_private_files_missing_private_key_error(self): + """Test check_or_create_public_private_files raises error when private key file missing.""" + # Arrange + nonexistent_private_key = os.path.join(self.temp_dir, "nonexistent") + + # Act & Assert + with self.assertRaises(azclierror.FileOperationError) as context: + file_utils.check_or_create_public_private_files( + self.mock_public_key, nonexistent_private_key, None) + + self.assertIn("not found", str(context.exception)) + + @mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile') + @mock.patch('tempfile.mkdtemp') + def test_check_or_create_public_private_files_error_handling_during_keygen(self, mock_mkdtemp, mock_create_keyfile): + """Test error handling during key generation process. + + This verifies SSH extension pattern for robust error handling. + """ + mock_mkdtemp.return_value = self.temp_dir + + # Mock key generation to raise error + with mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile', side_effect=Exception("Key generation error")): + with self.assertRaises(Exception) as context: + file_utils.check_or_create_public_private_files(None, None, None, None) + + # Verify error message + self.assertIn("Key generation error", str(context.exception)) + + @mock.patch('azext_sftp.file_utils.Profile') + @mock.patch('azext_sftp.file_utils._prepare_jwk_data') + @mock.patch('azext_sftp.file_utils._write_cert_file') + @mock.patch('azext_sftp.sftp_utils.get_ssh_cert_principals') + def test_get_and_write_certificate_success(self, mock_get_principals, + mock_write_cert, mock_prepare_jwk, mock_profile): + """Test get_and_write_certificate successfully generates certificate.""" + # Arrange + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = "azurecloud" + + mock_profile_instance = mock.Mock() + mock_profile.return_value = mock_profile_instance + mock_profile_instance.get_msal_token.return_value = (None, "test_certificate_data") + + mock_prepare_jwk.return_value = {"test": "data"} + mock_get_principals.return_value = ["testuser@domain.com"] + + # Set up the cert file path that the function will generate + expected_cert_file = str(self.mock_public_key.removesuffix(".pub")) + "-aadcert.pub" + mock_write_cert.return_value = expected_cert_file + + # Act + result_cert_file, username = file_utils.get_and_write_certificate( + cmd, self.mock_public_key, None, None) + + # Assert + self.assertEqual(result_cert_file, expected_cert_file) + self.assertEqual(username, "testuser@domain.com") + mock_prepare_jwk.assert_called_once_with(self.mock_public_key) + mock_write_cert.assert_called_once_with("test_certificate_data", expected_cert_file) + + @mock.patch('azure.cli.core._profile.Profile') + def test_get_and_write_certificate_unsupported_cloud(self, mock_profile): + """Test get_and_write_certificate raises error for unsupported cloud.""" + # Arrange + cmd = mock.Mock() + cmd.cli_ctx.cloud.name = "unsupportedcloud" + + # Act & Assert + with self.assertRaises(azclierror.InvalidArgumentValueError) as context: + file_utils.get_and_write_certificate(cmd, self.mock_public_key, None, None) + + self.assertIn("Unsupported cloud", str(context.exception)) + + @mock.patch('azext_sftp.file_utils._get_modulus_exponent') + @mock.patch('hashlib.sha256') + @mock.patch('json.dumps') + def test_prepare_jwk_data_creates_correct_structure(self, mock_dumps, mock_sha256, mock_get_mod_exp): + """Test _prepare_jwk_data creates correct JWK structure.""" + # Arrange + mock_get_mod_exp.return_value = ("test_modulus", "test_exponent") + mock_hash = mock.Mock() + mock_hash.hexdigest.return_value = "test_key_id" + mock_sha256.return_value = mock_hash + mock_dumps.return_value = "test_jwk_json" + + # Act + result = file_utils._prepare_jwk_data(self.mock_public_key) + + # Assert + self.assertEqual(result["token_type"], "ssh-cert") + self.assertEqual(result["req_cnf"], "test_jwk_json") + self.assertEqual(result["key_id"], "test_key_id") + + def test_write_cert_file_creates_certificate(self): + """Test _write_cert_file creates certificate file with correct format.""" + # Arrange + cert_contents = "test_certificate_data" + cert_file = os.path.join(self.temp_dir, "test_cert.pub") + + # Act + result = file_utils._write_cert_file(cert_contents, cert_file) + + # Assert + self.assertEqual(result, cert_file) + self.assertTrue(os.path.isfile(cert_file)) + + with open(cert_file, 'r') as f: + content = f.read() + self.assertEqual(content, f"ssh-rsa-cert-v01@openssh.com {cert_contents}") + + @mock.patch('azext_sftp.rsa_parser.RSAParser') + def test_get_modulus_exponent_success(self, mock_parser_class): + """Test _get_modulus_exponent successfully extracts modulus and exponent.""" + # Arrange + mock_parser = mock.Mock() + mock_parser.modulus = "test_modulus" + mock_parser.exponent = "test_exponent" + mock_parser_class.return_value = mock_parser + + # Act + modulus, exponent = file_utils._get_modulus_exponent(self.mock_public_key) + + # Assert + self.assertEqual(modulus, "test_modulus") + self.assertEqual(exponent, "test_exponent") + + # Verify parser was called with file contents + with open(self.mock_public_key, 'r') as f: + expected_content = f.read() + mock_parser.parse.assert_called_once_with(expected_content) + + def test_get_modulus_exponent_file_not_found(self): + """Test _get_modulus_exponent handles missing file.""" + # Arrange + nonexistent_file = os.path.join(self.temp_dir, "nonexistent.pub") + + # Act & Assert + with self.assertRaises(azclierror.FileOperationError) as context: + file_utils._get_modulus_exponent(nonexistent_file) + + self.assertIn("was not found", str(context.exception)) + + @mock.patch('azext_sftp.rsa_parser.RSAParser') + def test_get_modulus_exponent_parse_error(self, mock_parser_class): + """Test _get_modulus_exponent handles parsing errors.""" + # Arrange + mock_parser = mock.Mock() + mock_parser.parse.side_effect = ValueError("Invalid key format") + mock_parser_class.return_value = mock_parser + + # Act & Assert + with self.assertRaises(azclierror.FileOperationError) as context: + file_utils._get_modulus_exponent(self.mock_public_key) + + self.assertIn("Could not parse public key", str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/sftp/azext_sftp/tests/latest/test_rsa_parser.py b/src/sftp/azext_sftp/tests/latest/test_rsa_parser.py new file mode 100644 index 00000000000..9dc45125924 --- /dev/null +++ b/src/sftp/azext_sftp/tests/latest/test_rsa_parser.py @@ -0,0 +1,169 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +from unittest import mock + +from azext_sftp import rsa_parser + + +class RSAParserTest(unittest.TestCase): + """Test suite for RSAParser class. + + Owner: johnli1 + """ + + def test_rsa_parser_success(self): + """Test successful parsing of a valid RSA public key.""" + public_key_text = 'ssh-rsa ' + self._get_good_key() + parser = rsa_parser.RSAParser() + + parser.parse(public_key_text) + + self.assertEqual('ssh-rsa', parser.algorithm) + self.assertEqual(self._get_good_modulus(), parser.modulus) + self.assertEqual(self._get_good_exponent(), parser.exponent) + + def test_rsa_parser_too_few_public_key_text_fields(self): + """Test error when public key text has insufficient fields.""" + public_key_text = 'algo' + parser = rsa_parser.RSAParser() + + with self.assertRaises(ValueError) as context: + parser.parse(public_key_text) + + self.assertIn("Incorrectly formatted public key", str(context.exception)) + + def test_rsa_parser_wrong_algorithm(self): + """Test error when public key uses wrong algorithm.""" + public_key_text = 'wrongalgo key' + parser = rsa_parser.RSAParser() + + with self.assertRaises(ValueError) as context: + parser.parse(public_key_text) + + self.assertIn("Public key is not ssh-rsa algorithm", str(context.exception)) + + @mock.patch('base64.b64decode') + def test_rsa_parser_algorithm_mismatch(self, mock_decode): + """Test error when decoded algorithm doesn't match ssh-rsa.""" + public_key_text = 'ssh-rsa key' + parser = rsa_parser.RSAParser() + + with mock.patch.object(parser, '_get_fields') as mock_get_fields: + mock_get_fields.return_value = [b'otheralgo', b'exp', b'mod'] + + with self.assertRaises(ValueError) as context: + parser.parse(public_key_text) + + self.assertIn("Encoded public key is not ssh-rsa algorithm", str(context.exception)) + + mock_decode.assert_called_once_with('key') + mock_get_fields.assert_called_once_with(mock_decode.return_value) + + @mock.patch('base64.b64decode') + def test_rsa_parser_too_few_encoded_fields(self, mock_decode): + """Test error when decoded key has too few fields.""" + public_key_text = 'ssh-rsa key' + mock_decode.return_value = b'decodedkey' + parser = rsa_parser.RSAParser() + + with mock.patch.object(parser, '_get_fields') as mock_get_fields: + mock_get_fields.return_value = [b'ssh-rsa', b'exp'] + + with self.assertRaises(ValueError) as context: + parser.parse(public_key_text) + + self.assertIn("Incorrectly encoded public key", str(context.exception)) + + mock_decode.assert_called_once_with('key') + mock_get_fields.assert_called_once_with(mock_decode.return_value) + + def test_rsa_parser_initialization(self): + """Test proper initialization of RSAParser.""" + parser = rsa_parser.RSAParser() + + self.assertEqual('', parser.algorithm) + self.assertEqual('', parser.modulus) + self.assertEqual('', parser.exponent) + self.assertTrue(parser._key_length_big_endian) + + def test_get_struct_format_big_endian(self): + """Test struct format for big endian.""" + parser = rsa_parser.RSAParser() + parser._key_length_big_endian = True + + format_str = parser._get_struct_format() + + self.assertEqual(">L", format_str) + + def test_get_struct_format_little_endian(self): + """Test struct format for little endian.""" + parser = rsa_parser.RSAParser() + parser._key_length_big_endian = False + + format_str = parser._get_struct_format() + + self.assertEqual("