44# --------------------------------------------------------------------------------------------
55
66import os
7- import hashlib
8- import json
97import tempfile
10- import time
118import shutil
12- import oschmod
139
1410from knack import log
1511from azure .cli .core import azclierror
16- from azure .cli .core import telemetry
1712from azure .cli .core .style import Style , print_styled_text
1813from azure .cli .core ._profile import Profile
1914
20- from . import rsa_parser
2115from . import sftp_info
2216from . import sftp_utils
2317from . import file_utils
@@ -185,143 +179,48 @@ def sftp_connect(cmd, storage_account, port=None, cert_file=None, private_key_fi
185179
186180def _check_or_create_public_private_files (public_key_file , private_key_file , credentials_folder ,
187181 ssh_client_folder = None ):
188- delete_keys = False
189- if not public_key_file and not private_key_file :
190- delete_keys = True
191- if not credentials_folder :
192- credentials_folder = tempfile .mkdtemp (prefix = "aadsshcert" )
193- else :
194- if not os .path .isdir (credentials_folder ):
195- os .makedirs (credentials_folder )
196- public_key_file = os .path .join (credentials_folder , "id_rsa.pub" )
197- private_key_file = os .path .join (credentials_folder , "id_rsa" )
198- sftp_utils .create_ssh_keyfile (private_key_file , ssh_client_folder )
199-
200- if not public_key_file :
201- if private_key_file :
202- public_key_file = str (private_key_file ) + ".pub"
203- else :
204- raise azclierror .RequiredArgumentMissingError ("Public key file not specified" )
205-
206- if not os .path .isfile (public_key_file ):
207- raise azclierror .FileOperationError (f"Public key file { public_key_file } not found" )
208-
209- if private_key_file :
210- if not os .path .isfile (private_key_file ):
211- raise azclierror .FileOperationError (f"Private key file { private_key_file } not found" )
212-
213- if not private_key_file :
214- if public_key_file .endswith (".pub" ):
215- private_key_file = public_key_file [:- 4 ] if os .path .isfile (public_key_file [:- 4 ]) else None
216-
217- return public_key_file , private_key_file , delete_keys
182+ """Check for existing key files or create new ones if needed."""
183+ return file_utils .check_or_create_public_private_files (
184+ public_key_file , private_key_file , credentials_folder , ssh_client_folder )
218185
219186
220187def _get_and_write_certificate (cmd , public_key_file , cert_file , ssh_client_folder ):
221- cloudtoscope = {
222- "azurecloud" : "https://pas.windows.net/CheckMyAccess/Linux/.default" ,
223- "azurechinacloud" : "https://pas.chinacloudapi.cn/CheckMyAccess/Linux/.default" ,
224- "azureusgovernment" : "https://pasff.usgovcloudapi.net/CheckMyAccess/Linux/.default"
225- }
226- scope = cloudtoscope .get (cmd .cli_ctx .cloud .name .lower (), None )
227- if not scope :
228- raise azclierror .InvalidArgumentValueError (
229- f"Unsupported cloud { cmd .cli_ctx .cloud .name .lower ()} " ,
230- "Supported clouds include azurecloud,azurechinacloud,azureusgovernment" )
231-
232- scopes = [scope ]
233- data = _prepare_jwk_data (public_key_file )
234- profile = Profile (cli_ctx = cmd .cli_ctx )
235-
236- t0 = time .time ()
237- if hasattr (profile , "get_msal_token" ):
238- _ , certificate = profile .get_msal_token (scopes , data )
239- else :
240- credential , _ , _ = profile .get_login_credentials (subscription_id = profile .get_subscription ()["id" ])
241- certificatedata = credential .get_token (* scopes , data = data )
242- certificate = certificatedata .token
243-
244- time_elapsed = time .time () - t0
245- telemetry .add_extension_event ('sftp' ,
246- {'Context.Default.AzureCLI.SftpGetCertificateTime' : time_elapsed })
247-
248- if not cert_file :
249- base_name = os .path .splitext (str (public_key_file ))[0 ]
250- cert_file = base_name + "-aadcert.pub"
251-
252- logger .debug ("Generating certificate %s" , cert_file )
253- _write_cert_file (certificate , cert_file )
254- username = sftp_utils .get_ssh_cert_principals (cert_file , ssh_client_folder )[0 ]
255- oschmod .set_mode (cert_file , 0o600 )
256-
257- return cert_file , username .lower ()
188+ """Generate and write an SSH certificate using Azure AD authentication."""
189+ return file_utils .get_and_write_certificate (cmd , public_key_file , cert_file , ssh_client_folder )
258190
259191
260192def _prepare_jwk_data (public_key_file ):
261- modulus , exponent = _get_modulus_exponent (public_key_file )
262- key_hash = hashlib .sha256 ()
263- key_hash .update (modulus .encode ('utf-8' ))
264- key_hash .update (exponent .encode ('utf-8' ))
265- key_id = key_hash .hexdigest ()
266- jwk = {
267- "kty" : "RSA" ,
268- "n" : modulus ,
269- "e" : exponent ,
270- "kid" : key_id
271- }
272- json_jwk = json .dumps (jwk )
273- data = {
274- "token_type" : "ssh-cert" ,
275- "req_cnf" : json_jwk ,
276- "key_id" : key_id
277- }
278- return data
193+ """Prepare JWK data for certificate request."""
194+ return file_utils ._prepare_jwk_data (public_key_file ) # pylint: disable=protected-access
279195
280196
281197def _write_cert_file (certificate_contents , cert_file ):
282- with open (cert_file , 'w' , encoding = 'utf-8' ) as f :
283- f .write (f"ssh-rsa-cert-v01@openssh.com { certificate_contents } " )
284- oschmod .set_mode (cert_file , 0o644 )
285- return cert_file
198+ """Write SSH certificate to file."""
199+ return file_utils ._write_cert_file (certificate_contents , cert_file ) # pylint: disable=protected-access
286200
287201
288202def _get_modulus_exponent (public_key_file ):
289- if not os .path .isfile (public_key_file ):
290- raise azclierror .FileOperationError (f"Public key file '{ public_key_file } ' was not found" )
291-
292- with open (public_key_file , 'r' , encoding = 'utf-8' ) as f :
293- public_key_text = f .read ()
294-
295- parser = rsa_parser .RSAParser ()
296- try :
297- parser .parse (public_key_text )
298- except Exception as e :
299- raise azclierror .FileOperationError (f"Could not parse public key. Error: { str (e )} " )
300- modulus = parser .modulus
301- exponent = parser .exponent
302-
303- return modulus , exponent
203+ """Extract modulus and exponent from RSA public key file."""
204+ return file_utils ._get_modulus_exponent (public_key_file ) # pylint: disable=protected-access
304205
305206
306207def _assert_args (storage_account , cert_file , public_key_file , private_key_file ):
307208 """Validate SFTP connection arguments."""
308209 if not storage_account :
309210 raise azclierror .RequiredArgumentMissingError ("Storage account name is required." )
310211
311- if cert_file :
312- expanded_cert_file = os .path .expanduser (cert_file )
313- if not os .path .isfile (expanded_cert_file ):
314- raise azclierror .FileOperationError (f"Certificate file { cert_file } not found." )
212+ # Check file existence for provided files
213+ files_to_check = [
214+ (cert_file , "Certificate" ),
215+ (public_key_file , "Public key" ),
216+ (private_key_file , "Private key" )
217+ ]
315218
316- if public_key_file :
317- expanded_public_key_file = os .path .expanduser (public_key_file )
318- if not os .path .isfile (expanded_public_key_file ):
319- raise azclierror .FileOperationError (f"Public key file { public_key_file } not found." )
320-
321- if private_key_file :
322- expanded_private_key_file = os .path .expanduser (private_key_file )
323- if not os .path .isfile (expanded_private_key_file ):
324- raise azclierror .FileOperationError (f"Private key file { private_key_file } not found." )
219+ for file_path , file_type in files_to_check :
220+ if file_path :
221+ expanded_path = os .path .expanduser (file_path )
222+ if not os .path .isfile (expanded_path ):
223+ raise azclierror .FileOperationError (f"{ file_type } file { file_path } not found." )
325224
326225
327226def _do_sftp_op (sftp_session , op_call ):
@@ -337,12 +236,9 @@ def _cleanup_credentials(delete_keys, delete_cert, credentials_folder, cert_file
337236 file_utils .delete_file (cert_file , f"Deleting generated certificate { cert_file } " , warning = False )
338237
339238 if delete_keys :
340- if private_key_file and os .path .isfile (private_key_file ):
341- file_utils .delete_file (private_key_file ,
342- f"Deleting generated private key { private_key_file } " , warning = False )
343- if public_key_file and os .path .isfile (public_key_file ):
344- file_utils .delete_file (public_key_file ,
345- f"Deleting generated public key { public_key_file } " , warning = False )
239+ for key_file , key_type in [(private_key_file , "private" ), (public_key_file , "public" )]:
240+ if key_file and os .path .isfile (key_file ):
241+ file_utils .delete_file (key_file , f"Deleting generated { key_type } key { key_file } " , warning = False )
346242
347243 if credentials_folder and os .path .isdir (credentials_folder ):
348244 logger .debug ("Deleting credentials folder %s" , credentials_folder )
@@ -354,9 +250,9 @@ def _cleanup_credentials(delete_keys, delete_cert, credentials_folder, cert_file
354250
355251def _get_storage_endpoint_suffix (cmd ):
356252 """Get the appropriate storage endpoint suffix based on Azure cloud environment."""
357- cloud_to_storage_suffix = {
253+ cloud_suffixes = {
358254 "azurecloud" : "blob.core.windows.net" ,
359255 "azurechinacloud" : "blob.core.chinacloudapi.cn" ,
360256 "azureusgovernment" : "blob.core.usgovcloudapi.net"
361257 }
362- return cloud_to_storage_suffix .get (cmd .cli_ctx .cloud .name .lower (), "blob.core.windows.net" )
258+ return cloud_suffixes .get (cmd .cli_ctx .cloud .name .lower (), "blob.core.windows.net" )
0 commit comments