Skip to content

Commit 938e4c5

Browse files
authored
DBFS Root Support for HMS Federation (#3425)
Closes #3406 Add external location for DBFS root location
1 parent 4e0507b commit 938e4c5

File tree

13 files changed

+114
-60
lines changed

13 files changed

+114
-60
lines changed

src/databricks/labs/ucx/aws/access.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _get_role_access_task(self, arn: str, role_name: str):
208208
return policy_actions
209209

210210
def _identify_missing_paths(self):
211-
external_locations = self._locations.snapshot()
211+
external_locations = self._locations.external_locations_with_root()
212212
compatible_roles = self.load_uc_compatible_roles()
213213
missing_paths = set()
214214
for external_location in external_locations:
@@ -226,13 +226,14 @@ def get_roles_to_migrate(self) -> list[AWSCredentialCandidate]:
226226
"""
227227
Identify the roles that need to be migrated to UC from the UC compatible roles list.
228228
"""
229-
external_locations = self._locations.snapshot()
229+
external_locations = self._locations.external_locations_with_root()
230+
logger.info(f"Found {len(external_locations)} external locations")
230231
compatible_roles = self.load_uc_compatible_roles()
231232
roles: dict[str, AWSCredentialCandidate] = {}
232233
for external_location in external_locations:
233234
path = PurePath(external_location.location)
234235
for role in compatible_roles:
235-
if not (path.match(role.resource_path) or path.match(role.resource_path + "/*")):
236+
if not (PurePath(role.resource_path) in path.parents or path.match(role.resource_path)):
236237
continue
237238
if role.role_arn not in roles:
238239
roles[role.role_arn] = AWSCredentialCandidate(
@@ -323,7 +324,7 @@ def _create_uber_instance_profile(self, iam_role_name: str, iam_policy_name: str
323324

324325
def create_uber_principal(self, prompts: Prompts):
325326
config = self._installation.load(WorkspaceConfig)
326-
s3_paths = {loc.location for loc in self._locations.snapshot()}
327+
s3_paths = {loc.location for loc in self._locations.external_locations_with_root()}
327328
if len(s3_paths) == 0:
328329
logger.info("No S3 paths to migrate found")
329330
return
@@ -374,11 +375,16 @@ def create_uber_principal(self, prompts: Prompts):
374375
logger.error(f"Failed to assign instance profile to cluster policy {iam_role_name}")
375376
self._aws_resources.delete_instance_profile(iam_role_name, iam_role_name)
376377

378+
@classmethod
379+
def _clean_external_location_name(cls, location: str) -> str:
380+
# Remove leading s3:// s3a:// and trailing /
381+
return location.replace("s3://", "").replace("s3a://", "").replace("/", "_").replace(":", "_").replace(".", "_")
382+
377383
def _generate_role_name(self, single_role: bool, role_name: str, location: str) -> str:
378384
if single_role:
379385
metastore_id = self._ws.metastores.current().as_dict()["metastore_id"]
380386
return f"{role_name}_{metastore_id}"
381-
return f"{role_name}_{location[5:]}"
387+
return f"{role_name}_{self._clean_external_location_name(location)}"
382388

383389
def delete_uc_role(self, role_name: str):
384390
self._aws_resources.delete_role(role_name)

src/databricks/labs/ucx/aws/locations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
self._principal_acl = principal_acl
3030
# When HMS federation is enabled, the fallback bit is set for all the
3131
# locations which are created by UCX.
32-
self._enable_fallback_mode = enable_hms_federation
32+
self._enable_hms_federation = enable_hms_federation
3333

3434
def run(self) -> None:
3535
"""
@@ -38,7 +38,7 @@ def run(self) -> None:
3838
Create external location for the path using the credential identified
3939
"""
4040
credential_dict = self._get_existing_credentials_dict()
41-
external_locations = self._external_locations.snapshot()
41+
external_locations = list(self._external_locations.external_locations_with_root())
4242
existing_external_locations = self._ws.external_locations.list()
4343
existing_paths = []
4444
for external_location in existing_external_locations:
@@ -56,7 +56,7 @@ def run(self) -> None:
5656
path,
5757
credential_dict[role_arn],
5858
skip_validation=True,
59-
fallback=self._enable_fallback_mode,
59+
fallback=self._enable_hms_federation,
6060
)
6161
self._principal_acl.apply_location_acl()
6262

@@ -91,7 +91,7 @@ def _identify_missing_external_locations(
9191
path = role.resource_path
9292
if path.endswith("/*"):
9393
path = path[:-2]
94-
if new_path.match(path + "/*") or new_path.match(path):
94+
if PurePath(path) in new_path.parents or new_path.match(path):
9595
matching_role = role.role_arn
9696
continue
9797
if matching_role:

src/databricks/labs/ucx/azure/access.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ def load(self):
690690
return self._installation.load(list[StoragePermissionMapping], filename=self.FILENAME)
691691

692692
def _get_storage_accounts(self) -> list[StorageAccount]:
693-
external_locations = self._locations.snapshot()
693+
external_locations = self._locations.external_locations_with_root()
694694
used_storage_accounts = []
695695
for location in external_locations:
696696
if location.location.startswith("abfss://"):

src/databricks/labs/ucx/azure/locations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
self._resource_permissions = resource_permissions
2828
self._azurerm = azurerm
2929
self._principal_acl = principal_acl
30-
self._enable_fallback_mode = enable_hms_federation
30+
self._enable_hms_federation = enable_hms_federation
3131

3232
def _app_id_credential_name_mapping(self) -> tuple[dict[str, str], dict[str, str]]:
3333
# list all storage credentials.
@@ -128,7 +128,7 @@ def _create_external_location_helper(
128128
comment=comment,
129129
read_only=read_only,
130130
skip_validation=skip_validation,
131-
fallback=self._enable_fallback_mode,
131+
fallback=self._enable_hms_federation,
132132
)
133133
return url
134134
except InvalidParameterValue as invalid:

src/databricks/labs/ucx/contexts/application.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,6 @@ def mounts_crawler(self) -> MountsCrawler:
397397
self.sql_backend,
398398
self.workspace_client,
399399
self.inventory_database,
400-
self.config.enable_hms_federation,
401400
)
402401

403402
@cached_property
@@ -416,6 +415,7 @@ def external_locations(self) -> ExternalLocations:
416415
self.inventory_database,
417416
self.tables_crawler,
418417
self.mounts_crawler,
418+
self.config.enable_hms_federation,
419419
)
420420

421421
@cached_property

src/databricks/labs/ucx/hive_metastore/federation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def _get_authorized_paths(self) -> str:
8585
current_user = self._workspace_client.current_user.me()
8686
if not current_user.user_name:
8787
raise NotFound('Current user not found')
88-
for external_location_info in self._external_locations.snapshot():
89-
location = external_location_info.location.rstrip('/').replace('s3a://', 's3://')
88+
for external_location_info in self._external_locations.external_locations_with_root():
89+
location = ExternalLocations.clean_location(external_location_info.location)
9090
existing_location = existing.get(location)
9191
if not existing_location:
9292
logger.warning(f'External location {location} not found')

src/databricks/labs/ucx/hive_metastore/locations.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ def location(self) -> str | None:
6969
def _parse_location(cls, location: str | None) -> list[str]:
7070
if not location:
7171
return []
72-
parse_result = cls._parse_url(location.rstrip("/"))
72+
location = ExternalLocations.clean_location(location)
73+
parse_result = cls._parse_url(location)
7374
if not parse_result:
7475
return []
7576
parts = [parse_result.scheme, parse_result.netloc]
@@ -154,17 +155,66 @@ def __init__(
154155
schema: str,
155156
tables_crawler: TablesCrawler,
156157
mounts_crawler: 'MountsCrawler',
158+
enable_hms_federation: bool = False,
157159
):
158160
super().__init__(sql_backend, "hive_metastore", schema, "external_locations", ExternalLocation)
159161
self._ws = ws
160162
self._tables_crawler = tables_crawler
161163
self._mounts_crawler = mounts_crawler
164+
self._enable_hms_federation = enable_hms_federation
162165

163166
@cached_property
164167
def _mounts_snapshot(self) -> list['Mount']:
165168
"""Returns all mounts, sorted by longest prefixes first."""
166169
return sorted(self._mounts_crawler.snapshot(), key=lambda _: (len(_.name), _.name), reverse=True)
167170

171+
@staticmethod
172+
def clean_location(location: str) -> str:
173+
# remove the s3a scheme and replace it with s3 as these can be considered the same and will be treated as such
174+
# Having s3a and s3 as separate locations will cause issues when trying to find overlapping locations
175+
return re.sub(r"^s3a:/", r"s3:/", location).rstrip("/")
176+
177+
def external_locations_with_root(self) -> list[ExternalLocation]:
178+
"""
179+
Produces a list of external locations with the DBFS root location appended to the list.
180+
Utilizes the snapshot method.
181+
Used for HMS Federation.
182+
183+
Returns:
184+
List of ExternalLocation objects
185+
"""
186+
187+
external_locations = list(self.snapshot())
188+
dbfs_root = self._get_dbfs_root()
189+
if dbfs_root:
190+
external_locations.append(dbfs_root)
191+
return external_locations
192+
193+
def _get_dbfs_root(self) -> ExternalLocation | None:
194+
"""
195+
Get the root location of the DBFS only if HMS Fed is enabled.
196+
Utilizes an undocumented Databricks API call
197+
198+
Returns:
199+
Cloud storage root location for dbfs
200+
201+
"""
202+
if not self._enable_hms_federation:
203+
return None
204+
logger.debug("Retrieving DBFS root location")
205+
try:
206+
response = self._ws.api_client.do("GET", "/api/2.0/dbfs/resolve-path", query={"path": "dbfs:/"})
207+
if isinstance(response, dict):
208+
resolved_path = response.get("resolved_path")
209+
if resolved_path:
210+
path = f"{self.clean_location(resolved_path)}/user/hive/warehouse"
211+
return ExternalLocation(path, 0)
212+
except NotFound:
213+
# Couldn't retrieve the DBFS root location
214+
logger.warning("DBFS root location not found")
215+
return None
216+
return None
217+
168218
def _external_locations(self) -> Iterable[ExternalLocation]:
169219
trie = LocationTrie()
170220
for table in self._tables_crawler.snapshot():
@@ -356,11 +406,9 @@ def __init__(
356406
sql_backend: SqlBackend,
357407
ws: WorkspaceClient,
358408
inventory_database: str,
359-
enable_hms_federation: bool = False,
360409
):
361410
super().__init__(sql_backend, "hive_metastore", inventory_database, "mounts", Mount)
362411
self._dbutils = ws.dbutils
363-
self._enable_hms_federation = enable_hms_federation
364412

365413
@staticmethod
366414
def _deduplicate_mounts(mounts: list) -> list:
@@ -389,6 +437,7 @@ def _jvm(self):
389437
return None
390438

391439
def _resolve_dbfs_root(self) -> Mount | None:
440+
# TODO: Consider deprecating this method and rely on the new API call
392441
# pylint: disable=broad-exception-caught,too-many-try-statements
393442
try:
394443
jvm = self._jvm
@@ -412,12 +461,6 @@ def _crawl(self) -> Iterable[Mount]:
412461
try:
413462
for mount_point, source, _ in self._dbutils.fs.mounts():
414463
mounts.append(Mount(mount_point, source))
415-
if self._enable_hms_federation:
416-
root_mount = self._resolve_dbfs_root()
417-
if root_mount:
418-
# filter out DatabricksRoot, otherwise ExternalLocations.resolve_mount() won't work
419-
mounts = list(filter(lambda _: _.source != 'DatabricksRoot', mounts))
420-
mounts.append(root_mount)
421464
except Exception as error: # pylint: disable=broad-except
422465
if "com.databricks.backend.daemon.dbutils.DBUtilsCore.mounts() is not whitelisted" in str(error):
423466
logger.warning(

tests/unit/aws/test_access.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def test_create_uber_principal_no_storage(mock_ws, mock_installation, locations)
389389
)
390390
mock_ws.cluster_policies.get.return_value = cluster_policy
391391
external_locations = create_autospec(ExternalLocations)
392-
external_locations.snapshot.return_value = []
392+
external_locations.external_locations_with_root.return_value = []
393393
prompts = MockPrompts({})
394394
aws = create_autospec(AWSResources)
395395
aws_resource_permissions = AWSResourcePermissions(
@@ -459,7 +459,7 @@ def test_create_uc_role_multiple_raises_error(mock_ws, installation_single_role,
459459

460460
def test_create_uc_no_roles(installation_no_roles, mock_ws, caplog):
461461
external_locations = create_autospec(ExternalLocations)
462-
external_locations.snapshot.return_value = []
462+
external_locations.external_locations_with_root.return_value = []
463463
aws = create_autospec(AWSResources)
464464
aws_resource_permissions = AWSResourcePermissions(
465465
installation_no_roles,
@@ -867,15 +867,15 @@ def command_call(_: str):
867867
aws = AWSResources("profile", command_call)
868868

869869
external_locations = create_autospec(ExternalLocations)
870-
external_locations.snapshot.return_value = [
870+
external_locations.external_locations_with_root.return_value = [
871871
ExternalLocation("s3://BUCKET1", 1),
872872
ExternalLocation("s3://BUCKET2/Folder1", 1),
873873
]
874874
resource_permissions = AWSResourcePermissions(installation_multiple_roles, mock_ws, aws, external_locations)
875875
roles = resource_permissions.get_roles_to_migrate()
876876
assert len(roles) == 1
877877
assert len(roles[0].paths) == 2
878-
external_locations.snapshot.assert_called_once()
878+
external_locations.external_locations_with_root.assert_called_once()
879879

880880

881881
def test_delete_uc_roles(mock_ws, installation_multiple_roles, backend, locations):
@@ -939,7 +939,7 @@ def command_call(cmd: str):
939939

940940
aws = AWSResources("profile", command_call)
941941
external_locations = create_autospec(ExternalLocations)
942-
external_locations.snapshot.return_value = []
942+
external_locations.external_locations_with_root.return_value = []
943943
resource_permissions = AWSResourcePermissions(installation_no_roles, mock_ws, aws, external_locations)
944944
resource_permissions.delete_uc_role("uc_role_1")
945945
assert '/path/aws iam delete-role --role-name uc_role_1 --profile profile --output json' in command_calls

0 commit comments

Comments
 (0)