Skip to content

Commit c3c2295

Browse files
authored
Add support for account-level configuration and multi-cloud workspace list (#264)
This PR renames `MigrationConfig` to `WorkspaceConfig` and adds `AccountConfig`, which persists configuration file at `~/.ucx/config.yml` - same homefolder, but this time on the local machine, instead of the workspace FS.
1 parent 20d1b42 commit c3c2295

File tree

15 files changed

+505
-116
lines changed

15 files changed

+505
-116
lines changed

notebooks/toolkit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from databricks.labs.ucx.config import (
1616
GroupsConfig,
17-
MigrationConfig,
17+
WorkspaceConfig,
1818
)
1919
from databricks.labs.ucx.framework.crawlers import RuntimeBackend
2020
from databricks.labs.ucx.workspace_access import GroupMigrationToolkit
@@ -32,7 +32,7 @@
3232
selected_groups = dbutils.widgets.get("selected_groups").split(",")
3333
databases = dbutils.widgets.get("databases").split(",")
3434

35-
config = MigrationConfig(
35+
config = WorkspaceConfig(
3636
inventory_database=inventory_database,
3737
groups=GroupsConfig(
3838
# use this option to select specific groups manually

src/databricks/labs/ucx/account/__init__.py

Whitespace-only changes.
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import base64
2+
import json
3+
import logging
4+
from dataclasses import dataclass
5+
from pathlib import Path
6+
from typing import ClassVar
7+
8+
import requests
9+
from databricks.sdk import WorkspaceClient
10+
from databricks.sdk.core import AzureCliTokenSource, Config, DatabricksError
11+
from databricks.sdk.service.provisioning import PricingTier, Workspace
12+
from requests.exceptions import ConnectionError
13+
14+
from databricks.labs.ucx.__about__ import __version__
15+
from databricks.labs.ucx.config import AccountConfig
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@dataclass
21+
class AzureSubscription:
22+
name: str
23+
subscription_id: str
24+
tenant_id: str
25+
26+
27+
class AzureWorkspaceLister:
28+
def __init__(self, cfg: Config):
29+
endpoint = cfg.arm_environment.resource_manager_endpoint
30+
self._token_source = AzureCliTokenSource(endpoint)
31+
self._endpoint = endpoint
32+
33+
def _get(self, path: str, *, api_version=None) -> dict:
34+
token = self._token_source.token()
35+
headers = {"Authorization": f"{token.token_type} {token.access_token}"}
36+
return requests.get(
37+
self._endpoint + path, headers=headers, params={"api-version": api_version}, timeout=10
38+
).json()
39+
40+
def _all_subscriptions(self):
41+
for sub in self._get("/subscriptions", api_version="2022-12-01").get("value", []):
42+
yield AzureSubscription(
43+
name=sub["displayName"], subscription_id=sub["subscriptionId"], tenant_id=sub["tenantId"]
44+
)
45+
46+
def _tenant_id(self):
47+
token = self._token_source.token()
48+
_, payload, _ = token.access_token.split(".")
49+
b64_decoded = base64.standard_b64decode(payload + "==").decode("utf8")
50+
claims = json.loads(b64_decoded)
51+
return claims["tid"]
52+
53+
def current_tenant_subscriptions(self):
54+
tenant_id = self._tenant_id()
55+
for sub in self._all_subscriptions():
56+
if sub.tenant_id != tenant_id:
57+
continue
58+
yield sub
59+
60+
def subscriptions_name_to_id(self):
61+
return {sub.name: sub.subscription_id for sub in self.current_tenant_subscriptions()}
62+
63+
def list_workspaces(self, subscription_id):
64+
endpoint = f"/subscriptions/{subscription_id}/providers/Microsoft.Databricks/workspaces"
65+
sku_tiers = {
66+
"premium": PricingTier.PREMIUM,
67+
"enterprise": PricingTier.ENTERPRISE,
68+
"standard": PricingTier.STANDARD,
69+
"unknown": PricingTier.UNKNOWN,
70+
}
71+
items = self._get(endpoint, api_version="2023-02-01").get("value", [])
72+
for item in sorted(items, key=lambda _: _["name"].lower()):
73+
properties = item["properties"]
74+
if properties["provisioningState"] != "Succeeded":
75+
continue
76+
if "workspaceUrl" not in properties:
77+
continue
78+
parameters = properties.get("parameters", {})
79+
workspace_url = properties["workspaceUrl"]
80+
tags = item.get("tags", {})
81+
if "AzureSubscriptionID" not in tags:
82+
tags["AzureSubscriptionID"] = subscription_id
83+
if "AzureResourceGroup" not in tags:
84+
tags["AzureResourceGroup"] = item["id"].split("resourceGroups/")[1].split("/")[0]
85+
yield Workspace(
86+
cloud="azure",
87+
location=item["location"],
88+
workspace_name=item["name"],
89+
workspace_id=int(properties["workspaceId"]),
90+
workspace_status_message=properties["provisioningState"],
91+
deployment_name=workspace_url.replace(".azuredatabricks.net", ""),
92+
pricing_tier=sku_tiers.get(item.get("sku", {"name": None})["name"], None),
93+
# These fields are just approximation for the fields with same meaning in AWS and GCP
94+
storage_configuration_id=parameters.get("storageAccountName", {"value": None})["value"],
95+
network_id=parameters.get("customVirtualNetworkId", {"value": None})["value"],
96+
custom_tags=tags,
97+
)
98+
99+
100+
class Workspaces:
101+
_tlds: ClassVar[dict[str, str]] = {
102+
"aws": "cloud.databricks.com",
103+
"azure": "azuredatabricks.net",
104+
"gcp": "gcp.databricks.com",
105+
}
106+
107+
def __init__(self, cfg: AccountConfig):
108+
self._ac = cfg.to_account_client()
109+
self._cfg = cfg
110+
111+
def configured_workspaces(self):
112+
for workspace in self._all_workspaces():
113+
if self._cfg.include_workspace_names:
114+
if workspace.workspace_name not in self._cfg.include_workspace_names:
115+
logger.debug(
116+
f"skipping {workspace.workspace_name} ({workspace.workspace_id} because "
117+
f"its not explicitly included"
118+
)
119+
continue
120+
yield workspace
121+
122+
def client_for(self, workspace: Workspace) -> WorkspaceClient:
123+
config = self._ac.config.as_dict()
124+
# copy current config and swap with a host relevant to a workspace
125+
config["host"] = f"https://{workspace.deployment_name}.{self._tlds[workspace.cloud]}"
126+
return WorkspaceClient(**config, product="ucx", product_version=__version__)
127+
128+
def _all_workspaces(self):
129+
if self._ac.config.is_azure:
130+
yield from self._azure_workspaces()
131+
else:
132+
yield from self._native_workspaces()
133+
134+
def _native_workspaces(self):
135+
yield from self._ac.workspaces.list()
136+
137+
def _azure_workspaces(self):
138+
azure_lister = AzureWorkspaceLister(self._ac.config)
139+
for sub in azure_lister.current_tenant_subscriptions():
140+
if self._cfg.include_azure_subscription_ids:
141+
if sub.subscription_id not in self._cfg.include_azure_subscription_ids:
142+
logger.debug(f"skipping {sub.name} ({sub.subscription_id} because its not explicitly included")
143+
continue
144+
if self._cfg.include_azure_subscription_names:
145+
if sub.name not in self._cfg.include_azure_subscription_names:
146+
logger.debug(f"skipping {sub.name} ({sub.subscription_id} because its not explicitly included")
147+
continue
148+
for workspace in azure_lister.list_workspaces(sub.subscription_id):
149+
if "AzureSubscription" not in workspace.custom_tags:
150+
workspace.custom_tags["AzureSubscription"] = sub.name
151+
yield workspace
152+
153+
154+
if __name__ == "__main__":
155+
logger.setLevel("INFO")
156+
157+
config_file = Path.home() / ".ucx/config.yml"
158+
cfg = AccountConfig.from_file(config_file)
159+
wss = Workspaces(cfg)
160+
161+
for workspace in wss.configured_workspaces():
162+
ws = wss.client_for(workspace)
163+
164+
metastore_id = "NOT ASSIGNED"
165+
default_catalog = "hive_metastore"
166+
try:
167+
metastore = ws.metastores.current()
168+
default_catalog = metastore.default_catalog_name
169+
metastore_id = metastore.metastore_id
170+
except DatabricksError:
171+
pass
172+
except ConnectionError:
173+
logger.warning(f"Private DNS for {workspace.workspace_name} is not yet supported?..")
174+
175+
logger.info(
176+
f"workspace: {workspace.workspace_name}: "
177+
f"metastore {metastore_id}, "
178+
f"default catalog: {default_catalog}"
179+
)

0 commit comments

Comments
 (0)