@@ -69,7 +69,8 @@ def location(self) -> str | None:
6969 def _parse_location (cls , location : str | None ) -> list [str ]:
7070 if not location :
7171 return []
72- parse_result = cls ._parse_url (location .rstrip ("/" ))
72+ location = ExternalLocations .clean_location (location )
73+ parse_result = cls ._parse_url (location )
7374 if not parse_result :
7475 return []
7576 parts = [parse_result .scheme , parse_result .netloc ]
@@ -154,17 +155,66 @@ def __init__(
154155 schema : str ,
155156 tables_crawler : TablesCrawler ,
156157 mounts_crawler : 'MountsCrawler' ,
158+ enable_hms_federation : bool = False ,
157159 ):
158160 super ().__init__ (sql_backend , "hive_metastore" , schema , "external_locations" , ExternalLocation )
159161 self ._ws = ws
160162 self ._tables_crawler = tables_crawler
161163 self ._mounts_crawler = mounts_crawler
164+ self ._enable_hms_federation = enable_hms_federation
162165
163166 @cached_property
164167 def _mounts_snapshot (self ) -> list ['Mount' ]:
165168 """Returns all mounts, sorted by longest prefixes first."""
166169 return sorted (self ._mounts_crawler .snapshot (), key = lambda _ : (len (_ .name ), _ .name ), reverse = True )
167170
171+ @staticmethod
172+ def clean_location (location : str ) -> str :
173+ # remove the s3a scheme and replace it with s3 as these can be considered the same and will be treated as such
174+ # Having s3a and s3 as separate locations will cause issues when trying to find overlapping locations
175+ return re .sub (r"^s3a:/" , r"s3:/" , location ).rstrip ("/" )
176+
177+ def external_locations_with_root (self ) -> list [ExternalLocation ]:
178+ """
179+ Produces a list of external locations with the DBFS root location appended to the list.
180+ Utilizes the snapshot method.
181+ Used for HMS Federation.
182+
183+ Returns:
184+ List of ExternalLocation objects
185+ """
186+
187+ external_locations = list (self .snapshot ())
188+ dbfs_root = self ._get_dbfs_root ()
189+ if dbfs_root :
190+ external_locations .append (dbfs_root )
191+ return external_locations
192+
193+ def _get_dbfs_root (self ) -> ExternalLocation | None :
194+ """
195+ Get the root location of the DBFS only if HMS Fed is enabled.
196+ Utilizes an undocumented Databricks API call
197+
198+ Returns:
199+ Cloud storage root location for dbfs
200+
201+ """
202+ if not self ._enable_hms_federation :
203+ return None
204+ logger .debug ("Retrieving DBFS root location" )
205+ try :
206+ response = self ._ws .api_client .do ("GET" , "/api/2.0/dbfs/resolve-path" , query = {"path" : "dbfs:/" })
207+ if isinstance (response , dict ):
208+ resolved_path = response .get ("resolved_path" )
209+ if resolved_path :
210+ path = f"{ self .clean_location (resolved_path )} /user/hive/warehouse"
211+ return ExternalLocation (path , 0 )
212+ except NotFound :
213+ # Couldn't retrieve the DBFS root location
214+ logger .warning ("DBFS root location not found" )
215+ return None
216+ return None
217+
168218 def _external_locations (self ) -> Iterable [ExternalLocation ]:
169219 trie = LocationTrie ()
170220 for table in self ._tables_crawler .snapshot ():
@@ -356,11 +406,9 @@ def __init__(
356406 sql_backend : SqlBackend ,
357407 ws : WorkspaceClient ,
358408 inventory_database : str ,
359- enable_hms_federation : bool = False ,
360409 ):
361410 super ().__init__ (sql_backend , "hive_metastore" , inventory_database , "mounts" , Mount )
362411 self ._dbutils = ws .dbutils
363- self ._enable_hms_federation = enable_hms_federation
364412
365413 @staticmethod
366414 def _deduplicate_mounts (mounts : list ) -> list :
@@ -389,6 +437,7 @@ def _jvm(self):
389437 return None
390438
391439 def _resolve_dbfs_root (self ) -> Mount | None :
440+ # TODO: Consider deprecating this method and rely on the new API call
392441 # pylint: disable=broad-exception-caught,too-many-try-statements
393442 try :
394443 jvm = self ._jvm
@@ -412,12 +461,6 @@ def _crawl(self) -> Iterable[Mount]:
412461 try :
413462 for mount_point , source , _ in self ._dbutils .fs .mounts ():
414463 mounts .append (Mount (mount_point , source ))
415- if self ._enable_hms_federation :
416- root_mount = self ._resolve_dbfs_root ()
417- if root_mount :
418- # filter out DatabricksRoot, otherwise ExternalLocations.resolve_mount() won't work
419- mounts = list (filter (lambda _ : _ .source != 'DatabricksRoot' , mounts ))
420- mounts .append (root_mount )
421464 except Exception as error : # pylint: disable=broad-except
422465 if "com.databricks.backend.daemon.dbutils.DBUtilsCore.mounts() is not whitelisted" in str (error ):
423466 logger .warning (
0 commit comments