Skip to content

Commit b94df54

Browse files
authored
Creates cluster policy (ucx-policy) to be used by all UCX compute (#853)
1 parent 6496b3b commit b94df54

File tree

4 files changed

+352
-107
lines changed

4 files changed

+352
-107
lines changed

src/databricks/labs/ucx/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class WorkspaceConfig: # pylint: disable=too-many-instance-attributes
3434
spark_conf: dict[str, str] | None = None
3535

3636
override_clusters: dict[str, str] | None = None
37-
custom_cluster_policy_id: str | None = None
37+
policy_id: str | None = None
3838

3939
def replace_inventory_variable(self, text: str) -> str:
4040
return text.replace("$inventory", f"hive_metastore.{self.inventory_database}")

src/databricks/labs/ucx/install.py

Lines changed: 70 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,12 @@ def warehouse_type(_):
242242
cluster_policy = json.loads(self._prompts.choice_from_dict("Choose a cluster policy", cluster_policies))
243243
instance_profile, spark_conf_dict = self._get_ext_hms_conf_from_policy(cluster_policy)
244244

245-
if self._prompts.confirm("Do you want to follow a policy to create clusters?"):
246-
cluster_policies_list = {f"{_.name} ({_.policy_id})": _.policy_id for _ in self._ws.cluster_policies.list()}
247-
custom_cluster_policy_id = self._prompts.choice_from_dict("Choose a cluster policy", cluster_policies_list)
248-
else:
249-
custom_cluster_policy_id = None
245+
logger.info("Creating UCX cluster policy.")
246+
policy_id = self._ws.cluster_policies.create(
247+
name=f"Unity Catalog Migration ({inventory_database})",
248+
definition=self._cluster_policy_definition(conf=spark_conf_dict, instance_profile=instance_profile),
249+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
250+
).policy_id
250251

251252
config = WorkspaceConfig(
252253
inventory_database=inventory_database,
@@ -261,13 +262,41 @@ def warehouse_type(_):
261262
num_threads=num_threads,
262263
instance_profile=instance_profile,
263264
spark_conf=spark_conf_dict,
264-
custom_cluster_policy_id=custom_cluster_policy_id,
265+
policy_id=policy_id,
265266
)
266267
ws_file_url = self._installation.save(config)
267268
if self._prompts.confirm("Open config file in the browser and continue installing?"):
268269
webbrowser.open(ws_file_url)
269270
return config
270271

272+
@staticmethod
273+
def _policy_config(value: str):
274+
return {"type": "fixed", "value": value}
275+
276+
def _cluster_policy_definition(self, conf: dict, instance_profile: str | None) -> str:
277+
policy_definition = {
278+
"spark_version": self._policy_config(self._ws.clusters.select_spark_version(latest=True)),
279+
"node_type_id": self._policy_config(self._ws.clusters.select_node_type(local_disk=True)),
280+
}
281+
if conf:
282+
for key, value in conf.items():
283+
policy_definition[f"spark_conf.{key}"] = self._policy_config(value)
284+
if self._ws.config.is_aws:
285+
policy_definition["aws_attributes.availability"] = self._policy_config(
286+
compute.AwsAvailability.ON_DEMAND.value
287+
)
288+
if instance_profile:
289+
policy_definition["aws_attributes.instance_profile_arn"] = self._policy_config(instance_profile)
290+
elif self._ws.config.is_azure:
291+
policy_definition["azure_attributes.availability"] = self._policy_config(
292+
compute.AzureAvailability.ON_DEMAND_AZURE.value
293+
)
294+
else:
295+
policy_definition["gcp_attributes.availability"] = self._policy_config(
296+
compute.GcpAvailability.ON_DEMAND_GCP.value
297+
)
298+
return json.dumps(policy_definition)
299+
271300
@staticmethod
272301
def _get_ext_hms_conf_from_policy(cluster_policy):
273302
spark_conf_dict = {}
@@ -277,7 +306,7 @@ def _get_ext_hms_conf_from_policy(cluster_policy):
277306
logger.info(f"Instance Profile is Set to {instance_profile}")
278307
for key in cluster_policy.keys():
279308
if (
280-
key.startswith("spark_conf.sql.hive.metastore")
309+
key.startswith("spark_conf.spark.sql.hive.metastore")
281310
or key.startswith("spark_conf.spark.hadoop.javax.jdo.option")
282311
or key.startswith("spark_conf.spark.databricks.hive.metastore")
283312
or key.startswith("spark_conf.spark.hadoop.hive.metastore.glue")
@@ -293,7 +322,7 @@ def _get_cluster_policies_with_external_hive_metastores(self):
293322
yield policy
294323
continue
295324
for key in def_json.keys():
296-
if key.startswith("spark_config.spark.sql.hive.metastore"):
325+
if key.startswith("spark_conf.spark.sql.hive.metastore"):
297326
yield policy
298327
break
299328

@@ -512,13 +541,26 @@ def _upload_wheel(self):
512541
def create_jobs(self):
513542
logger.debug(f"Creating jobs from tasks in {main.__name__}")
514543
remote_wheel = self._upload_wheel()
544+
try:
545+
policy_definition = self._ws.cluster_policies.get(policy_id=self.config.policy_id).definition
546+
except NotFound as e:
547+
msg = f"UCX Policy {self.config.policy_id} not found, please reinstall UCX"
548+
logger.error(msg)
549+
raise NotFound(msg) from e
550+
551+
self._ws.cluster_policies.edit(
552+
policy_id=self.config.policy_id,
553+
name=f"Unity Catalog Migration ({self.config.inventory_database})",
554+
definition=policy_definition,
555+
libraries=[compute.Library(whl=f"dbfs:{remote_wheel}")],
556+
)
515557
desired_steps = {t.workflow for t in _TASKS.values() if t.cloud_compatible(self._ws.config)}
516558
wheel_runner = None
517559

518560
if self._config.override_clusters:
519561
wheel_runner = self._upload_wheel_runner(remote_wheel)
520562
for step_name in desired_steps:
521-
settings = self._job_settings(step_name, remote_wheel)
563+
settings = self._job_settings(step_name)
522564
if self._config.override_clusters:
523565
settings = self._apply_cluster_overrides(settings, self._config.override_clusters, wheel_runner)
524566
self._deploy_workflow(step_name, settings)
@@ -618,7 +660,7 @@ def _create_debug(self, remote_wheel: str):
618660
).encode("utf8")
619661
self._installation.upload('DEBUG.py', content)
620662

621-
def _job_settings(self, step_name: str, remote_wheel: str):
663+
def _job_settings(self, step_name: str):
622664
email_notifications = None
623665
if not self._config.override_clusters and "@" in self._my_username:
624666
# set email notifications only if we're running the real
@@ -637,7 +679,7 @@ def _job_settings(self, step_name: str, remote_wheel: str):
637679
"tags": {"version": f"v{version}"},
638680
"job_clusters": self._job_clusters({t.job_cluster for t in tasks}),
639681
"email_notifications": email_notifications,
640-
"tasks": [self._job_task(task, remote_wheel) for task in tasks],
682+
"tasks": [self._job_task(task) for task in tasks],
641683
}
642684

