@@ -242,11 +242,12 @@ def warehouse_type(_):
242242 cluster_policy = json .loads (self ._prompts .choice_from_dict ("Choose a cluster policy" , cluster_policies ))
243243 instance_profile , spark_conf_dict = self ._get_ext_hms_conf_from_policy (cluster_policy )
244244
245- if self ._prompts .confirm ("Do you want to follow a policy to create clusters?" ):
246- cluster_policies_list = {f"{ _ .name } ({ _ .policy_id } )" : _ .policy_id for _ in self ._ws .cluster_policies .list ()}
247- custom_cluster_policy_id = self ._prompts .choice_from_dict ("Choose a cluster policy" , cluster_policies_list )
248- else :
249- custom_cluster_policy_id = None
245+ logger .info ("Creating UCX cluster policy." )
246+ policy_id = self ._ws .cluster_policies .create (
247+ name = f"Unity Catalog Migration ({ inventory_database } )" ,
248+ definition = self ._cluster_policy_definition (conf = spark_conf_dict , instance_profile = instance_profile ),
249+ description = "Custom cluster policy for Unity Catalog Migration (UCX)" ,
250+ ).policy_id
250251
251252 config = WorkspaceConfig (
252253 inventory_database = inventory_database ,
@@ -261,13 +262,41 @@ def warehouse_type(_):
261262 num_threads = num_threads ,
262263 instance_profile = instance_profile ,
263264 spark_conf = spark_conf_dict ,
264- custom_cluster_policy_id = custom_cluster_policy_id ,
265+ policy_id = policy_id ,
265266 )
266267 ws_file_url = self ._installation .save (config )
267268 if self ._prompts .confirm ("Open config file in the browser and continue installing?" ):
268269 webbrowser .open (ws_file_url )
269270 return config
270271
272+ @staticmethod
273+ def _policy_config (value : str ):
274+ return {"type" : "fixed" , "value" : value }
275+
276+ def _cluster_policy_definition (self , conf : dict , instance_profile : str | None ) -> str :
277+ policy_definition = {
278+ "spark_version" : self ._policy_config (self ._ws .clusters .select_spark_version (latest = True )),
279+ "node_type_id" : self ._policy_config (self ._ws .clusters .select_node_type (local_disk = True )),
280+ }
281+ if conf :
282+ for key , value in conf .items ():
283+ policy_definition [f"spark_conf.{ key } " ] = self ._policy_config (value )
284+ if self ._ws .config .is_aws :
285+ policy_definition ["aws_attributes.availability" ] = self ._policy_config (
286+ compute .AwsAvailability .ON_DEMAND .value
287+ )
288+ if instance_profile :
289+ policy_definition ["aws_attributes.instance_profile_arn" ] = self ._policy_config (instance_profile )
290+ elif self ._ws .config .is_azure :
291+ policy_definition ["azure_attributes.availability" ] = self ._policy_config (
292+ compute .AzureAvailability .ON_DEMAND_AZURE .value
293+ )
294+ else :
295+ policy_definition ["gcp_attributes.availability" ] = self ._policy_config (
296+ compute .GcpAvailability .ON_DEMAND_GCP .value
297+ )
298+ return json .dumps (policy_definition )
299+
271300 @staticmethod
272301 def _get_ext_hms_conf_from_policy (cluster_policy ):
273302 spark_conf_dict = {}
@@ -277,7 +306,7 @@ def _get_ext_hms_conf_from_policy(cluster_policy):
277306 logger .info (f"Instance Profile is Set to { instance_profile } " )
278307 for key in cluster_policy .keys ():
279308 if (
280- key .startswith ("spark_conf.sql.hive.metastore" )
309+ key .startswith ("spark_conf.spark. sql.hive.metastore" )
281310 or key .startswith ("spark_conf.spark.hadoop.javax.jdo.option" )
282311 or key .startswith ("spark_conf.spark.databricks.hive.metastore" )
283312 or key .startswith ("spark_conf.spark.hadoop.hive.metastore.glue" )
@@ -293,7 +322,7 @@ def _get_cluster_policies_with_external_hive_metastores(self):
293322 yield policy
294323 continue
295324 for key in def_json .keys ():
296- if key .startswith ("spark_config .spark.sql.hive.metastore" ):
325+ if key .startswith ("spark_conf .spark.sql.hive.metastore" ):
297326 yield policy
298327 break
299328
@@ -512,13 +541,26 @@ def _upload_wheel(self):
512541 def create_jobs (self ):
513542 logger .debug (f"Creating jobs from tasks in { main .__name__ } " )
514543 remote_wheel = self ._upload_wheel ()
544+ try :
545+ policy_definition = self ._ws .cluster_policies .get (policy_id = self .config .policy_id ).definition
546+ except NotFound as e :
547+ msg = f"UCX Policy { self .config .policy_id } not found, please reinstall UCX"
548+ logger .error (msg )
549+ raise NotFound (msg ) from e
550+
551+ self ._ws .cluster_policies .edit (
552+ policy_id = self .config .policy_id ,
553+ name = f"Unity Catalog Migration ({ self .config .inventory_database } )" ,
554+ definition = policy_definition ,
555+ libraries = [compute .Library (whl = f"dbfs:{ remote_wheel } " )],
556+ )
515557 desired_steps = {t .workflow for t in _TASKS .values () if t .cloud_compatible (self ._ws .config )}
516558 wheel_runner = None
517559
518560 if self ._config .override_clusters :
519561 wheel_runner = self ._upload_wheel_runner (remote_wheel )
520562 for step_name in desired_steps :
521- settings = self ._job_settings (step_name , remote_wheel )
563+ settings = self ._job_settings (step_name )
522564 if self ._config .override_clusters :
523565 settings = self ._apply_cluster_overrides (settings , self ._config .override_clusters , wheel_runner )
524566 self ._deploy_workflow (step_name , settings )
@@ -618,7 +660,7 @@ def _create_debug(self, remote_wheel: str):
618660 ).encode ("utf8" )
619661 self ._installation .upload ('DEBUG.py' , content )
620662
621- def _job_settings (self , step_name : str , remote_wheel : str ):
663+ def _job_settings (self , step_name : str ):
622664 email_notifications = None
623665 if not self ._config .override_clusters and "@" in self ._my_username :
624666 # set email notifications only if we're running the real
@@ -637,7 +679,7 @@ def _job_settings(self, step_name: str, remote_wheel: str):
637679 "tags" : {"version" : f"v{ version } " },
638680 "job_clusters" : self ._job_clusters ({t .job_cluster for t in tasks }),
639681 "email_notifications" : email_notifications ,
640- "tasks" : [self ._job_task (task , remote_wheel ) for task in tasks ],
682+ "tasks" : [self ._job_task (task ) for task in tasks ],
641683 }
642684
643685 def _upload_wheel_runner (self , remote_wheel : str ):
@@ -661,7 +703,7 @@ def _apply_cluster_overrides(settings: dict[str, Any], overrides: dict[str, str]
661703 job_task .notebook_task = jobs .NotebookTask (notebook_path = wheel_runner , base_parameters = params )
662704 return settings
663705
664- def _job_task (self , task : Task , remote_wheel : str ) -> jobs .Task :
706+ def _job_task (self , task : Task ) -> jobs .Task :
665707 jobs_task = jobs .Task (
666708 task_key = task .name ,
667709 job_cluster_key = task .job_cluster ,
@@ -674,7 +716,7 @@ def _job_task(self, task: Task, remote_wheel: str) -> jobs.Task:
674716 return retried_job_dashboard_task (jobs_task , task )
675717 if task .notebook :
676718 return self ._job_notebook_task (jobs_task , task )
677- return self ._job_wheel_task (jobs_task , task , remote_wheel )
719+ return self ._job_wheel_task (jobs_task , task )
678720
679721 def _job_dashboard_task (self , jobs_task : jobs .Task , task : Task ) -> jobs .Task :
680722 assert task .dashboard is not None
@@ -706,11 +748,10 @@ def _job_notebook_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task:
706748 ),
707749 )
708750
709- def _job_wheel_task (self , jobs_task : jobs .Task , task : Task , remote_wheel : str ) -> jobs .Task :
751+ def _job_wheel_task (self , jobs_task : jobs .Task , task : Task ) -> jobs .Task :
710752 return replace (
711753 jobs_task ,
712754 # TODO: check when we can install wheels from WSFS properly
713- libraries = [compute .Library (whl = f"dbfs:{ remote_wheel } " )],
714755 python_wheel_task = jobs .PythonWheelTask (
715756 package_name = "databricks_labs_ucx" ,
716757 entry_point = "runtime" , # [project.entry-points.databricks] in pyproject.toml
@@ -726,21 +767,13 @@ def _job_clusters(self, names: set[str]):
726767 }
727768 if self ._config .spark_conf is not None :
728769 spark_conf = spark_conf | self ._config .spark_conf
729- spec = self ._cluster_node_type (
730- compute .ClusterSpec (
731- spark_version = self ._ws .clusters .select_spark_version (latest = True ),
732- data_security_mode = compute .DataSecurityMode .LEGACY_SINGLE_USER ,
733- spark_conf = spark_conf ,
734- custom_tags = {"ResourceClass" : "SingleNode" },
735- num_workers = 0 ,
736- )
770+ spec = compute .ClusterSpec (
771+ data_security_mode = compute .DataSecurityMode .LEGACY_SINGLE_USER ,
772+ spark_conf = spark_conf ,
773+ custom_tags = {"ResourceClass" : "SingleNode" },
774+ num_workers = 0 ,
775+ policy_id = self .config .policy_id ,
737776 )
738- if self ._config .custom_cluster_policy_id is not None :
739- spec = replace (spec , policy_id = self ._config .custom_cluster_policy_id )
740- if self ._ws .config .is_aws and spec .aws_attributes is not None :
741- # TODO: we might not need spec.aws_attributes, if we have a cluster policy
742- aws_attributes = replace (spec .aws_attributes , instance_profile_arn = self ._config .instance_profile )
743- spec = replace (spec , aws_attributes = aws_attributes )
744777 if "main" in names :
745778 clusters .append (
746779 jobs .JobCluster (
@@ -763,41 +796,6 @@ def _job_clusters(self, names: set[str]):
763796 )
764797 return clusters
765798
766- def _cluster_node_type (self , spec : compute .ClusterSpec ) -> compute .ClusterSpec :
767- cfg = self ._config
768- valid_node_type = False
769- if cfg .custom_cluster_policy_id is not None :
770- if self ._check_policy_has_instance_pool (cfg .custom_cluster_policy_id ):
771- valid_node_type = True
772- if not valid_node_type :
773- if cfg .instance_pool_id is not None :
774- return replace (spec , instance_pool_id = cfg .instance_pool_id )
775- spec = replace (spec , node_type_id = self ._ws .clusters .select_node_type (local_disk = True ))
776- if self ._ws .config .is_aws :
777- return replace (spec , aws_attributes = compute .AwsAttributes (availability = compute .AwsAvailability .ON_DEMAND ))
778- if self ._ws .config .is_azure :
779- return replace (
780- spec , azure_attributes = compute .AzureAttributes (availability = compute .AzureAvailability .ON_DEMAND_AZURE )
781- )
782- return replace (spec , gcp_attributes = compute .GcpAttributes (availability = compute .GcpAvailability .ON_DEMAND_GCP ))
783-
784- def _check_policy_has_instance_pool (self , policy_id ):
785- try :
786- policy = self ._ws .cluster_policies .get (policy_id = policy_id )
787- except NotFound :
788- logger .warning (f"removed on the backend { policy_id } " )
789- return False
790- def_json = json .loads (policy .definition )
791- instance_pool = def_json .get ("instance_pool_id" )
792- if instance_pool is not None :
793- return True
794- return False
795-
796- def _instance_profiles (self ):
797- return {"No Instance Profile" : None } | {
798- profile .instance_profile_arn : profile .instance_profile_arn for profile in self ._ws .instance_profiles .list ()
799- }
800-
801799 @staticmethod
802800 def _readable_timedelta (epoch ):
803801 when = datetime .fromtimestamp (epoch )
@@ -899,6 +897,7 @@ def uninstall(self):
899897 self ._remove_database ()
900898 self ._remove_jobs ()
901899 self ._remove_warehouse ()
900+ self ._remove_policies ()
902901 self ._installation .remove ()
903902 logger .info ("UnInstalling UCX complete" )
904903
@@ -911,6 +910,13 @@ def _remove_database(self):
911910 deployer = SchemaDeployer (self ._sql_backend , self ._config .inventory_database , Any )
912911 deployer .delete_schema ()
913912
913+ def _remove_policies (self ):
914+ logger .info ("Deleting cluster policy" )
915+ try :
916+ self ._ws .cluster_policies .delete (policy_id = self .config .policy_id )
917+ except NotFound :
918+ logger .error ("UCX Policy already deleted" )
919+
914920 def _remove_jobs (self ):
915921 logger .info ("Deleting jobs" )
916922 if not self ._state .jobs :
0 commit comments