|
| 1 | +import base64 |
| 2 | +import json |
| 3 | +import logging |
| 4 | +from dataclasses import dataclass |
| 5 | +from pathlib import Path |
| 6 | +from typing import ClassVar |
| 7 | + |
| 8 | +import requests |
| 9 | +from databricks.sdk import WorkspaceClient |
| 10 | +from databricks.sdk.core import AzureCliTokenSource, Config, DatabricksError |
| 11 | +from databricks.sdk.service.provisioning import PricingTier, Workspace |
| 12 | +from requests.exceptions import ConnectionError |
| 13 | + |
| 14 | +from databricks.labs.ucx.__about__ import __version__ |
| 15 | +from databricks.labs.ucx.config import AccountConfig |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class AzureSubscription: |
| 22 | + name: str |
| 23 | + subscription_id: str |
| 24 | + tenant_id: str |
| 25 | + |
| 26 | + |
| 27 | +class AzureWorkspaceLister: |
| 28 | + def __init__(self, cfg: Config): |
| 29 | + endpoint = cfg.arm_environment.resource_manager_endpoint |
| 30 | + self._token_source = AzureCliTokenSource(endpoint) |
| 31 | + self._endpoint = endpoint |
| 32 | + |
| 33 | + def _get(self, path: str, *, api_version=None) -> dict: |
| 34 | + token = self._token_source.token() |
| 35 | + headers = {"Authorization": f"{token.token_type} {token.access_token}"} |
| 36 | + return requests.get( |
| 37 | + self._endpoint + path, headers=headers, params={"api-version": api_version}, timeout=10 |
| 38 | + ).json() |
| 39 | + |
| 40 | + def _all_subscriptions(self): |
| 41 | + for sub in self._get("/subscriptions", api_version="2022-12-01").get("value", []): |
| 42 | + yield AzureSubscription( |
| 43 | + name=sub["displayName"], subscription_id=sub["subscriptionId"], tenant_id=sub["tenantId"] |
| 44 | + ) |
| 45 | + |
| 46 | + def _tenant_id(self): |
| 47 | + token = self._token_source.token() |
| 48 | + _, payload, _ = token.access_token.split(".") |
| 49 | + b64_decoded = base64.standard_b64decode(payload + "==").decode("utf8") |
| 50 | + claims = json.loads(b64_decoded) |
| 51 | + return claims["tid"] |
| 52 | + |
| 53 | + def current_tenant_subscriptions(self): |
| 54 | + tenant_id = self._tenant_id() |
| 55 | + for sub in self._all_subscriptions(): |
| 56 | + if sub.tenant_id != tenant_id: |
| 57 | + continue |
| 58 | + yield sub |
| 59 | + |
| 60 | + def subscriptions_name_to_id(self): |
| 61 | + return {sub.name: sub.subscription_id for sub in self.current_tenant_subscriptions()} |
| 62 | + |
| 63 | + def list_workspaces(self, subscription_id): |
| 64 | + endpoint = f"/subscriptions/{subscription_id}/providers/Microsoft.Databricks/workspaces" |
| 65 | + sku_tiers = { |
| 66 | + "premium": PricingTier.PREMIUM, |
| 67 | + "enterprise": PricingTier.ENTERPRISE, |
| 68 | + "standard": PricingTier.STANDARD, |
| 69 | + "unknown": PricingTier.UNKNOWN, |
| 70 | + } |
| 71 | + items = self._get(endpoint, api_version="2023-02-01").get("value", []) |
| 72 | + for item in sorted(items, key=lambda _: _["name"].lower()): |
| 73 | + properties = item["properties"] |
| 74 | + if properties["provisioningState"] != "Succeeded": |
| 75 | + continue |
| 76 | + if "workspaceUrl" not in properties: |
| 77 | + continue |
| 78 | + parameters = properties.get("parameters", {}) |
| 79 | + workspace_url = properties["workspaceUrl"] |
| 80 | + tags = item.get("tags", {}) |
| 81 | + if "AzureSubscriptionID" not in tags: |
| 82 | + tags["AzureSubscriptionID"] = subscription_id |
| 83 | + if "AzureResourceGroup" not in tags: |
| 84 | + tags["AzureResourceGroup"] = item["id"].split("resourceGroups/")[1].split("/")[0] |
| 85 | + yield Workspace( |
| 86 | + cloud="azure", |
| 87 | + location=item["location"], |
| 88 | + workspace_name=item["name"], |
| 89 | + workspace_id=int(properties["workspaceId"]), |
| 90 | + workspace_status_message=properties["provisioningState"], |
| 91 | + deployment_name=workspace_url.replace(".azuredatabricks.net", ""), |
| 92 | + pricing_tier=sku_tiers.get(item.get("sku", {"name": None})["name"], None), |
| 93 | + # These fields are just approximation for the fields with same meaning in AWS and GCP |
| 94 | + storage_configuration_id=parameters.get("storageAccountName", {"value": None})["value"], |
| 95 | + network_id=parameters.get("customVirtualNetworkId", {"value": None})["value"], |
| 96 | + custom_tags=tags, |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +class Workspaces: |
| 101 | + _tlds: ClassVar[dict[str, str]] = { |
| 102 | + "aws": "cloud.databricks.com", |
| 103 | + "azure": "azuredatabricks.net", |
| 104 | + "gcp": "gcp.databricks.com", |
| 105 | + } |
| 106 | + |
| 107 | + def __init__(self, cfg: AccountConfig): |
| 108 | + self._ac = cfg.to_account_client() |
| 109 | + self._cfg = cfg |
| 110 | + |
| 111 | + def configured_workspaces(self): |
| 112 | + for workspace in self._all_workspaces(): |
| 113 | + if self._cfg.include_workspace_names: |
| 114 | + if workspace.workspace_name not in self._cfg.include_workspace_names: |
| 115 | + logger.debug( |
| 116 | + f"skipping {workspace.workspace_name} ({workspace.workspace_id} because " |
| 117 | + f"its not explicitly included" |
| 118 | + ) |
| 119 | + continue |
| 120 | + yield workspace |
| 121 | + |
| 122 | + def client_for(self, workspace: Workspace) -> WorkspaceClient: |
| 123 | + config = self._ac.config.as_dict() |
| 124 | + # copy current config and swap with a host relevant to a workspace |
| 125 | + config["host"] = f"https://{workspace.deployment_name}.{self._tlds[workspace.cloud]}" |
| 126 | + return WorkspaceClient(**config, product="ucx", product_version=__version__) |
| 127 | + |
| 128 | + def _all_workspaces(self): |
| 129 | + if self._ac.config.is_azure: |
| 130 | + yield from self._azure_workspaces() |
| 131 | + else: |
| 132 | + yield from self._native_workspaces() |
| 133 | + |
| 134 | + def _native_workspaces(self): |
| 135 | + yield from self._ac.workspaces.list() |
| 136 | + |
| 137 | + def _azure_workspaces(self): |
| 138 | + azure_lister = AzureWorkspaceLister(self._ac.config) |
| 139 | + for sub in azure_lister.current_tenant_subscriptions(): |
| 140 | + if self._cfg.include_azure_subscription_ids: |
| 141 | + if sub.subscription_id not in self._cfg.include_azure_subscription_ids: |
| 142 | + logger.debug(f"skipping {sub.name} ({sub.subscription_id} because its not explicitly included") |
| 143 | + continue |
| 144 | + if self._cfg.include_azure_subscription_names: |
| 145 | + if sub.name not in self._cfg.include_azure_subscription_names: |
| 146 | + logger.debug(f"skipping {sub.name} ({sub.subscription_id} because its not explicitly included") |
| 147 | + continue |
| 148 | + for workspace in azure_lister.list_workspaces(sub.subscription_id): |
| 149 | + if "AzureSubscription" not in workspace.custom_tags: |
| 150 | + workspace.custom_tags["AzureSubscription"] = sub.name |
| 151 | + yield workspace |
| 152 | + |
| 153 | + |
| 154 | +if __name__ == "__main__": |
| 155 | + logger.setLevel("INFO") |
| 156 | + |
| 157 | + config_file = Path.home() / ".ucx/config.yml" |
| 158 | + cfg = AccountConfig.from_file(config_file) |
| 159 | + wss = Workspaces(cfg) |
| 160 | + |
| 161 | + for workspace in wss.configured_workspaces(): |
| 162 | + ws = wss.client_for(workspace) |
| 163 | + |
| 164 | + metastore_id = "NOT ASSIGNED" |
| 165 | + default_catalog = "hive_metastore" |
| 166 | + try: |
| 167 | + metastore = ws.metastores.current() |
| 168 | + default_catalog = metastore.default_catalog_name |
| 169 | + metastore_id = metastore.metastore_id |
| 170 | + except DatabricksError: |
| 171 | + pass |
| 172 | + except ConnectionError: |
| 173 | + logger.warning(f"Private DNS for {workspace.workspace_name} is not yet supported?..") |
| 174 | + |
| 175 | + logger.info( |
| 176 | + f"workspace: {workspace.workspace_name}: " |
| 177 | + f"metastore {metastore_id}, " |
| 178 | + f"default catalog: {default_catalog}" |
| 179 | + ) |
0 commit comments