643685
def _upload_wheel_runner(self, remote_wheel: str):
@@ -661,7 +703,7 @@ def _apply_cluster_overrides(settings: dict[str, Any], overrides: dict[str, str]
661703
job_task.notebook_task = jobs.NotebookTask(notebook_path=wheel_runner, base_parameters=params)
662704
return settings
663705

664-
def _job_task(self, task: Task, remote_wheel: str) -> jobs.Task:
706+
def _job_task(self, task: Task) -> jobs.Task:
665707
jobs_task = jobs.Task(
666708
task_key=task.name,
667709
job_cluster_key=task.job_cluster,
@@ -674,7 +716,7 @@ def _job_task(self, task: Task, remote_wheel: str) -> jobs.Task:
674716
return retried_job_dashboard_task(jobs_task, task)
675717
if task.notebook:
676718
return self._job_notebook_task(jobs_task, task)
677-
return self._job_wheel_task(jobs_task, task, remote_wheel)
719+
return self._job_wheel_task(jobs_task, task)
678720

679721
def _job_dashboard_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task:
680722
assert task.dashboard is not None
@@ -706,11 +748,10 @@ def _job_notebook_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task:
706748
),
707749
)
708750

709-
def _job_wheel_task(self, jobs_task: jobs.Task, task: Task, remote_wheel: str) -> jobs.Task:
751+
def _job_wheel_task(self, jobs_task: jobs.Task, task: Task) -> jobs.Task:
710752
return replace(
711753
jobs_task,
712754
# TODO: check when we can install wheels from WSFS properly
713-
libraries=[compute.Library(whl=f"dbfs:{remote_wheel}")],
714755
python_wheel_task=jobs.PythonWheelTask(
715756
package_name="databricks_labs_ucx",
716757
entry_point="runtime", # [project.entry-points.databricks] in pyproject.toml
@@ -726,21 +767,13 @@ def _job_clusters(self, names: set[str]):
726767
}
727768
if self._config.spark_conf is not None:
728769
spark_conf = spark_conf | self._config.spark_conf
729-
spec = self._cluster_node_type(
730-
compute.ClusterSpec(
731-
spark_version=self._ws.clusters.select_spark_version(latest=True),
732-
data_security_mode=compute.DataSecurityMode.LEGACY_SINGLE_USER,
733-
spark_conf=spark_conf,
734-
custom_tags={"ResourceClass": "SingleNode"},
735-
num_workers=0,
736-
)
770+
spec = compute.ClusterSpec(
771+
data_security_mode=compute.DataSecurityMode.LEGACY_SINGLE_USER,
772+
spark_conf=spark_conf,
773+
custom_tags={"ResourceClass": "SingleNode"},
774+
num_workers=0,
775+
policy_id=self.config.policy_id,
737776
)
738-
if self._config.custom_cluster_policy_id is not None:
739-
spec = replace(spec, policy_id=self._config.custom_cluster_policy_id)
740-
if self._ws.config.is_aws and spec.aws_attributes is not None:
741-
# TODO: we might not need spec.aws_attributes, if we have a cluster policy
742-
aws_attributes = replace(spec.aws_attributes, instance_profile_arn=self._config.instance_profile)
743-
spec = replace(spec, aws_attributes=aws_attributes)
744777
if "main" in names:
745778
clusters.append(
746779
jobs.JobCluster(
@@ -763,41 +796,6 @@ def _job_clusters(self, names: set[str]):
763796
)
764797
return clusters
765798

766-
def _cluster_node_type(self, spec: compute.ClusterSpec) -> compute.ClusterSpec:
767-
cfg = self._config
768-
valid_node_type = False
769-
if cfg.custom_cluster_policy_id is not None:
770-
if self._check_policy_has_instance_pool(cfg.custom_cluster_policy_id):
771-
valid_node_type = True
772-
if not valid_node_type:
773-
if cfg.instance_pool_id is not None:
774-
return replace(spec, instance_pool_id=cfg.instance_pool_id)
775-
spec = replace(spec, node_type_id=self._ws.clusters.select_node_type(local_disk=True))
776-
if self._ws.config.is_aws:
777-
return replace(spec, aws_attributes=compute.AwsAttributes(availability=compute.AwsAvailability.ON_DEMAND))
778-
if self._ws.config.is_azure:
779-
return replace(
780-
spec, azure_attributes=compute.AzureAttributes(availability=compute.AzureAvailability.ON_DEMAND_AZURE)
781-
)
782-
return replace(spec, gcp_attributes=compute.GcpAttributes(availability=compute.GcpAvailability.ON_DEMAND_GCP))
783-
784-
def _check_policy_has_instance_pool(self, policy_id):
785-
try:
786-
policy = self._ws.cluster_policies.get(policy_id=policy_id)
787-
except NotFound:
788-
logger.warning(f"removed on the backend {policy_id}")
789-
return False
790-
def_json = json.loads(policy.definition)
791-
instance_pool = def_json.get("instance_pool_id")
792-
if instance_pool is not None:
793-
return True
794-
return False
795-
796-
def _instance_profiles(self):
797-
return {"No Instance Profile": None} | {
798-
profile.instance_profile_arn: profile.instance_profile_arn for profile in self._ws.instance_profiles.list()
799-
}
800-
801799
@staticmethod
802800
def _readable_timedelta(epoch):
803801
when = datetime.fromtimestamp(epoch)
@@ -899,6 +897,7 @@ def uninstall(self):
899897
self._remove_database()
900898
self._remove_jobs()
901899
self._remove_warehouse()
900+
self._remove_policies()
902901
self._installation.remove()
903902
logger.info("UnInstalling UCX complete")
904903

@@ -911,6 +910,13 @@ def _remove_database(self):
911910
deployer = SchemaDeployer(self._sql_backend, self._config.inventory_database, Any)
912911
deployer.delete_schema()
913912

913+
def _remove_policies(self):
914+
logger.info("Deleting cluster policy")
915+
try:
916+
self._ws.cluster_policies.delete(policy_id=self.config.policy_id)
917+
except NotFound:
918+
logger.error("UCX Policy already deleted")
919+
914920
def _remove_jobs(self):
915921
logger.info("Deleting jobs")
916922
if not self._state.jobs:

tests/integration/test_installation.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import json
23
import logging
34
from collections.abc import Callable
45
from dataclasses import replace
@@ -11,11 +12,15 @@
1112
from databricks.labs.blueprint.wheels import WheelsV2
1213
from databricks.sdk.errors import InvalidParameterValue, NotFound, Unknown
1314
from databricks.sdk.retries import retried
14-
from databricks.sdk.service import sql
15+
from databricks.sdk.service import compute, sql
1516
from databricks.sdk.service.iam import PermissionLevel
1617

1718
from databricks.labs.ucx.config import WorkspaceConfig
18-
from databricks.labs.ucx.install import PRODUCT_INFO, WorkspaceInstallation
19+
from databricks.labs.ucx.install import (
20+
PRODUCT_INFO,
21+
WorkspaceInstallation,
22+
WorkspaceInstaller,
23+
)
1924
from databricks.labs.ucx.workspace_access import redash
2025
from databricks.labs.ucx.workspace_access.generic import (
2126
GenericPermissionsSupport,
@@ -29,20 +34,25 @@
2934

3035

3136
@pytest.fixture
32-
def new_installation(ws, sql_backend, env_or_skip, inventory_schema, make_random):
37+
def new_installation(ws, sql_backend, env_or_skip, inventory_schema, make_random, make_cluster_policy):
3338
cleanup = []
3439

35-
prompts = MockPrompts(
36-
{
37-
r'Open job overview in your browser.*': 'no',
38-
r'Do you want to uninstall ucx.*': 'yes',
39-
r'Do you want to delete the inventory database.*': 'yes',
40-
}
41-
)
42-
4340
def factory(config_transform: Callable[[WorkspaceConfig], WorkspaceConfig] | None = None):
4441
prefix = make_random(4)
4542
renamed_group_prefix = f"rename-{prefix}-"
43+
prompts = MockPrompts(
44+
{
45+
r'Open job overview in your browser.*': 'no',
46+
r'Do you want to uninstall ucx.*': 'yes',
47+
r'Do you want to delete the inventory database.*': 'yes',
48+
r".*PRO or SERVERLESS SQL warehouse.*": "1",
49+
r"Choose how to map the workspace groups.*": "1",
50+
r".*connect to the external metastore?.*": "yes",
51+
r".*Inventory Database.*": inventory_schema,
52+
r".*Backup prefix*": renamed_group_prefix,
53+
r".*": "",
54+
}
55+
)
4656
workspace_start_path = f"/Users/{ws.current_user.me().user_name}/.{prefix}"
4757
default_cluster_id = env_or_skip("TEST_DEFAULT_CLUSTER_ID")
4858
tacl_cluster_id = env_or_skip("TEST_LEGACY_TABLE_ACL_CLUSTER_ID")
@@ -53,16 +63,14 @@ def factory(config_transform: Callable[[WorkspaceConfig], WorkspaceConfig] | Non
5363
functools.partial(ws.clusters.ensure_cluster_is_running, tacl_cluster_id),
5464
],
5565
)
56-
workspace_config = WorkspaceConfig(
57-
inventory_database=inventory_schema,
58-
log_level="DEBUG",
59-
renamed_group_prefix=renamed_group_prefix,
60-
workspace_start_path=workspace_start_path,
61-
override_clusters={"main": default_cluster_id, "tacl": tacl_cluster_id},
62-
)
66+
installation = Installation(ws, prefix)
67+
installer = WorkspaceInstaller(prompts, installation, ws)
68+
workspace_config = installer.configure()
69+
overrides = {"main": default_cluster_id, "tacl": tacl_cluster_id}
70+
workspace_config.override_clusters = overrides
71+
workspace_config.workspace_start_path = workspace_start_path
6372
if config_transform:
6473
workspace_config = config_transform(workspace_config)
65-
installation = Installation(ws, prefix)
6674
installation.save(workspace_config)
6775

6876
# TODO: see if we want to move building wheel as a context manager for yield factory,
@@ -104,6 +112,42 @@ def test_job_failure_propagates_correct_error_message_and_logs(ws, sql_backend,
104112

105113

106114
@retried(on=[NotFound, Unknown, InvalidParameterValue], timeout=timedelta(minutes=18))
115+
def test_job_cluster_policy(ws, new_installation):
116+
install = new_installation(lambda wc: replace(wc, override_clusters=None))
117+
cluster_policy = ws.cluster_policies.get(policy_id=install.config.policy_id)
118+
policy_definition = json.loads(cluster_policy.definition)
119+
120+
assert cluster_policy.name == f"Unity Catalog Migration ({install.config.inventory_database})"
121+
122+
assert policy_definition["spark_version"]["value"] == ws.clusters.select_spark_version(latest=True)
123+
assert policy_definition["node_type_id"]["value"] == ws.clusters.select_node_type(local_disk=True)
124+
assert (
125+
policy_definition["azure_attributes.availability"]["value"] == compute.AzureAvailability.ON_DEMAND_AZURE.value
126+
)
127+
128+
129+
@pytest.mark.skip
130+
@retried(on=[NotFound, TimeoutError], timeout=timedelta(minutes=15))
131+
def test_new_job_cluster_with_policy_assessment(
132+
ws, new_installation, make_ucx_group, make_cluster_policy, make_cluster_policy_permissions
133+
):
134+
ws_group_a, acc_group_a = make_ucx_group()
135+
cluster_policy = make_cluster_policy()
136+
make_cluster_policy_permissions(
137+
object_id=cluster_policy.policy_id,
138+
permission_level=PermissionLevel.CAN_USE,
139+
group_name=ws_group_a.display_name,
140+
)
141+
install = new_installation(
142+
lambda wc: replace(wc, override_clusters=None, include_group_names=[ws_group_a.display_name])
143+
)
144+
install.run_workflow("assessment")
145+
generic_permissions = GenericPermissionsSupport(ws, [])
146+
before = generic_permissions.load_as_dict("cluster-policies", cluster_policy.policy_id)
147+
assert before[ws_group_a.display_name] == PermissionLevel.CAN_USE
148+
149+
150+
@retried(on=[NotFound, Unknown, InvalidParameterValue], timeout=timedelta(minutes=20))
107151
def test_running_real_assessment_job(
108152
ws, new_installation, make_ucx_group, make_cluster_policy, make_cluster_policy_permissions
109153
):

0 commit comments

Comments
 (0)