4343
4444
4545@dataclass
46- class LocationACL :
47- location_name : str
48- principal : str
46+ class ComputeLocations :
47+ compute_id : str
48+ locations : dict
49+ compute_type : str
4950
5051
5152@dataclass (frozen = True )
@@ -386,6 +387,7 @@ def _get_cluster_to_instance_profile_mapping(self) -> dict[str, str]:
386387 # this function gives a mapping between an interactive cluster and the instance profile used by it
387388 # either directly or through a cluster policy.
388389 cluster_instance_profiles = {}
390+
389391 for cluster in self ._ws .clusters .list ():
390392 if (
391393 cluster .cluster_id is None
@@ -403,14 +405,26 @@ def _get_cluster_to_instance_profile_mapping(self) -> dict[str, str]:
403405
404406 return cluster_instance_profiles
405407
406- def get_eligible_locations_principals (self ) -> dict [str , dict ]:
407- cluster_locations = {}
408- eligible_locations = {}
408+ def _update_warehouse_to_instance_profile_mapping (
409+ self ,
410+ ) -> dict [str , str ]:
411+ warehouse_instance_profiles = {}
412+ sql_config = self ._ws .warehouses .get_workspace_warehouse_config ()
413+ if sql_config .instance_profile_arn is not None :
414+ role_name = sql_config .instance_profile_arn
415+ for warehouse in self ._ws .warehouses .list ():
416+ if warehouse .id is not None :
417+ warehouse_instance_profiles [warehouse .id ] = role_name
418+ return warehouse_instance_profiles
419+
420+ def get_eligible_locations_principals (self ) -> list [ComputeLocations ]:
409421 cluster_instance_profiles = self ._get_cluster_to_instance_profile_mapping ()
410- if len (cluster_instance_profiles ) == 0 :
411- # if there are no interactive clusters , then return empty grants
412- logger .info ("No interactive cluster found with instance profiles configured" )
413- return {}
422+ warehouse_instance_profiles = self ._update_warehouse_to_instance_profile_mapping ()
423+ compute_locations = []
424+ if len (cluster_instance_profiles ) == 0 and len (warehouse_instance_profiles ) == 0 :
425+ # if there are no interactive clusters or warehouse with instance profile , then return empty grants
426+ logger .info ("No interactive cluster or sql warehouse found with instance profiles configured" )
427+ return []
414428 external_locations = list (self ._ws .external_locations .list ())
415429 if len (external_locations ) == 0 :
416430 # if there are no external locations, then throw an error to run migrate_locations cli command
@@ -434,12 +448,17 @@ def get_eligible_locations_principals(self) -> dict[str, dict]:
434448 logger .error (msg )
435449 raise ResourceDoesNotExist (msg ) from None
436450
437- for cluster_id , role_name in cluster_instance_profiles .items ():
438- eligible_locations . update ( self ._get_external_locations (role_name , external_locations , permission_mappings ) )
451+ for cluster_id , role_compute in cluster_instance_profiles .items ():
452+ eligible_locations = self ._get_external_locations (role_compute , external_locations , permission_mappings )
439453 if len (eligible_locations ) == 0 :
440454 continue
441- cluster_locations [cluster_id ] = eligible_locations
442- return cluster_locations
455+ compute_locations .append (ComputeLocations (cluster_id , eligible_locations , "clusters" ))
456+ for warehouse_id , role_compute in warehouse_instance_profiles .items ():
457+ eligible_locations = self ._get_external_locations (role_compute , external_locations , permission_mappings )
458+ if len (eligible_locations ) == 0 :
459+ continue
460+ compute_locations .append (ComputeLocations (warehouse_id , eligible_locations , "warehouses" ))
461+ return compute_locations
443462
444463 @staticmethod
445464 def _get_external_locations (
@@ -475,14 +494,14 @@ def __init__(
475494 self ._spn_crawler = spn_crawler
476495 self ._installation = installation
477496
478- def get_eligible_locations_principals (self ) -> dict [str , dict ]:
479- cluster_locations = {}
480- eligible_locations = {}
497+ def get_eligible_locations_principals (self ) -> list [ComputeLocations ]:
498+ compute_locations = []
481499 spn_cluster_mapping = self ._spn_crawler .get_cluster_to_storage_mapping ()
482- if len (spn_cluster_mapping ) == 0 :
500+ spn_warehouse_mapping = self ._spn_crawler .get_warehouse_to_storage_mapping ()
501+ if len (spn_cluster_mapping ) == 0 and len (spn_warehouse_mapping ) == 0 :
483502 # if there are no interactive clusters , then return empty grants
484503 logger .info ("No interactive cluster found with spn configured" )
485- return {}
504+ return []
486505 external_locations = list (self ._ws .external_locations .list ())
487506 if len (external_locations ) == 0 :
488507 # if there are no external locations, then throw an error to run migrate_locations cli command
@@ -507,10 +526,18 @@ def get_eligible_locations_principals(self) -> dict[str, dict]:
507526 raise ResourceDoesNotExist (msg ) from None
508527
509528 for cluster_spn in spn_cluster_mapping :
529+ eligible_locations = {}
510530 for spn in cluster_spn .spn_info :
511531 eligible_locations .update (self ._get_external_locations (spn , external_locations , permission_mappings ))
512- cluster_locations [cluster_spn .cluster_id ] = eligible_locations
513- return cluster_locations
532+ compute_locations .append (ComputeLocations (cluster_spn .cluster_id , eligible_locations , "clusters" ))
533+
534+ for warehouse_spn in spn_warehouse_mapping :
535+ eligible_locations = {}
536+ for spn in warehouse_spn .spn_info :
537+ eligible_locations .update (self ._get_external_locations (spn , external_locations , permission_mappings ))
538+ compute_locations .append (ComputeLocations (warehouse_spn .cluster_id , eligible_locations , "warehouses" ))
539+
540+ return compute_locations
514541
515542 def _get_external_locations (
516543 self ,
@@ -543,25 +570,25 @@ def __init__(
543570 installation : Installation ,
544571 tables_crawler : TablesCrawler ,
545572 mounts_crawler : Mounts ,
546- cluster_locations : dict [ str , dict ],
573+ cluster_locations : list [ ComputeLocations ],
547574 ):
548575 self ._backend = backend
549576 self ._ws = ws
550577 self ._installation = installation
551578 self ._tables_crawler = tables_crawler
552579 self ._mounts_crawler = mounts_crawler
553- self ._cluster_locations = cluster_locations
580+ self ._compute_locations = cluster_locations
554581
555582 def get_interactive_cluster_grants (self ) -> list [Grant ]:
556583 tables = self ._tables_crawler .snapshot ()
557584 mounts = list (self ._mounts_crawler .snapshot ())
558585 grants : set [Grant ] = set ()
559586
560- for cluster_id , locations in self ._cluster_locations . items () :
561- principals = self ._get_cluster_principal_mapping (cluster_id )
587+ for compute_location in self ._compute_locations :
588+ principals = self ._get_cluster_principal_mapping (compute_location . compute_id , compute_location . compute_type )
562589 if len (principals ) == 0 :
563590 continue
564- cluster_usage = self ._get_grants (locations , principals , tables , mounts )
591+ cluster_usage = self ._get_grants (compute_location . locations , principals , tables , mounts )
565592 grants .update (cluster_usage )
566593 return list (grants )
567594
@@ -628,11 +655,11 @@ def _get_grants(
628655
629656 return grants
630657
631- def _get_cluster_principal_mapping (self , cluster_id : str ) -> list [str ]:
658+ def _get_cluster_principal_mapping (self , cluster_id : str , object_type : str ) -> list [str ]:
632659 # gets all the users,groups,spn which have access to the clusters and returns a dataclass of that mapping
633660 principal_list = []
634661 try :
635- cluster_permission = self ._ws .permissions .get ("clusters" , cluster_id )
662+ cluster_permission = self ._ws .permissions .get (object_type , cluster_id )
636663 except ResourceDoesNotExist :
637664 return []
638665 if cluster_permission .access_control_list is None :
@@ -661,12 +688,12 @@ def apply_location_acl(self):
661688 "CREATE EXTERNAL VOLUME and READ_FILES for existing eligible interactive cluster users"
662689 )
663690 # get the eligible location mapped for each interactive cluster
664- for cluster_id , locations in self ._cluster_locations . items () :
691+ for compute_location in self ._compute_locations :
665692 # get interactive cluster users
666- principals = self ._get_cluster_principal_mapping (cluster_id )
693+ principals = self ._get_cluster_principal_mapping (compute_location . compute_id , compute_location . compute_type )
667694 if len (principals ) == 0 :
668695 continue
669- for location_url in locations .keys ():
696+ for location_url in compute_location . locations .keys ():
670697 # get the location name for the given url
671698 location_name = self ._get_location_name (location_url )
672699 if location_name is None :
0 commit comments