|
| 1 | +import logging |
| 2 | +import os |
| 3 | +import re |
| 4 | +import subprocess |
| 5 | +import sys |
| 6 | +import tempfile |
| 7 | +import webbrowser |
| 8 | +from dataclasses import replace |
| 9 | +from pathlib import Path |
| 10 | + |
| 11 | +import yaml |
| 12 | +from databricks.sdk import WorkspaceClient |
| 13 | +from databricks.sdk.core import DatabricksError |
| 14 | +from databricks.sdk.service import compute, jobs |
| 15 | +from databricks.sdk.service.workspace import ImportFormat |
| 16 | + |
| 17 | +from databricks.labs.ucx.__about__ import __version__ |
| 18 | +from databricks.labs.ucx.config import GroupsConfig, MigrationConfig, TaclConfig |
| 19 | +from databricks.labs.ucx.runtime import main |
| 20 | +from databricks.labs.ucx.tasks import _TASKS |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | + |
| 25 | +class Installer: |
| 26 | + def __init__(self, ws: WorkspaceClient): |
| 27 | + if "DATABRICKS_RUNTIME_VERSION" in os.environ: |
| 28 | + msg = "Installer is not supposed to be executed in Databricks Runtime" |
| 29 | + raise SystemExit(msg) |
| 30 | + self._ws = ws |
| 31 | + |
| 32 | + def run(self): |
| 33 | + self._configure() |
| 34 | + self._create_jobs() |
| 35 | + |
| 36 | + @property |
| 37 | + def _my_username(self): |
| 38 | + if not hasattr(self, "_me"): |
| 39 | + self._me = self._ws.current_user.me() |
| 40 | + is_workspace_admin = any(g.display == "admins" for g in self._me.groups) |
| 41 | + if not is_workspace_admin: |
| 42 | + msg = "Current user is not a workspace admin" |
| 43 | + raise PermissionError(msg) |
| 44 | + return self._me.user_name |
| 45 | + |
| 46 | + @property |
| 47 | + def _install_folder(self): |
| 48 | + return f"/Users/{self._my_username}/.ucx" |
| 49 | + |
| 50 | + @property |
| 51 | + def _config_file(self): |
| 52 | + return f"{self._install_folder}/config.yml" |
| 53 | + |
| 54 | + @property |
| 55 | + def _current_config(self): |
| 56 | + if hasattr(self, "_config"): |
| 57 | + return self._config |
| 58 | + with self._ws.workspace.download(self._config_file) as f: |
| 59 | + self._config = MigrationConfig.from_bytes(f.read()) |
| 60 | + return self._config |
| 61 | + |
| 62 | + def _configure(self): |
| 63 | + config_path = self._config_file |
| 64 | + ws_file_url = f"{self._ws.config.host}/#workspace{config_path}" |
| 65 | + try: |
| 66 | + self._ws.workspace.get_status(config_path) |
| 67 | + logger.info(f"UCX is already configured. See {ws_file_url}") |
| 68 | + if self._question("Type 'yes' to open config file in the browser") == "yes": |
| 69 | + webbrowser.open(ws_file_url) |
| 70 | + return config_path |
| 71 | + except DatabricksError as err: |
| 72 | + if err.error_code != "RESOURCE_DOES_NOT_EXIST": |
| 73 | + raise err |
| 74 | + |
| 75 | + logger.info("Please answer a couple of questions to configure Unity Catalog migration") |
| 76 | + self._config = MigrationConfig( |
| 77 | + inventory_database=self._question("Inventory Database", default="ucx"), |
| 78 | + groups=GroupsConfig( |
| 79 | + selected=self._question("Comma-separated list of workspace group names to migrate").split(","), |
| 80 | + backup_group_prefix=self._question("Backup prefix", default="db-temp-"), |
| 81 | + ), |
| 82 | + tacl=TaclConfig(auto=True), |
| 83 | + log_level=self._question("Log level", default="INFO"), |
| 84 | + num_threads=int(self._question("Number of threads", default="8")), |
| 85 | + ) |
| 86 | + |
| 87 | + config_bytes = yaml.dump(self._config.as_dict()).encode("utf8") |
| 88 | + self._ws.workspace.upload(config_path, config_bytes, format=ImportFormat.AUTO) |
| 89 | + logger.info(f"Created configuration file: {config_path}") |
| 90 | + if self._question("Open config file in the browser and continue installing?", default="yes") == "yes": |
| 91 | + webbrowser.open(ws_file_url) |
| 92 | + |
| 93 | + def _create_jobs(self): |
| 94 | + logger.debug(f"Creating jobs from tasks in {main.__name__}") |
| 95 | + dbfs_path = self._upload_wheel() |
| 96 | + deployed_steps = self._deployed_steps() |
| 97 | + desired_steps = {t.workflow for t in _TASKS.values()} |
| 98 | + for step_name in desired_steps: |
| 99 | + settings = self._job_settings(step_name, dbfs_path) |
| 100 | + if step_name in deployed_steps: |
| 101 | + job_id = deployed_steps[step_name] |
| 102 | + logger.info(f"Updating configuration for step={step_name} job_id={job_id}") |
| 103 | + self._ws.jobs.reset(job_id, jobs.JobSettings(**settings)) |
| 104 | + else: |
| 105 | + logger.info(f"Creating new job configuration for step={step_name}") |
| 106 | + deployed_steps[step_name] = self._ws.jobs.create(**settings).job_id |
| 107 | + |
| 108 | + for step_name, job_id in deployed_steps.items(): |
| 109 | + if step_name not in desired_steps: |
| 110 | + logger.info(f"Removing job_id={job_id}, as it is no longer needed") |
| 111 | + self._ws.jobs.delete(job_id) |
| 112 | + |
| 113 | + self._create_readme(deployed_steps) |
| 114 | + |
| 115 | + def _create_readme(self, deployed_steps): |
| 116 | + md = [ |
| 117 | + "# UCX - The Unity Catalog Migration Assistant", |
| 118 | + "Here are the descriptions of jobs that trigger various stages of migration.", |
| 119 | + ] |
| 120 | + for step_name, job_id in deployed_steps.items(): |
| 121 | + md.append(f"## [[UCX] {step_name}]({self._ws.config.host}#job/{job_id})\n") |
| 122 | + for t in _TASKS.values(): |
| 123 | + if t.workflow != step_name: |
| 124 | + continue |
| 125 | + doc = re.sub(r"\s+", " ", t.doc) |
| 126 | + md.append(f" - `{t.name}`: {doc}") |
| 127 | + md.append("") |
| 128 | + preamble = ["# Databricks notebook source", "# MAGIC %md"] |
| 129 | + intro = "\n".join(preamble + [f"# MAGIC {line}" for line in md]) |
| 130 | + path = f"{self._install_folder}/README.py" |
| 131 | + self._ws.workspace.upload(path, intro.encode("utf8"), overwrite=True) |
| 132 | + url = f"{self._ws.config.host}/#workspace{path}" |
| 133 | + logger.info(f"Created notebook with job overview: {url}") |
| 134 | + msg = "Type 'yes' to open job overview in README notebook in your home directory" |
| 135 | + if self._question(msg) == "yes": |
| 136 | + webbrowser.open(url) |
| 137 | + |
| 138 | + @staticmethod |
| 139 | + def _question(text: str, *, default: str | None = None) -> str: |
| 140 | + default_help = "" if default is None else f"\033[36m (default: {default})\033[0m" |
| 141 | + prompt = f"\033[1m{text}{default_help}: \033[0m" |
| 142 | + res = None |
| 143 | + while not res: |
| 144 | + res = input(prompt) |
| 145 | + if not res and default is not None: |
| 146 | + return default |
| 147 | + return res |
| 148 | + |
| 149 | + def _upload_wheel(self): |
| 150 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 151 | + wheel = self._build_wheel(tmp_dir) |
| 152 | + dbfs_path = f"{self._install_folder}/wheels/{wheel.name}" |
| 153 | + with wheel.open("rb") as f: |
| 154 | + logger.info(f"Uploading wheel to dbfs:{dbfs_path}") |
| 155 | + self._ws.dbfs.upload(dbfs_path, f, overwrite=True) |
| 156 | + return dbfs_path |
| 157 | + |
| 158 | + def _job_settings(self, step_name, dbfs_path): |
| 159 | + config_file = f"/Workspace/{self._install_folder}/config.yml" |
| 160 | + email_notifications = None |
| 161 | + if "@" in self._my_username: |
| 162 | + email_notifications = jobs.JobEmailNotifications( |
| 163 | + on_success=[self._my_username], on_failure=[self._my_username] |
| 164 | + ) |
| 165 | + tasks = sorted([t for t in _TASKS.values() if t.workflow == step_name], key=lambda _: _.name) |
| 166 | + return { |
| 167 | + "name": f"[UCX] {step_name}", |
| 168 | + "tags": {"App": "ucx", "step": step_name}, |
| 169 | + "job_clusters": self._job_clusters({t.job_cluster for t in tasks}), |
| 170 | + "email_notifications": email_notifications, |
| 171 | + "tasks": [ |
| 172 | + jobs.Task( |
| 173 | + task_key=task.name, |
| 174 | + job_cluster_key=task.job_cluster, |
| 175 | + depends_on=[jobs.TaskDependency(task_key=d) for d in _TASKS[task.name].depends_on], |
| 176 | + libraries=[compute.Library(whl=f"dbfs:{dbfs_path}")], |
| 177 | + python_wheel_task=jobs.PythonWheelTask( |
| 178 | + package_name="databricks_labs_ucx", |
| 179 | + entry_point="runtime", # [project.entry-points.databricks] in pyproject.toml |
| 180 | + named_parameters={"task": task.name, "config": config_file}, |
| 181 | + ), |
| 182 | + ) |
| 183 | + for task in tasks |
| 184 | + ], |
| 185 | + } |
| 186 | + |
| 187 | + def _job_clusters(self, names: set[str]): |
| 188 | + clusters = [] |
| 189 | + spec = self._cluster_node_type( |
| 190 | + compute.ClusterSpec( |
| 191 | + spark_version=self._ws.clusters.select_spark_version(latest=True), |
| 192 | + data_security_mode=compute.DataSecurityMode.NONE, |
| 193 | + spark_conf={"spark.databricks.cluster.profile": "singleNode", "spark.master": "local[*]"}, |
| 194 | + custom_tags={"ResourceClass": "SingleNode"}, |
| 195 | + num_workers=0, |
| 196 | + ) |
| 197 | + ) |
| 198 | + if "main" in names: |
| 199 | + clusters.append( |
| 200 | + jobs.JobCluster( |
| 201 | + job_cluster_key="main", |
| 202 | + new_cluster=spec, |
| 203 | + ) |
| 204 | + ) |
| 205 | + if "tacl" in names: |
| 206 | + clusters.append( |
| 207 | + jobs.JobCluster( |
| 208 | + job_cluster_key="tacl", |
| 209 | + new_cluster=replace( |
| 210 | + spec, |
| 211 | + data_security_mode=compute.DataSecurityMode.LEGACY_TABLE_ACL, |
| 212 | + spark_conf={"spark.databricks.acl.sqlOnly": "true"}, |
| 213 | + custom_tags={}, |
| 214 | + ), |
| 215 | + ) |
| 216 | + ) |
| 217 | + return clusters |
| 218 | + |
| 219 | + @staticmethod |
| 220 | + def _build_wheel(tmp_dir: str, *, verbose: bool = False): |
| 221 | + """Helper to build the wheel package""" |
| 222 | + streams = {} |
| 223 | + if not verbose: |
| 224 | + streams = { |
| 225 | + "stdout": subprocess.DEVNULL, |
| 226 | + "stderr": subprocess.DEVNULL, |
| 227 | + } |
| 228 | + project_root = Installer._find_project_root(Path(__file__)) |
| 229 | + if not project_root: |
| 230 | + msg = "Cannot find project root" |
| 231 | + raise NotADirectoryError(msg) |
| 232 | + logger.debug(f"Building wheel for {project_root} in {tmp_dir}") |
| 233 | + subprocess.run( |
| 234 | + [sys.executable, "-m", "pip", "wheel", "--no-deps", "--wheel-dir", tmp_dir, project_root], |
| 235 | + **streams, |
| 236 | + check=True, |
| 237 | + ) |
| 238 | + # get wheel name as first file in the temp directory |
| 239 | + return next(Path(tmp_dir).glob("*.whl")) |
| 240 | + |
| 241 | + @staticmethod |
| 242 | + def _find_project_root(folder: Path) -> Path | None: |
| 243 | + for leaf in ["pyproject.toml", "setup.py"]: |
| 244 | + root = Installer._find_dir_with_leaf(folder, leaf) |
| 245 | + if root is not None: |
| 246 | + return root |
| 247 | + return None |
| 248 | + |
| 249 | + @staticmethod |
| 250 | + def _find_dir_with_leaf(folder: Path, leaf: str) -> Path | None: |
| 251 | + root = folder.root |
| 252 | + while str(folder.absolute()) != root: |
| 253 | + if (folder / leaf).exists(): |
| 254 | + return folder |
| 255 | + folder = folder.parent |
| 256 | + return None |
| 257 | + |
| 258 | + def _cluster_node_type(self, spec: compute.ClusterSpec) -> compute.ClusterSpec: |
| 259 | + cfg = self._current_config |
| 260 | + if cfg.instance_pool_id is not None: |
| 261 | + return replace(spec, instance_pool_id=cfg.instance_pool_id) |
| 262 | + spec = replace(spec, node_type_id=self._ws.clusters.select_node_type(local_disk=True)) |
| 263 | + if self._ws.config.is_aws: |
| 264 | + return replace(spec, aws_attributes=compute.AwsAttributes(availability=compute.AwsAvailability.ON_DEMAND)) |
| 265 | + if self._ws.config.is_azure: |
| 266 | + return replace( |
| 267 | + spec, azure_attributes=compute.AzureAttributes(availability=compute.AzureAvailability.ON_DEMAND_AZURE) |
| 268 | + ) |
| 269 | + return replace(spec, gcp_attributes=compute.GcpAttributes(availability=compute.GcpAvailability.ON_DEMAND_GCP)) |
| 270 | + |
| 271 | + def _deployed_steps(self): |
| 272 | + deployed_steps = {} |
| 273 | + for j in self._ws.jobs.list(): |
| 274 | + tags = j.settings.tags |
| 275 | + if tags is None: |
| 276 | + continue |
| 277 | + if tags.get("App", None) != "ucx": |
| 278 | + continue |
| 279 | + deployed_steps[tags.get("step", "_")] = j.job_id |
| 280 | + return deployed_steps |
| 281 | + |
| 282 | + |
| 283 | +if __name__ == "__main__": |
| 284 | + ws = WorkspaceClient(product="ucx", product_version=__version__) |
| 285 | + installer = Installer(ws) |
| 286 | + installer.run() |
0 commit comments