Skip to content

Commit be3d45f

Browse files
authored
Add support for migrating Table ACL for SQL Warehouse cluster in AWS using Instance Profile and Azure using SPN (#2258)
This PR contains the following changes for migrating Table ACL for SQL warehouse cluster used through Instance Profile in AWS scenario and spn info for Azure Identifies all the SQL Warehouse identifies all the instance profiles used by the data access config for aws and spn info for azure identify the list of principals who have access to the warehouse (user, group, spn) get the prefix these instance role arn or spn have access and the type of permission get the list of external locations matching the list of prefix get the list of tables having the matching external location prefix from the table mapping CSV Resolves #2238
1 parent 815c7e8 commit be3d45f

File tree

5 files changed

+244
-62
lines changed

5 files changed

+244
-62
lines changed

src/databricks/labs/ucx/assessment/azure.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,27 @@ def _get_azure_spn_from_config(self, config: dict) -> set[AzureServicePrincipalI
181181
)
182182
return set_service_principals
183183

184-
def get_cluster_to_storage_mapping(self):
184+
def get_cluster_to_storage_mapping(self) -> list[ServicePrincipalClusterMapping]:
185185
# this function gives a mapping between an interactive cluster and the spn used by it
186186
# either directly or through a cluster policy.
187187
set_service_principals = set[AzureServicePrincipalInfo]()
188188
spn_cluster_mapping = []
189189
for cluster in self._ws.clusters.list():
190190
if cluster.cluster_source != ClusterSource.JOB and (
191191
cluster.data_security_mode in [DataSecurityMode.LEGACY_SINGLE_USER, DataSecurityMode.NONE]
192+
and cluster.cluster_id is not None
192193
):
193194
set_service_principals = self._get_azure_spn_from_cluster_config(cluster)
194195
spn_cluster_mapping.append(ServicePrincipalClusterMapping(cluster.cluster_id, set_service_principals))
195196
return spn_cluster_mapping
197+
198+
def get_warehouse_to_storage_mapping(self) -> list[ServicePrincipalClusterMapping]:
199+
# this function gives a mapping between a sql warehouse and the spn used by it
200+
spn_warehouse_mapping = []
201+
set_service_principals = self._list_all_spn_in_sql_warehouses_spark_conf()
202+
if len(set_service_principals) == 0:
203+
return []
204+
for warehouse in self._ws.warehouses.list():
205+
if warehouse.id is not None:
206+
spn_warehouse_mapping.append(ServicePrincipalClusterMapping(warehouse.id, set_service_principals))
207+
return spn_warehouse_mapping

src/databricks/labs/ucx/hive_metastore/grants.py

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@
4343

4444

4545
@dataclass
46-
class LocationACL:
47-
location_name: str
48-
principal: str
46+
class ComputeLocations:
47+
compute_id: str
48+
locations: dict
49+
compute_type: str
4950

5051

5152
@dataclass(frozen=True)
@@ -386,6 +387,7 @@ def _get_cluster_to_instance_profile_mapping(self) -> dict[str, str]:
386387
# this function gives a mapping between an interactive cluster and the instance profile used by it
387388
# either directly or through a cluster policy.
388389
cluster_instance_profiles = {}
390+
389391
for cluster in self._ws.clusters.list():
390392
if (
391393
cluster.cluster_id is None
@@ -403,14 +405,26 @@ def _get_cluster_to_instance_profile_mapping(self) -> dict[str, str]:
403405

404406
return cluster_instance_profiles
405407

406-
def get_eligible_locations_principals(self) -> dict[str, dict]:
407-
cluster_locations = {}
408-
eligible_locations = {}
408+
def _update_warehouse_to_instance_profile_mapping(
409+
self,
410+
) -> dict[str, str]:
411+
warehouse_instance_profiles = {}
412+
sql_config = self._ws.warehouses.get_workspace_warehouse_config()
413+
if sql_config.instance_profile_arn is not None:
414+
role_name = sql_config.instance_profile_arn
415+
for warehouse in self._ws.warehouses.list():
416+
if warehouse.id is not None:
417+
warehouse_instance_profiles[warehouse.id] = role_name
418+
return warehouse_instance_profiles
419+
420+
def get_eligible_locations_principals(self) -> list[ComputeLocations]:
409421
cluster_instance_profiles = self._get_cluster_to_instance_profile_mapping()
410-
if len(cluster_instance_profiles) == 0:
411-
# if there are no interactive clusters , then return empty grants
412-
logger.info("No interactive cluster found with instance profiles configured")
413-
return {}
422+
warehouse_instance_profiles = self._update_warehouse_to_instance_profile_mapping()
423+
compute_locations = []
424+
if len(cluster_instance_profiles) == 0 and len(warehouse_instance_profiles) == 0:
425+
# if there are no interactive clusters or warehouse with instance profile , then return empty grants
426+
logger.info("No interactive cluster or sql warehouse found with instance profiles configured")
427+
return []
414428
external_locations = list(self._ws.external_locations.list())
415429
if len(external_locations) == 0:
416430
# if there are no external locations, then throw an error to run migrate_locations cli command
@@ -434,12 +448,17 @@ def get_eligible_locations_principals(self) -> dict[str, dict]:
434448
logger.error(msg)
435449
raise ResourceDoesNotExist(msg) from None
436450

437-
for cluster_id, role_name in cluster_instance_profiles.items():
438-
eligible_locations.update(self._get_external_locations(role_name, external_locations, permission_mappings))
451+
for cluster_id, role_compute in cluster_instance_profiles.items():
452+
eligible_locations = self._get_external_locations(role_compute, external_locations, permission_mappings)
439453
if len(eligible_locations) == 0:
440454
continue
441-
cluster_locations[cluster_id] = eligible_locations
442-
return cluster_locations
455+
compute_locations.append(ComputeLocations(cluster_id, eligible_locations, "clusters"))
456+
for warehouse_id, role_compute in warehouse_instance_profiles.items():
457+
eligible_locations = self._get_external_locations(role_compute, external_locations, permission_mappings)
458+
if len(eligible_locations) == 0:
459+
continue
460+
compute_locations.append(ComputeLocations(warehouse_id, eligible_locations, "warehouses"))
461+
return compute_locations
443462

444463
@staticmethod
445464
def _get_external_locations(
@@ -475,14 +494,14 @@ def __init__(
475494
self._spn_crawler = spn_crawler
476495
self._installation = installation
477496

478-
def get_eligible_locations_principals(self) -> dict[str, dict]:
479-
cluster_locations = {}
480-
eligible_locations = {}
497+
def get_eligible_locations_principals(self) -> list[ComputeLocations]:
498+
compute_locations = []
481499
spn_cluster_mapping = self._spn_crawler.get_cluster_to_storage_mapping()
482-
if len(spn_cluster_mapping) == 0:
500+
spn_warehouse_mapping = self._spn_crawler.get_warehouse_to_storage_mapping()
501+
if len(spn_cluster_mapping) == 0 and len(spn_warehouse_mapping) == 0:
483502
# if there are no interactive clusters , then return empty grants
484503
logger.info("No interactive cluster found with spn configured")
485-
return {}
504+
return []
486505
external_locations = list(self._ws.external_locations.list())
487506
if len(external_locations) == 0:
488507
# if there are no external locations, then throw an error to run migrate_locations cli command
@@ -507,10 +526,18 @@ def get_eligible_locations_principals(self) -> dict[str, dict]:
507526
raise ResourceDoesNotExist(msg) from None
508527

509528
for cluster_spn in spn_cluster_mapping:
529+
eligible_locations = {}
510530
for spn in cluster_spn.spn_info:
511531
eligible_locations.update(self._get_external_locations(spn, external_locations, permission_mappings))
512-
cluster_locations[cluster_spn.cluster_id] = eligible_locations
513-
return cluster_locations
532+
compute_locations.append(ComputeLocations(cluster_spn.cluster_id, eligible_locations, "clusters"))
533+
534+
for warehouse_spn in spn_warehouse_mapping:
535+
eligible_locations = {}
536+
for spn in warehouse_spn.spn_info:
537+
eligible_locations.update(self._get_external_locations(spn, external_locations, permission_mappings))
538+
compute_locations.append(ComputeLocations(warehouse_spn.cluster_id, eligible_locations, "warehouses"))
539+
540+
return compute_locations
514541

515542
def _get_external_locations(
516543
self,
@@ -543,25 +570,25 @@ def __init__(
543570
installation: Installation,
544571
tables_crawler: TablesCrawler,
545572
mounts_crawler: Mounts,
546-
cluster_locations: dict[str, dict],
573+
cluster_locations: list[ComputeLocations],
547574
):
548575
self._backend = backend
549576
self._ws = ws
550577
self._installation = installation
551578
self._tables_crawler = tables_crawler
552579
self._mounts_crawler = mounts_crawler
553-
self._cluster_locations = cluster_locations
580+
self._compute_locations = cluster_locations
554581

555582
def get_interactive_cluster_grants(self) -> list[Grant]:
556583
tables = self._tables_crawler.snapshot()
557584
mounts = list(self._mounts_crawler.snapshot())
558585
grants: set[Grant] = set()
559586

560-
for cluster_id, locations in self._cluster_locations.items():
561-
principals = self._get_cluster_principal_mapping(cluster_id)
587+
for compute_location in self._compute_locations:
588+
principals = self._get_cluster_principal_mapping(compute_location.compute_id, compute_location.compute_type)
562589
if len(principals) == 0:
563590
continue
564-
cluster_usage = self._get_grants(locations, principals, tables, mounts)
591+
cluster_usage = self._get_grants(compute_location.locations, principals, tables, mounts)
565592
grants.update(cluster_usage)
566593
return list(grants)
567594

@@ -628,11 +655,11 @@ def _get_grants(
628655

629656
return grants
630657

631-
def _get_cluster_principal_mapping(self, cluster_id: str) -> list[str]:
658+
def _get_cluster_principal_mapping(self, cluster_id: str, object_type: str) -> list[str]:
632659
# gets all the users,groups,spn which have access to the clusters and returns a dataclass of that mapping
633660
principal_list = []
634661
try:
635-
cluster_permission = self._ws.permissions.get("clusters", cluster_id)
662+
cluster_permission = self._ws.permissions.get(object_type, cluster_id)
636663
except ResourceDoesNotExist:
637664
return []
638665
if cluster_permission.access_control_list is None:
@@ -661,12 +688,12 @@ def apply_location_acl(self):
661688
"CREATE EXTERNAL VOLUME and READ_FILES for existing eligible interactive cluster users"
662689
)
663690
# get the eligible location mapped for each interactive cluster
664-
for cluster_id, locations in self._cluster_locations.items():
691+
for compute_location in self._compute_locations:
665692
# get interactive cluster users
666-
principals = self._get_cluster_principal_mapping(cluster_id)
693+
principals = self._get_cluster_principal_mapping(compute_location.compute_id, compute_location.compute_type)
667694
if len(principals) == 0:
668695
continue
669-
for location_url in locations.keys():
696+
for location_url in compute_location.locations.keys():
670697
# get the location name for the given url
671698
location_name = self._get_location_name(location_url)
672699
if location_name is None:

tests/integration/hive_metastore/test_migrate.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from databricks.sdk.service.compute import DataSecurityMode, AwsAttributes
88
from databricks.sdk.service.catalog import Privilege, SecurableType, TableInfo, TableType
99
from databricks.sdk.service.iam import PermissionLevel
10-
1110
from databricks.labs.ucx.config import WorkspaceConfig
1211
from databricks.labs.ucx.hive_metastore.mapping import Rule, TableMapping
1312
from databricks.labs.ucx.hive_metastore.tables import AclMigrationWhat, Table, What
@@ -564,6 +563,33 @@ def test_migrate_external_tables_with_principal_acl_aws(
564563
assert match
565564

566565

566+
@retried(on=[NotFound], timeout=timedelta(minutes=3))
567+
def test_migrate_external_tables_with_principal_acl_aws_warehouse(
568+
ws, make_user, prepared_principal_acl, make_warehouse_permissions, make_warehouse, env_or_skip
569+
):
570+
if not ws.config.is_aws:
571+
pytest.skip("temporary: only works in aws test env")
572+
ctx, table_full_name, _, _ = prepared_principal_acl
573+
ctx.with_dummy_resource_permission()
574+
warehouse = make_warehouse()
575+
table_migrate = ctx.tables_migrator
576+
user = make_user()
577+
make_warehouse_permissions(
578+
object_id=warehouse.id,
579+
permission_level=PermissionLevel.CAN_USE,
580+
user_name=user.user_name,
581+
)
582+
table_migrate.migrate_tables(what=What.EXTERNAL_SYNC, acl_strategy=[AclMigrationWhat.PRINCIPAL])
583+
584+
target_table_grants = ws.grants.get(SecurableType.TABLE, table_full_name)
585+
match = False
586+
for _ in target_table_grants.privilege_assignments:
587+
if _.principal == user.user_name and _.privileges == [Privilege.ALL_PRIVILEGES]:
588+
match = True
589+
break
590+
assert match
591+
592+
567593
def test_migrate_table_in_mount(
568594
ws,
569595
sql_backend,

tests/unit/assessment/test_azure.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
ClusterSource,
88
DataSecurityMode,
99
)
10-
10+
from databricks.sdk.service.sql import EndpointInfo
1111
from databricks.labs.ucx.assessment.azure import (
1212
AzureServicePrincipalCrawler,
1313
AzureServicePrincipalInfo,
@@ -219,6 +219,13 @@ def test_get_cluster_to_storage_mapping_no_cluster_return_empty():
219219
assert not crawler.get_cluster_to_storage_mapping()
220220

221221

222+
def test_get_warehouse_to_storage_mapping_no_warehouse_return_empty():
223+
ws = create_autospec(WorkspaceClient)
224+
ws.warehouses.list.return_value = []
225+
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")
226+
assert not crawler.get_warehouse_to_storage_mapping()
227+
228+
222229
def test_get_cluster_to_storage_mapping_no_interactive_cluster_return_empty():
223230
ws = mock_workspace_client(cluster_ids=['azure-spn-secret'])
224231
ws.clusters.list.return_value = [
@@ -229,6 +236,15 @@ def test_get_cluster_to_storage_mapping_no_interactive_cluster_return_empty():
229236
assert not crawler.get_cluster_to_storage_mapping()
230237

231238

239+
def test_get_warehouse_to_storage_mapping_no_spn_info_return_empty():
240+
ws = mock_workspace_client()
241+
ws.warehouses.list.return_value = [
242+
EndpointInfo(id="123", name="warehouse1"),
243+
]
244+
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")
245+
assert not crawler.get_warehouse_to_storage_mapping()
246+
247+
232248
def test_get_cluster_to_storage_mapping_interactive_cluster_no_spn_return_empty():
233249
ws = mock_workspace_client(cluster_ids=['azure-spn-secret-interactive-multiple-spn'])
234250

@@ -254,3 +270,29 @@ def test_get_cluster_to_storage_mapping_interactive_cluster_no_spn_return_empty(
254270
assert cluster_spn_info[0].cluster_id == "azure-spn-secret-interactive"
255271
assert len(cluster_spn_info[0].spn_info) == 2
256272
assert cluster_spn_info[0].spn_info == spn_info
273+
274+
275+
def test_get_warehouse_to_storage_mapping_spn():
276+
ws = mock_workspace_client(warehouse_config="spn-secret-config")
277+
ws.warehouses.list.return_value = [
278+
EndpointInfo(id="azure-spn-secret-warehouse", name="warehouse1"),
279+
]
280+
281+
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")
282+
cluster_spn_info = crawler.get_warehouse_to_storage_mapping()
283+
spn_info = {
284+
AzureServicePrincipalInfo(
285+
application_id='Hello, World!',
286+
tenant_id='dummy_tenant_id2',
287+
storage_account='xyz',
288+
),
289+
AzureServicePrincipalInfo(
290+
application_id='dummy_application_id',
291+
tenant_id='dummy_tenant_id',
292+
storage_account='abcde',
293+
),
294+
}
295+
296+
assert cluster_spn_info[0].cluster_id == "azure-spn-secret-warehouse"
297+
assert len(cluster_spn_info[0].spn_info) == 2
298+
assert cluster_spn_info[0].spn_info == spn_info

0 commit comments

Comments
 (0)