Skip to content

Commit 87b6262

Browse files
authored
Cleanup installer framework and speed up test execution (#711)
1 parent f8ae4fb commit 87b6262

File tree

9 files changed

+550
-535
lines changed

9 files changed

+550
-535
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dependencies = [
4646
"pytest-xdist",
4747
"pytest-cov>=4.0.0,<5.0.0",
4848
"pytest-mock>=3.0.0,<4.0.0",
49+
"pytest-timeout",
4950
"black>=23.1.0",
5051
"ruff>=0.0.243",
5152
"isort>=2.5.0",
@@ -60,8 +61,8 @@ python="3.10"
6061
path = ".venv"
6162

6263
[tool.hatch.envs.default.scripts]
63-
test = "pytest -n auto --cov src --cov-report=xml tests/unit"
64-
coverage = "pytest -n auto --cov src tests/unit --cov-report=html"
64+
test = "pytest -n auto --cov src --cov-report=xml --timeout 10 tests/unit"
65+
coverage = "pytest -n auto --cov src tests/unit --timeout 10 --cov-report=html"
6566
integration = "pytest -n 10 --cov src tests/integration"
6667
fmt = ["isort .",
6768
"black .",

src/databricks/labs/ucx/configure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, ws, choice_from_dict):
1717
def _valid_cluster_id(self, cluster_id: str) -> bool:
1818
return cluster_id is not None and CLUSTER_ID_LENGTH == len(cluster_id)
1919

20-
def _configure_override_clusters(self):
20+
def configure(self):
2121
"""User may override standard job clusters with interactive clusters"""
2222
logger.info("Configuring cluster overrides from existing clusters")
2323

src/databricks/labs/ucx/framework/tui.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import logging
12
import re
3+
from collections.abc import Callable
24
from typing import Any
35

6+
logger = logging.getLogger(__name__)
7+
48

59
class Prompts:
610
def multi_select_from_dict(self, all_prompt: str, item_prompt: str, choices: dict[str, Any]) -> list[Any]:
@@ -39,6 +43,10 @@ def choice(self, text: str, choices: list[Any], *, max_attempts: int = 10, sort:
3943
msg = f"cannot get answer within {max_attempts} attempt"
4044
raise ValueError(msg)
4145

46+
def confirm(self, text: str, *, max_attempts: int = 10):
47+
answer = self.question(text, valid_regex=r"[Yy][Ee][Ss]|[Nn][Oo]", default="no", max_attempts=max_attempts)
48+
return answer.lower() == "yes"
49+
4250
def question(
4351
self,
4452
text: str,
@@ -47,6 +55,7 @@ def question(
4755
max_attempts: int = 10,
4856
valid_number: bool = False,
4957
valid_regex: str | None = None,
58+
validate: Callable[[str], bool] | None = None,
5059
) -> str:
5160
default_help = "" if default is None else f"\033[36m (default: {default})\033[0m"
5261
prompt = f"\033[1m{text}{default_help}: \033[0m"
@@ -59,6 +68,9 @@ def question(
5968
while attempt < max_attempts:
6069
attempt += 1
6170
res = input(prompt)
71+
if res and validate:
72+
if not validate(res):
73+
continue
6274
if res and match_regex:
6375
if not match_regex.match(res):
6476
print(f"\033[31m[ERROR] Not a '{valid_regex}' match: {res}\033[0m\n")
@@ -74,12 +86,18 @@ def question(
7486

7587

7688
class MockPrompts(Prompts):
77-
def __init__(self, patterns_to_answers: dict):
78-
self._questions_to_answers = {re.compile(k): v for k, v in patterns_to_answers.items()}
89+
def __init__(self, patterns_to_answers: dict[str, str]):
90+
self._questions_to_answers = sorted(
91+
[(re.compile(k), v) for k, v in patterns_to_answers.items()], key=lambda _: len(_[0].pattern), reverse=True
92+
)
7993

80-
def question(self, text: str, **_) -> str:
81-
for question, answer in self._questions_to_answers.items():
82-
if question.match(text):
83-
return answer
94+
def question(self, text: str, default: str | None = None, **_) -> str:
95+
logger.info(f"Asking prompt: {text}")
96+
for question, answer in self._questions_to_answers:
97+
if not question.search(text):
98+
continue
99+
if not answer and default:
100+
return default
101+
return answer
84102
mocked = f"not mocked: {text}"
85103
raise ValueError(mocked)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import datetime
2+
import logging
3+
import os
4+
import shutil
5+
import subprocess
6+
import sys
7+
import tempfile
8+
from contextlib import AbstractContextManager
9+
from pathlib import Path
10+
11+
from databricks.sdk import WorkspaceClient
12+
from databricks.sdk.mixins.compute import SemVer
13+
from databricks.sdk.service.workspace import ImportFormat
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class Wheels(AbstractContextManager):
19+
def __init__(self, ws: WorkspaceClient, install_folder: str, released_version: str):
20+
self._ws = ws
21+
self._this_file = Path(__file__)
22+
self._install_folder = install_folder
23+
self._released_version = released_version
24+
25+
def version(self):
26+
if hasattr(self, "__version"):
27+
return self.__version
28+
project_root = find_project_root()
29+
if not (project_root / ".git/config").exists():
30+
# normal install, downloaded releases won't have the .git folder
31+
return self._released_version
32+
try:
33+
out = subprocess.run(["git", "describe", "--tags"], stdout=subprocess.PIPE, check=True) # noqa S607
34+
git_detached_version = out.stdout.decode("utf8")
35+
dv = SemVer.parse(git_detached_version)
36+
datestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
37+
# new commits on main branch since the last tag
38+
new_commits = dv.pre_release.split("-")[0] if dv.pre_release else None
39+
# show that it's a version different from the released one in stats
40+
bump_patch = dv.patch + 1
41+
# create something that is both https://semver.org and https://peps.python.org/pep-0440/
42+
semver_and_pep0440 = f"{dv.major}.{dv.minor}.{bump_patch}+{new_commits}{datestamp}"
43+
# validate the semver
44+
SemVer.parse(semver_and_pep0440)
45+
self.__version = semver_and_pep0440
46+
return semver_and_pep0440
47+
except Exception as err:
48+
msg = (
49+
f"Cannot determine unreleased version. Please report this error "
50+
f"message that you see on https://github.com/databrickslabs/ucx/issues/new. "
51+
f"Meanwhile, download, unpack, and install the latest released version from "
52+
f"https://github.com/databrickslabs/ucx/releases. Original error is: {err!s}"
53+
)
54+
raise OSError(msg) from None
55+
56+
def __enter__(self) -> "Wheels":
57+
self._tmp_dir = tempfile.TemporaryDirectory()
58+
self._local_wheel = self._build_wheel(self._tmp_dir.name)
59+
self._remote_wheel = f"{self._install_folder}/wheels/{self._local_wheel.name}"
60+
self._remote_dirname = os.path.dirname(self._remote_wheel)
61+
return self
62+
63+
def __exit__(self, __exc_type, __exc_value, __traceback):
64+
self._tmp_dir.cleanup()
65+
66+
def upload_to_dbfs(self) -> str:
67+
with self._local_wheel.open("rb") as f:
68+
self._ws.dbfs.mkdirs(self._remote_dirname)
69+
logger.info(f"Uploading wheel to dbfs:{self._remote_wheel}")
70+
self._ws.dbfs.upload(self._remote_wheel, f, overwrite=True)
71+
return self._remote_wheel
72+
73+
def upload_to_wsfs(self) -> str:
74+
with self._local_wheel.open("rb") as f:
75+
self._ws.workspace.mkdirs(self._remote_dirname)
76+
logger.info(f"Uploading wheel to /Workspace{self._remote_wheel}")
77+
self._ws.workspace.upload(self._remote_wheel, f, overwrite=True, format=ImportFormat.AUTO)
78+
return self._remote_wheel
79+
80+
def _build_wheel(self, tmp_dir: str, *, verbose: bool = False):
81+
"""Helper to build the wheel package"""
82+
stdout = subprocess.STDOUT
83+
stderr = subprocess.STDOUT
84+
if not verbose:
85+
stdout = subprocess.DEVNULL
86+
stderr = subprocess.DEVNULL
87+
project_root = find_project_root()
88+
is_non_released_version = "+" in self.version()
89+
if (project_root / ".git" / "config").exists() and is_non_released_version:
90+
tmp_dir_path = Path(tmp_dir) / "working-copy"
91+
# copy everything to a temporary directory
92+
shutil.copytree(project_root, tmp_dir_path)
93+
# and override the version file
94+
# TODO: make it configurable
95+
version_file = tmp_dir_path / "src/databricks/labs/ucx/__about__.py"
96+
with version_file.open("w") as f:
97+
f.write(f'__version__ = "{self.version()}"')
98+
# working copy becomes project root for building a wheel
99+
project_root = tmp_dir_path
100+
logger.debug(f"Building wheel for {project_root} in {tmp_dir}")
101+
subprocess.run(
102+
[sys.executable, "-m", "pip", "wheel", "--no-deps", "--wheel-dir", tmp_dir, project_root.as_posix()],
103+
check=True,
104+
stdout=stdout,
105+
stderr=stderr,
106+
)
107+
# get wheel name as first file in the temp directory
108+
return next(Path(tmp_dir).glob("*.whl"))
109+
110+
111+
def find_project_root() -> Path:
112+
def _find_dir_with_leaf(folder: Path, leaf: str) -> Path | None:
113+
root = folder.root
114+
while str(folder.absolute()) != root:
115+
if (folder / leaf).exists():
116+
return folder
117+
folder = folder.parent
118+
return None
119+
120+
for leaf in ["pyproject.toml", "setup.py"]:
121+
root = _find_dir_with_leaf(Path(__file__), leaf)
122+
if root is not None:
123+
return root
124+
msg = "Cannot find project root"
125+
raise NotADirectoryError(msg)

0 commit comments

Comments
 (0)