1515try :
1616 from azure .core .exceptions import ResourceNotFoundError
1717 from azure .core .credentials import AzureNamedKeyCredential
18+
1819 from azure .storage .blob import (
1920 BlobPrefix ,
2021 BlobSasPermissions ,
2425 generate_blob_sas ,
2526 )
2627
28+ from azure .storage .blob ._shared .authentication import (
29+ SharedKeyCredentialPolicy as BlobSharedKeyCredentialPolicy ,
30+ )
31+
2732 from azure .storage .filedatalake import DataLakeServiceClient , FileProperties
33+ from azure .storage .filedatalake ._shared .authentication import (
34+ SharedKeyCredentialPolicy as DataLakeSharedKeyCredentialPolicy ,
35+ )
36+
2837except ModuleNotFoundError :
2938 implementation_registry ["azure" ].dependencies_loaded = False
3039
@@ -104,19 +113,29 @@ def __init__(
104113 if connection_string is None :
105114 connection_string = os .getenv ("AZURE_STORAGE_CONNECTION_STRING" , None )
106115
107- self .data_lake_client = None # only needs to end up being set if HNS is enabled
116+ self .data_lake_client : Optional [DataLakeServiceClient ] = (
117+ None # only needs to end up being set if HNS is enabled
118+ )
108119
109120 if blob_service_client is not None :
110121 self .service_client = blob_service_client
111122
112123 # create from blob service client if not passed
113124 if data_lake_client is None :
114- self .data_lake_client = DataLakeServiceClient (
115- account_url = self .service_client .url .replace (".blob." , ".dfs." , 1 ),
116- credential = AzureNamedKeyCredential (
125+ credential = (
126+ blob_service_client .credential
127+ if not isinstance (
128+ blob_service_client .credential , BlobSharedKeyCredentialPolicy
129+ )
130+ else AzureNamedKeyCredential (
117131 blob_service_client .credential .account_name ,
118132 blob_service_client .credential .account_key ,
119- ),
133+ )
134+ )
135+
136+ self .data_lake_client = DataLakeServiceClient (
137+ account_url = self .service_client .url .replace (".blob." , ".dfs." , 1 ),
138+ credential = credential ,
120139 )
121140 else :
122141 self .data_lake_client = data_lake_client
@@ -125,12 +144,21 @@ def __init__(
125144 self .data_lake_client = data_lake_client
126145
127146 if blob_service_client is None :
128- self .service_client = BlobServiceClient (
129- account_url = self .data_lake_client .url .replace (".dfs." , ".blob." , 1 ),
130- credential = AzureNamedKeyCredential (
147+
148+ credential = (
149+ data_lake_client .credential
150+ if not isinstance (
151+ data_lake_client .credential , DataLakeSharedKeyCredentialPolicy
152+ )
153+ else AzureNamedKeyCredential (
131154 data_lake_client .credential .account_name ,
132155 data_lake_client .credential .account_key ,
133- ),
156+ )
157+ )
158+
159+ self .service_client = BlobServiceClient (
160+ account_url = self .data_lake_client .url .replace (".dfs." , ".blob." , 1 ),
161+ credential = credential ,
134162 )
135163
136164 elif connection_string is not None :
@@ -167,19 +195,31 @@ def __init__(
167195 "Credentials are required; see docs for options."
168196 )
169197
170- self ._hns_enabled = None
198+ self ._hns_enabled : Optional [ bool ] = None
171199
172- def _check_hns (self ) -> Optional [bool ]:
200+ def _check_hns (self , cloud_path : AzureBlobPath ) -> Optional [bool ]:
173201 if self ._hns_enabled is None :
174- account_info = self .service_client .get_account_information () # type: ignore
175- self ._hns_enabled = account_info .get ("is_hns_enabled" , False ) # type: ignore
202+ try :
203+ account_info = self .service_client .get_account_information () # type: ignore
204+ self ._hns_enabled = account_info .get ("is_hns_enabled" , False ) # type: ignore
205+ except ResourceNotFoundError :
206+ # get_account_information() not supported with this credential; we have to fallback to
207+ # checking if the root directory exists and is a has 'metadata': {'hdi_isfolder': 'true'}
208+ root_dir = self .service_client .get_blob_client (
209+ container = cloud_path .container , blob = "/"
210+ )
211+ self ._hns_enabled = (
212+ root_dir .exists ()
213+ and root_dir .get_blob_properties ().metadata .get ("hdi_isfolder" , False )
214+ == "true"
215+ )
176216
177217 return self ._hns_enabled
178218
179219 def _get_metadata (
180220 self , cloud_path : AzureBlobPath
181221 ) -> Union ["BlobProperties" , "FileProperties" , Dict [str , Any ]]:
182- if self ._check_hns ():
222+ if self ._check_hns (cloud_path ):
183223
184224 # works on both files and directories
185225 fsc = self .data_lake_client .get_file_system_client (cloud_path .container ) # type: ignore
@@ -292,7 +332,7 @@ def _list_dir(
292332 if prefix and not prefix .endswith ("/" ):
293333 prefix += "/"
294334
295- if self ._check_hns ():
335+ if self ._check_hns (cloud_path ):
296336 file_system_client = self .data_lake_client .get_file_system_client (cloud_path .container ) # type: ignore
297337 paths = file_system_client .get_paths (path = cloud_path .blob , recursive = recursive )
298338
@@ -334,7 +374,7 @@ def _move_file(
334374 )
335375
336376 # we can use rename API when the same account on adls gen2
337- elif remove_src and (src .client is dst .client ) and self ._check_hns ():
377+ elif remove_src and (src .client is dst .client ) and self ._check_hns (src ):
338378 fsc = self .data_lake_client .get_file_system_client (src .container ) # type: ignore
339379
340380 if src .is_dir ():
@@ -358,7 +398,7 @@ def _move_file(
358398 def _mkdir (
359399 self , cloud_path : AzureBlobPath , parents : bool = False , exist_ok : bool = False
360400 ) -> None :
361- if self ._check_hns ():
401+ if self ._check_hns (cloud_path ):
362402 file_system_client = self .data_lake_client .get_file_system_client (cloud_path .container ) # type: ignore
363403 directory_client = file_system_client .get_directory_client (cloud_path .blob )
364404
@@ -379,7 +419,7 @@ def _mkdir(
379419 def _remove (self , cloud_path : AzureBlobPath , missing_ok : bool = True ) -> None :
380420 file_or_dir = self ._is_file_or_dir (cloud_path )
381421 if file_or_dir == "dir" :
382- if self ._check_hns ():
422+ if self ._check_hns (cloud_path ):
383423 _hns_rmtree (self .data_lake_client , cloud_path .container , cloud_path .blob )
384424 return
385425
@@ -432,7 +472,7 @@ def _generate_presigned_url(
432472 self , cloud_path : AzureBlobPath , expire_seconds : int = 60 * 60
433473 ) -> str :
434474 sas_token = generate_blob_sas (
435- self .service_client .account_name ,
475+ self .service_client .account_name , # type: ignore[arg-type]
436476 container_name = cloud_path .container ,
437477 blob_name = cloud_path .blob ,
438478 account_key = self .service_client .credential .account_key ,
0 commit comments