11import functools
2- import json
32import logging
43import os
54import re
7271from databricks .labs .ucx .hive_metastore .table_size import TableSize
7372from databricks .labs .ucx .hive_metastore .tables import Table , TableError
7473from databricks .labs .ucx .installer .hms_lineage import HiveMetastoreLineageEnabler
74+ from databricks .labs .ucx .installer .policy import ClusterPolicyInstaller
7575from databricks .labs .ucx .runtime import main
7676from databricks .labs .ucx .workspace_access .base import Permissions
7777from databricks .labs .ucx .workspace_access .generic import WorkspaceObjectInfo
@@ -178,6 +178,7 @@ def __init__(self, prompts: Prompts, installation: Installation, ws: WorkspaceCl
178178 self ._ws = ws
179179 self ._installation = installation
180180 self ._prompts = prompts
181+ self ._policy_installer = ClusterPolicyInstaller (installation , ws , prompts )
181182
182183 def run (
183184 self ,
@@ -244,21 +245,7 @@ def _configure_new_installation(self) -> WorkspaceConfig:
244245 log_level = self ._prompts .question ("Log level" , default = "INFO" ).upper ()
245246 num_threads = int (self ._prompts .question ("Number of threads" , default = "8" , valid_number = True ))
246247
247- # Checking for external HMS
248- instance_profile = None
249- spark_conf_dict = {}
250- policies_with_external_hms = list (self ._get_cluster_policies_with_external_hive_metastores ())
251- if len (policies_with_external_hms ) > 0 and self ._prompts .confirm (
252- "We have identified one or more cluster policies set up for an external metastore"
253- "Would you like to set UCX to connect to the external metastore?"
254- ):
255- logger .info ("Setting up an external metastore" )
256- cluster_policies = {conf .name : conf .definition for conf in policies_with_external_hms }
257- if len (cluster_policies ) >= 1 :
258- cluster_policy = json .loads (self ._prompts .choice_from_dict ("Choose a cluster policy" , cluster_policies ))
259- instance_profile , spark_conf_dict = self ._get_ext_hms_conf_from_policy (cluster_policy )
260-
261- policy_id = self ._create_cluster_policy (inventory_database , spark_conf_dict , instance_profile )
248+ policy_id , instance_profile , spark_conf_dict = self ._policy_installer .create (inventory_database )
262249
263250 # Check if terraform is being used
264251 is_terraform_used = self ._prompts .confirm ("Do you use Terraform to deploy your infrastructure?" )
@@ -318,83 +305,6 @@ def warehouse_type(_):
318305 warehouse_id = new_warehouse .id
319306 return warehouse_id
320307
321- @staticmethod
322- def _policy_config (value : str ):
323- return {"type" : "fixed" , "value" : value }
324-
325- def _create_cluster_policy (
326- self , inventory_database : str , spark_conf : dict , instance_profile : str | None
327- ) -> str | None :
328- policy_name = f"Unity Catalog Migration ({ inventory_database } ) ({ self ._ws .current_user .me ().user_name } )"
329- policies = self ._ws .cluster_policies .list ()
330- policy_id = None
331- for policy in policies :
332- if policy .name == policy_name :
333- policy_id = policy .policy_id
334- logger .info (f"Cluster policy { policy_name } already present, reusing the same." )
335- break
336- if not policy_id :
337- logger .info ("Creating UCX cluster policy." )
338- policy_id = self ._ws .cluster_policies .create (
339- name = policy_name ,
340- definition = self ._cluster_policy_definition (conf = spark_conf , instance_profile = instance_profile ),
341- description = "Custom cluster policy for Unity Catalog Migration (UCX)" ,
342- ).policy_id
343- return policy_id
344-
345- def _cluster_policy_definition (self , conf : dict , instance_profile : str | None ) -> str :
346- policy_definition = {
347- "spark_version" : self ._policy_config (self ._ws .clusters .select_spark_version (latest = True )),
348- "node_type_id" : self ._policy_config (self ._ws .clusters .select_node_type (local_disk = True )),
349- }
350- if conf :
351- for key , value in conf .items ():
352- policy_definition [f"spark_conf.{ key } " ] = self ._policy_config (value )
353- if self ._ws .config .is_aws :
354- policy_definition ["aws_attributes.availability" ] = self ._policy_config (
355- compute .AwsAvailability .ON_DEMAND .value
356- )
357- if instance_profile :
358- policy_definition ["aws_attributes.instance_profile_arn" ] = self ._policy_config (instance_profile )
359- elif self ._ws .config .is_azure : # pylint: disable=confusing-consecutive-elif
360- policy_definition ["azure_attributes.availability" ] = self ._policy_config (
361- compute .AzureAvailability .ON_DEMAND_AZURE .value
362- )
363- else :
364- policy_definition ["gcp_attributes.availability" ] = self ._policy_config (
365- compute .GcpAvailability .ON_DEMAND_GCP .value
366- )
367- return json .dumps (policy_definition )
368-
369- @staticmethod
370- def _get_ext_hms_conf_from_policy (cluster_policy ):
371- spark_conf_dict = {}
372- instance_profile = None
373- if cluster_policy .get ("aws_attributes.instance_profile_arn" ) is not None :
374- instance_profile = cluster_policy .get ("aws_attributes.instance_profile_arn" ).get ("value" )
375- logger .info (f"Instance Profile is Set to { instance_profile } " )
376- for key in cluster_policy .keys ():
377- if (
378- key .startswith ("spark_conf.spark.sql.hive.metastore" )
379- or key .startswith ("spark_conf.spark.hadoop.javax.jdo.option" )
380- or key .startswith ("spark_conf.spark.databricks.hive.metastore" )
381- or key .startswith ("spark_conf.spark.hadoop.hive.metastore.glue" )
382- ):
383- spark_conf_dict [key [11 :]] = cluster_policy [key ]["value" ]
384- return instance_profile , spark_conf_dict
385-
386- def _get_cluster_policies_with_external_hive_metastores (self ):
387- for policy in self ._ws .cluster_policies .list ():
388- def_json = json .loads (policy .definition )
389- glue_node = def_json .get ("spark_conf.spark.databricks.hive.metastore.glueCatalog.enabled" )
390- if glue_node is not None and glue_node .get ("value" ) == "true" :
391- yield policy
392- continue
393- for key in def_json .keys ():
394- if key .startswith ("spark_conf.spark.sql.hive.metastore" ):
395- yield policy
396- break
397-
398308
399309class WorkspaceInstallation :
400310 def __init__ (
@@ -625,35 +535,16 @@ def _upload_wheel(self):
625535 self ._installation .save (self ._config )
626536 return self ._wheels .upload_to_wsfs ()
627537
628- def _upload_cluster_policy (self , remote_wheel : str ):
629- try :
630- if self .config .policy_id is None :
631- msg = "Cluster policy not present, please uninstall and reinstall ucx completely."
632- raise InvalidParameterValue (msg )
633- policy = self ._ws .cluster_policies .get (policy_id = self .config .policy_id )
634- except NotFound as err :
635- msg = f"UCX Policy { self .config .policy_id } not found, please reinstall UCX"
636- logger .error (msg )
637- raise NotFound (msg ) from err
638- if policy .name is not None :
639- self ._ws .cluster_policies .edit (
640- policy_id = self .config .policy_id ,
641- name = policy .name ,
642- definition = policy .definition ,
643- libraries = [compute .Library (whl = f"dbfs:{ remote_wheel } " )],
644- )
645-
646538 def create_jobs (self ):
647539 logger .debug (f"Creating jobs from tasks in { main .__name__ } " )
648540 remote_wheel = self ._upload_wheel ()
649- self ._upload_cluster_policy (remote_wheel )
650541 desired_steps = {t .workflow for t in _TASKS .values () if t .cloud_compatible (self ._ws .config )}
651542 wheel_runner = None
652543
653544 if self ._config .override_clusters :
654545 wheel_runner = self ._upload_wheel_runner (remote_wheel )
655546 for step_name in desired_steps :
656- settings = self ._job_settings (step_name )
547+ settings = self ._job_settings (step_name , remote_wheel )
657548 if self ._config .override_clusters :
658549 settings = self ._apply_cluster_overrides (settings , self ._config .override_clusters , wheel_runner )
659550 self ._deploy_workflow (step_name , settings )
@@ -753,7 +644,7 @@ def _create_debug(self, remote_wheel: str):
753644 ).encode ("utf8" )
754645 self ._installation .upload ('DEBUG.py' , content )
755646
756- def _job_settings (self , step_name : str ):
647+ def _job_settings (self , step_name : str , remote_wheel : str ):
757648 email_notifications = None
758649 if not self ._config .override_clusters and "@" in self ._my_username :
759650 # set email notifications only if we're running the real
@@ -772,7 +663,7 @@ def _job_settings(self, step_name: str):
772663 "tags" : {"version" : f"v{ version } " },
773664 "job_clusters" : self ._job_clusters ({t .job_cluster for t in tasks }),
774665 "email_notifications" : email_notifications ,
775- "tasks" : [self ._job_task (task ) for task in tasks ],
666+ "tasks" : [self ._job_task (task , remote_wheel ) for task in tasks ],
776667 }
777668
778669 def _upload_wheel_runner (self , remote_wheel : str ):
@@ -796,7 +687,7 @@ def _apply_cluster_overrides(settings: dict[str, Any], overrides: dict[str, str]
796687 job_task .notebook_task = jobs .NotebookTask (notebook_path = wheel_runner , base_parameters = params )
797688 return settings
798689
799- def _job_task (self , task : Task ) -> jobs .Task :
690+ def _job_task (self , task : Task , remote_wheel : str ) -> jobs .Task :
800691 jobs_task = jobs .Task (
801692 task_key = task .name ,
802693 job_cluster_key = task .job_cluster ,
@@ -809,7 +700,7 @@ def _job_task(self, task: Task) -> jobs.Task:
809700 return retried_job_dashboard_task (jobs_task , task )
810701 if task .notebook :
811702 return self ._job_notebook_task (jobs_task , task )
812- return self ._job_wheel_task (jobs_task , task )
703+ return self ._job_wheel_task (jobs_task , task , remote_wheel )
813704
814705 def _job_dashboard_task (self , jobs_task : jobs .Task , task : Task ) -> jobs .Task :
815706 assert task .dashboard is not None
@@ -841,10 +732,11 @@ def _job_notebook_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task:
841732 ),
842733 )
843734
844- def _job_wheel_task (self , jobs_task : jobs .Task , task : Task ) -> jobs .Task :
735+ def _job_wheel_task (self , jobs_task : jobs .Task , task : Task , remote_wheel : str ) -> jobs .Task :
845736 return replace (
846737 jobs_task ,
847738 # TODO: check when we can install wheels from WSFS properly
739+ libraries = [compute .Library (whl = f"dbfs:{ remote_wheel } " )],
848740 python_wheel_task = jobs .PythonWheelTask (
849741 package_name = "databricks_labs_ucx" ,
850742 entry_point = "runtime" , # [project.entry-points.databricks] in pyproject.toml
0 commit comments