Skip to content

Commit fbfcdb8

Browse files
HariGS-DBnfx
andauthored
Added automated upgrade option to set up cluster policy (#1024)
## Changes This PR has the following changes: - Separating the cluster creation policy from install.py to installer. policy.py - adding an upgrade script to set up cluster policy for older versions of ucx - removing reference to libraries ### Linked issues Resolves #1023 Resolves #1012 ### Functionality - [ ] added relevant user documentation - [ ] added new CLI command - [ ] modified existing command: `databricks labs ucx ...` - [ ] added a new workflow - [ ] modified existing workflow: `...` - [ ] added a new table - [ ] modified existing table: `...` ### Tests - [X] manually tested - [X] added unit tests - [X] added integration tests - [ ] verified on staging environment (screenshot attached) --------- Co-authored-by: Serge Smertin <[email protected]>
1 parent bdafc44 commit fbfcdb8

File tree

5 files changed

+401
-439
lines changed

5 files changed

+401
-439
lines changed

src/databricks/labs/ucx/install.py

Lines changed: 10 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import functools
2-
import json
32
import logging
43
import os
54
import re
@@ -72,6 +71,7 @@
7271
from databricks.labs.ucx.hive_metastore.table_size import TableSize
7372
from databricks.labs.ucx.hive_metastore.tables import Table, TableError
7473
from databricks.labs.ucx.installer.hms_lineage import HiveMetastoreLineageEnabler
74+
from databricks.labs.ucx.installer.policy import ClusterPolicyInstaller
7575
from databricks.labs.ucx.runtime import main
7676
from databricks.labs.ucx.workspace_access.base import Permissions
7777
from databricks.labs.ucx.workspace_access.generic import WorkspaceObjectInfo
@@ -178,6 +178,7 @@ def __init__(self, prompts: Prompts, installation: Installation, ws: WorkspaceCl
178178
self._ws = ws
179179
self._installation = installation
180180
self._prompts = prompts
181+
self._policy_installer = ClusterPolicyInstaller(installation, ws, prompts)
181182

182183
def run(
183184
self,
@@ -244,21 +245,7 @@ def _configure_new_installation(self) -> WorkspaceConfig:
244245
log_level = self._prompts.question("Log level", default="INFO").upper()
245246
num_threads = int(self._prompts.question("Number of threads", default="8", valid_number=True))
246247

247-
# Checking for external HMS
248-
instance_profile = None
249-
spark_conf_dict = {}
250-
policies_with_external_hms = list(self._get_cluster_policies_with_external_hive_metastores())
251-
if len(policies_with_external_hms) > 0 and self._prompts.confirm(
252-
"We have identified one or more cluster policies set up for an external metastore"
253-
"Would you like to set UCX to connect to the external metastore?"
254-
):
255-
logger.info("Setting up an external metastore")
256-
cluster_policies = {conf.name: conf.definition for conf in policies_with_external_hms}
257-
if len(cluster_policies) >= 1:
258-
cluster_policy = json.loads(self._prompts.choice_from_dict("Choose a cluster policy", cluster_policies))
259-
instance_profile, spark_conf_dict = self._get_ext_hms_conf_from_policy(cluster_policy)
260-
261-
policy_id = self._create_cluster_policy(inventory_database, spark_conf_dict, instance_profile)
248+
policy_id, instance_profile, spark_conf_dict = self._policy_installer.create(inventory_database)
262249

263250
# Check if terraform is being used
264251
is_terraform_used = self._prompts.confirm("Do you use Terraform to deploy your infrastructure?")
@@ -318,83 +305,6 @@ def warehouse_type(_):
318305
warehouse_id = new_warehouse.id
319306
return warehouse_id
320307

321-
@staticmethod
322-
def _policy_config(value: str):
323-
return {"type": "fixed", "value": value}
324-
325-
def _create_cluster_policy(
326-
self, inventory_database: str, spark_conf: dict, instance_profile: str | None
327-
) -> str | None:
328-
policy_name = f"Unity Catalog Migration ({inventory_database}) ({self._ws.current_user.me().user_name})"
329-
policies = self._ws.cluster_policies.list()
330-
policy_id = None
331-
for policy in policies:
332-
if policy.name == policy_name:
333-
policy_id = policy.policy_id
334-
logger.info(f"Cluster policy {policy_name} already present, reusing the same.")
335-
break
336-
if not policy_id:
337-
logger.info("Creating UCX cluster policy.")
338-
policy_id = self._ws.cluster_policies.create(
339-
name=policy_name,
340-
definition=self._cluster_policy_definition(conf=spark_conf, instance_profile=instance_profile),
341-
description="Custom cluster policy for Unity Catalog Migration (UCX)",
342-
).policy_id
343-
return policy_id
344-
345-
def _cluster_policy_definition(self, conf: dict, instance_profile: str | None) -> str:
346-
policy_definition = {
347-
"spark_version": self._policy_config(self._ws.clusters.select_spark_version(latest=True)),
348-
"node_type_id": self._policy_config(self._ws.clusters.select_node_type(local_disk=True)),
349-
}
350-
if conf:
351-
for key, value in conf.items():
352-
policy_definition[f"spark_conf.{key}"] = self._policy_config(value)
353-
if self._ws.config.is_aws:
354-
policy_definition["aws_attributes.availability"] = self._policy_config(
355-
compute.AwsAvailability.ON_DEMAND.value
356-
)
357-
if instance_profile:
358-
policy_definition["aws_attributes.instance_profile_arn"] = self._policy_config(instance_profile)
359-
elif self._ws.config.is_azure: # pylint: disable=confusing-consecutive-elif
360-
policy_definition["azure_attributes.availability"] = self._policy_config(
361-
compute.AzureAvailability.ON_DEMAND_AZURE.value
362-
)
363-
else:
364-
policy_definition["gcp_attributes.availability"] = self._policy_config(
365-
compute.GcpAvailability.ON_DEMAND_GCP.value
366-
)
367-
return json.dumps(policy_definition)
368-
369-
@staticmethod
370-
def _get_ext_hms_conf_from_policy(cluster_policy):
371-
spark_conf_dict = {}
372-
instance_profile = None
373-
if cluster_policy.get("aws_attributes.instance_profile_arn") is not None:
374-
instance_profile = cluster_policy.get("aws_attributes.instance_profile_arn").get("value")
375-
logger.info(f"Instance Profile is Set to {instance_profile}")
376-
for key in cluster_policy.keys():
377-
if (
378-
key.startswith("spark_conf.spark.sql.hive.metastore")
379-
or key.startswith("spark_conf.spark.hadoop.javax.jdo.option")
380-
or key.startswith("spark_conf.spark.databricks.hive.metastore")
381-
or key.startswith("spark_conf.spark.hadoop.hive.metastore.glue")
382-
):
383-
spark_conf_dict[key[11:]] = cluster_policy[key]["value"]
384-
return instance_profile, spark_conf_dict
385-
386-
def _get_cluster_policies_with_external_hive_metastores(self):
387-
for policy in self._ws.cluster_policies.list():
388-
def_json = json.loads(policy.definition)
389-
glue_node = def_json.get("spark_conf.spark.databricks.hive.metastore.glueCatalog.enabled")
390-
if glue_node is not None and glue_node.get("value") == "true":
391-
yield policy
392-
continue
393-
for key in def_json.keys():
394-
if key.startswith("spark_conf.spark.sql.hive.metastore"):
395-
yield policy
396-
break
397-
398308

399309
class WorkspaceInstallation:
400310
def __init__(
@@ -625,35 +535,16 @@ def _upload_wheel(self):
625535
self._installation.save(self._config)
626536
return self._wheels.upload_to_wsfs()
627537

628-
def _upload_cluster_policy(self, remote_wheel: str):
629-
try:
630-
if self.config.policy_id is None:
631-
msg = "Cluster policy not present, please uninstall and reinstall ucx completely."
632-
raise InvalidParameterValue(msg)
633-
policy = self._ws.cluster_policies.get(policy_id=self.config.policy_id)
634-
except NotFound as err:
635-
msg = f"UCX Policy {self.config.policy_id} not found, please reinstall UCX"
636-
logger.error(msg)
637-
raise NotFound(msg) from err
638-
if policy.name is not None:
639-
self._ws.cluster_policies.edit(
640-
policy_id=self.config.policy_id,
641-
name=policy.name,
642-
definition=policy.definition,
643-
libraries=[compute.Library(whl=f"dbfs:{remote_wheel}")],
644-
)
645-
646538
def create_jobs(self):
647539
logger.debug(f"Creating jobs from tasks in {main.__name__}")
648540
remote_wheel = self._upload_wheel()
649-
self._upload_cluster_policy(remote_wheel)
650541
desired_steps = {t.workflow for t in _TASKS.values() if t.cloud_compatible(self._ws.config)}
651542
wheel_runner = None
652543

653544
if self._config.override_clusters:
654545
wheel_runner = self._upload_wheel_runner(remote_wheel)
655546
for step_name in desired_steps:
656-
settings = self._job_settings(step_name)
547+
settings = self._job_settings(step_name, remote_wheel)
657548
if self._config.override_clusters:
658549
settings = self._apply_cluster_overrides(settings, self._config.override_clusters, wheel_runner)
659550
self._deploy_workflow(step_name, settings)
@@ -753,7 +644,7 @@ def _create_debug(self, remote_wheel: str):
753644
).encode("utf8")
754645
self._installation.upload('DEBUG.py', content)
755646

756-
def _job_settings(self, step_name: str):
647+
def _job_settings(self, step_name: str, remote_wheel: str):
757648
email_notifications = None
758649
if not self._config.override_clusters and "@" in self._my_username:
759650
# set email notifications only if we're running the real
@@ -772,7 +663,7 @@ def _job_settings(self, step_name: str):
772663
"tags": {"version": f"v{version}"},
773664
"job_clusters": self._job_clusters({t.job_cluster for t in tasks}),
774665
"email_notifications": email_notifications,
775-
"tasks": [self._job_task(task) for task in tasks],
666+
"tasks": [self._job_task(task, remote_wheel) for task in tasks],
776667
}
777668

778669
def _upload_wheel_runner(self, remote_wheel: str):
@@ -796,7 +687,7 @@ def _apply_cluster_overrides(settings: dict[str, Any], overrides: dict[str, str]
796687
job_task.notebook_task = jobs.NotebookTask(notebook_path=wheel_runner, base_parameters=params)
797688
return settings
798689

799-
def _job_task(self, task: Task) -> jobs.Task:
690+
def _job_task(self, task: Task, remote_wheel: str) -> jobs.Task:
800691
jobs_task = jobs.Task(
801692
task_key=task.name,
802693
job_cluster_key=task.job_cluster,
@@ -809,7 +700,7 @@ def _job_task(self, task: Task) -> jobs.Task:
809700
return retried_job_dashboard_task(jobs_task, task)
810701
if task.notebook:
811702
return self._job_notebook_task(jobs_task, task)
812-
return self._job_wheel_task(jobs_task, task)
703+
return self._job_wheel_task(jobs_task, task, remote_wheel)
813704

814705
def _job_dashboard_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task:
815706
assert task.dashboard is not None
@@ -841,10 +732,11 @@ def _job_notebook_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task:
841732
),
842733
)
843734

844-
def _job_wheel_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task:
735+
def _job_wheel_task(self, jobs_task: jobs.Task, task: Task, remote_wheel: str) -> jobs.Task:
845736
return replace(
846737
jobs_task,
847738
# TODO: check when we can install wheels from WSFS properly
739+
libraries=[compute.Library(whl=f"dbfs:{remote_wheel}")],
848740
python_wheel_task=jobs.PythonWheelTask(
849741
package_name="databricks_labs_ucx",
850742
entry_point="runtime", # [project.entry-points.databricks] in pyproject.toml
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import json
2+
import logging
3+
4+
from databricks.labs.blueprint.installation import Installation
5+
from databricks.labs.blueprint.installer import InstallState
6+
from databricks.labs.blueprint.tui import Prompts
7+
from databricks.sdk import WorkspaceClient
8+
from databricks.sdk.errors import NotFound
9+
from databricks.sdk.service import compute
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class ClusterPolicyInstaller:
15+
def __init__(self, installation: Installation, ws: WorkspaceClient, prompts: Prompts):
16+
self._ws = ws
17+
self._installation = installation
18+
self._prompts = prompts
19+
20+
@staticmethod
21+
def _policy_config(value: str):
22+
return {"type": "fixed", "value": value}
23+
24+
def create(self, inventory_database: str) -> tuple[str, str, dict]:
25+
instance_profile = ""
26+
spark_conf_dict = {}
27+
policies_with_external_hms = list(self._get_cluster_policies_with_external_hive_metastores())
28+
if len(policies_with_external_hms) > 0 and self._prompts.confirm(
29+
"We have identified one or more cluster policies set up for an external metastore"
30+
"Would you like to set UCX to connect to the external metastore?"
31+
):
32+
logger.info("Setting up an external metastore")
33+
cluster_policies = {conf.name: conf.definition for conf in policies_with_external_hms}
34+
if len(cluster_policies) >= 1:
35+
cluster_policy = json.loads(self._prompts.choice_from_dict("Choose a cluster policy", cluster_policies))
36+
instance_profile, spark_conf_dict = self._extract_external_hive_metastore_conf(cluster_policy)
37+
policy_name = f"Unity Catalog Migration ({inventory_database}) ({self._ws.current_user.me().user_name})"
38+
policies = self._ws.cluster_policies.list()
39+
for policy in policies:
40+
if policy.name == policy_name:
41+
logger.info(f"Cluster policy {policy_name} already present, reusing the same.")
42+
policy_id = policy.policy_id
43+
assert policy_id is not None
44+
return policy_id, instance_profile, spark_conf_dict
45+
logger.info("Creating UCX cluster policy.")
46+
policy_id = self._ws.cluster_policies.create(
47+
name=policy_name,
48+
definition=self._definition(spark_conf_dict, instance_profile),
49+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
50+
).policy_id
51+
assert policy_id is not None
52+
return (
53+
policy_id,
54+
instance_profile,
55+
spark_conf_dict,
56+
)
57+
58+
def _definition(self, conf: dict, instance_profile: str | None) -> str:
59+
policy_definition = {
60+
"spark_version": self._policy_config(self._ws.clusters.select_spark_version(latest=True)),
61+
"node_type_id": self._policy_config(self._ws.clusters.select_node_type(local_disk=True)),
62+
}
63+
for key, value in conf.items():
64+
policy_definition[f"spark_conf.{key}"] = self._policy_config(value)
65+
if self._ws.config.is_aws:
66+
policy_definition["aws_attributes.availability"] = self._policy_config(
67+
compute.AwsAvailability.ON_DEMAND.value
68+
)
69+
if instance_profile:
70+
policy_definition["aws_attributes.instance_profile_arn"] = self._policy_config(instance_profile)
71+
elif self._ws.config.is_azure: # pylint: disable=confusing-consecutive-elif
72+
policy_definition["azure_attributes.availability"] = self._policy_config(
73+
compute.AzureAvailability.ON_DEMAND_AZURE.value
74+
)
75+
else:
76+
policy_definition["gcp_attributes.availability"] = self._policy_config(
77+
compute.GcpAvailability.ON_DEMAND_GCP.value
78+
)
79+
return json.dumps(policy_definition)
80+
81+
@staticmethod
82+
def _extract_external_hive_metastore_conf(cluster_policy):
83+
spark_conf_dict = {}
84+
instance_profile = None
85+
if cluster_policy.get("aws_attributes.instance_profile_arn") is not None:
86+
instance_profile = cluster_policy.get("aws_attributes.instance_profile_arn").get("value")
87+
logger.info(f"Instance Profile is Set to {instance_profile}")
88+
for key in cluster_policy.keys():
89+
if (
90+
key.startswith("spark_conf.spark.sql.hive.metastore")
91+
or key.startswith("spark_conf.spark.hadoop.javax.jdo.option")
92+
or key.startswith("spark_conf.spark.databricks.hive.metastore")
93+
or key.startswith("spark_conf.spark.hadoop.hive.metastore.glue")
94+
):
95+
spark_conf_dict[key[11:]] = cluster_policy[key]["value"]
96+
return instance_profile, spark_conf_dict
97+
98+
def _get_cluster_policies_with_external_hive_metastores(self):
99+
for policy in self._ws.cluster_policies.list():
100+
def_json = json.loads(policy.definition)
101+
glue_node = def_json.get("spark_conf.spark.databricks.hive.metastore.glueCatalog.enabled")
102+
if glue_node is not None and glue_node.get("value") == "true":
103+
yield policy
104+
continue
105+
for key in def_json.keys():
106+
if key.startswith("spark_conf.spark.sql.hive.metastore"):
107+
yield policy
108+
break
109+
110+
def update_job_policy(self, state: InstallState, policy_id: str):
111+
if not state.jobs:
112+
logger.error("No jobs found in states")
113+
return
114+
for _, job_id in state.jobs.items():
115+
try:
116+
job = self._ws.jobs.get(job_id)
117+
job_settings = job.settings
118+
assert job.job_id is not None
119+
assert job_settings is not None
120+
if job_settings.job_clusters is None:
121+
# if job_clusters is None, it means override cluster is being set and hence policy should not be applied
122+
return
123+
except NotFound:
124+
logger.error(f"Job id {job_id} not found. Please check if the job is present in the workspace")
125+
continue
126+
try:
127+
job_clusters = []
128+
for cluster in job_settings.job_clusters:
129+
assert cluster.new_cluster is not None
130+
cluster.new_cluster.policy_id = policy_id
131+
job_clusters.append(cluster)
132+
job_settings.job_clusters = job_clusters
133+
self._ws.jobs.update(job.job_id, new_settings=job_settings)
134+
except NotFound:
135+
logger.error(f"Job id {job_id} not found.")
136+
continue
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# pylint: disable=invalid-name
2+
3+
import logging
4+
5+
from databricks.labs.blueprint.installation import Installation
6+
from databricks.labs.blueprint.installer import InstallState
7+
from databricks.labs.blueprint.tui import Prompts
8+
from databricks.sdk import WorkspaceClient
9+
10+
from databricks.labs.ucx.config import WorkspaceConfig
11+
from databricks.labs.ucx.installer.policy import ClusterPolicyInstaller
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
def upgrade(installation: Installation, ws: WorkspaceClient):
17+
config = installation.load(WorkspaceConfig)
18+
policy_installer = ClusterPolicyInstaller(installation, ws, Prompts())
19+
config.policy_id, _, _ = policy_installer.create(config.inventory_database)
20+
installation.save(config)
21+
states = InstallState.from_installation(installation)
22+
assert config.policy_id is not None
23+
policy_installer.update_job_policy(states, config.policy_id)

0 commit comments

Comments
 (0)