|
1 | 1 | import base64 |
| 2 | +import csv |
| 3 | +import dataclasses |
| 4 | +import io |
2 | 5 | import json |
3 | 6 | import re |
4 | 7 | from collections.abc import Iterable |
5 | 8 | from dataclasses import dataclass |
6 | 9 |
|
7 | 10 | from databricks.sdk import WorkspaceClient |
| 11 | +from databricks.sdk.core import ( |
| 12 | + ApiClient, |
| 13 | + AzureCliTokenSource, |
| 14 | + Config, |
| 15 | + credentials_provider, |
| 16 | +) |
8 | 17 | from databricks.sdk.errors import NotFound |
| 18 | +from databricks.sdk.service.catalog import Privilege |
9 | 19 | from databricks.sdk.service.compute import ClusterSource, Policy |
| 20 | +from databricks.sdk.service.workspace import ImportFormat |
10 | 21 |
|
11 | 22 | from databricks.labs.ucx.assessment.crawlers import ( |
12 | 23 | _CLIENT_ENDPOINT_LENGTH, |
|
17 | 28 | logger, |
18 | 29 | ) |
19 | 30 | from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend |
| 31 | +from databricks.labs.ucx.hive_metastore.locations import ExternalLocations |
20 | 32 |
|
21 | 33 |
|
22 | 34 | @dataclass |
@@ -249,3 +261,300 @@ def snapshot(self) -> Iterable[AzureServicePrincipalInfo]: |
249 | 261 | def _try_fetch(self) -> Iterable[AzureServicePrincipalInfo]: |
250 | 262 | for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"): |
251 | 263 | yield AzureServicePrincipalInfo(*row) |
| 264 | + |
| 265 | + |
| 266 | +@dataclass |
| 267 | +class AzureSubscription: |
| 268 | + name: str |
| 269 | + subscription_id: str |
| 270 | + tenant_id: str |
| 271 | + |
| 272 | + |
| 273 | +class AzureResource: |
| 274 | + def __init__(self, resource_id: str): |
| 275 | + self._pairs = {} |
| 276 | + self._resource_id = resource_id |
| 277 | + split = resource_id.lstrip("/").split("/") |
| 278 | + if len(split) % 2 != 0: |
| 279 | + msg = f"not a list of pairs: {resource_id}" |
| 280 | + raise ValueError(msg) |
| 281 | + i = 0 |
| 282 | + while i < len(split): |
| 283 | + k = split[i] |
| 284 | + v = split[i + 1] |
| 285 | + i += 2 |
| 286 | + self._pairs[k] = v |
| 287 | + |
| 288 | + @property |
| 289 | + def subscription_id(self): |
| 290 | + return self._pairs.get("subscriptions") |
| 291 | + |
| 292 | + @property |
| 293 | + def resource_group(self): |
| 294 | + return self._pairs.get("resourceGroups") |
| 295 | + |
| 296 | + @property |
| 297 | + def storage_account(self): |
| 298 | + return self._pairs.get("storageAccounts") |
| 299 | + |
| 300 | + @property |
| 301 | + def container(self): |
| 302 | + return self._pairs.get("containers") |
| 303 | + |
| 304 | + def __eq__(self, other): |
| 305 | + if not isinstance(other, AzureResource): |
| 306 | + return NotImplemented |
| 307 | + return self._resource_id == other._resource_id |
| 308 | + |
| 309 | + def __repr__(self): |
| 310 | + properties = ["subscription_id", "resource_group", "storage_account", "container"] |
| 311 | + pairs = [f"{_}={getattr(self, _)}" for _ in properties] |
| 312 | + return f'AzureResource<{", ".join(pairs)}>' |
| 313 | + |
| 314 | + def __str__(self): |
| 315 | + return self._resource_id |
| 316 | + |
| 317 | + |
| 318 | +@dataclass |
| 319 | +class Principal: |
| 320 | + client_id: str |
| 321 | + display_name: str |
| 322 | + object_id: str |
| 323 | + |
| 324 | + |
| 325 | +@dataclass |
| 326 | +class AzureRoleAssignment: |
| 327 | + resource: AzureResource |
| 328 | + scope: AzureResource |
| 329 | + principal: Principal |
| 330 | + role_name: str |
| 331 | + |
| 332 | + |
| 333 | +class AzureResources: |
| 334 | + def __init__(self, ws: WorkspaceClient, *, include_subscriptions=None): |
| 335 | + if not include_subscriptions: |
| 336 | + include_subscriptions = [] |
| 337 | + rm_host = ws.config.arm_environment.resource_manager_endpoint |
| 338 | + self._resource_manager = ApiClient( |
| 339 | + Config( |
| 340 | + host=rm_host, |
| 341 | + credentials_provider=self._provider_for(ws.config.arm_environment.service_management_endpoint), |
| 342 | + ) |
| 343 | + ) |
| 344 | + self._graph = ApiClient( |
| 345 | + Config( |
| 346 | + host="https://graph.microsoft.com", |
| 347 | + credentials_provider=self._provider_for("https://graph.microsoft.com"), |
| 348 | + ) |
| 349 | + ) |
| 350 | + self._token_source = AzureCliTokenSource(rm_host) |
| 351 | + self._include_subscriptions = include_subscriptions |
| 352 | + self._role_definitions = {} # type: dict[str, str] |
| 353 | + self._principals: dict[str, Principal | None] = {} |
| 354 | + |
| 355 | + def _provider_for(self, endpoint: str): |
| 356 | + @credentials_provider("azure-cli", ["host"]) |
| 357 | + def _credentials(_: Config): |
| 358 | + token_source = AzureCliTokenSource(endpoint) |
| 359 | + |
| 360 | + def inner() -> dict[str, str]: |
| 361 | + token = token_source.token() |
| 362 | + return {"Authorization": f"{token.token_type} {token.access_token}"} |
| 363 | + |
| 364 | + return inner |
| 365 | + |
| 366 | + return _credentials |
| 367 | + |
| 368 | + def _get_subscriptions(self) -> Iterable[AzureSubscription]: |
| 369 | + for subscription in self._get_resource("/subscriptions", api_version="2022-12-01").get("value", []): |
| 370 | + yield AzureSubscription( |
| 371 | + name=subscription["displayName"], |
| 372 | + subscription_id=subscription["subscriptionId"], |
| 373 | + tenant_id=subscription["tenantId"], |
| 374 | + ) |
| 375 | + |
| 376 | + def _tenant_id(self): |
| 377 | + token = self._token_source.token() |
| 378 | + return token.jwt_claims().get("tid") |
| 379 | + |
| 380 | + def subscriptions(self): |
| 381 | + tenant_id = self._tenant_id() |
| 382 | + for subscription in self._get_subscriptions(): |
| 383 | + if subscription.tenant_id != tenant_id: |
| 384 | + continue |
| 385 | + if subscription.subscription_id not in self._include_subscriptions: |
| 386 | + continue |
| 387 | + yield subscription |
| 388 | + |
| 389 | + def _get_resource(self, path: str, api_version: str): |
| 390 | + headers = {"Accept": "application/json"} |
| 391 | + query = {"api-version": api_version} |
| 392 | + return self._resource_manager.do("GET", path, query=query, headers=headers) |
| 393 | + |
| 394 | + def storage_accounts(self) -> Iterable[AzureResource]: |
| 395 | + for subscription in self.subscriptions(): |
| 396 | + logger.info(f"Checking in subscription {subscription.name} for storage accounts") |
| 397 | + path = f"/subscriptions/{subscription.subscription_id}/providers/Microsoft.Storage/storageAccounts" |
| 398 | + for storage in self._get_resource(path, "2023-01-01").get("value", []): |
| 399 | + resource_id = storage.get("id") |
| 400 | + if not resource_id: |
| 401 | + continue |
| 402 | + yield AzureResource(resource_id) |
| 403 | + |
| 404 | + def containers(self, storage: AzureResource): |
| 405 | + for raw in self._get_resource(f"{storage}/blobServices/default/containers", "2023-01-01").get("value", []): |
| 406 | + resource_id = raw.get("id") |
| 407 | + if not resource_id: |
| 408 | + continue |
| 409 | + yield AzureResource(resource_id) |
| 410 | + |
| 411 | + def _get_principal(self, principal_id: str) -> Principal | None: |
| 412 | + if principal_id in self._principals: |
| 413 | + return self._principals[principal_id] |
| 414 | + try: |
| 415 | + path = f"/v1.0/directoryObjects/{principal_id}" |
| 416 | + raw: dict[str, str] = self._graph.do("GET", path) # type: ignore[assignment] |
| 417 | + client_id = raw.get("appId") |
| 418 | + display_name = raw.get("displayName") |
| 419 | + object_id = raw.get("id") |
| 420 | + assert client_id is not None |
| 421 | + assert display_name is not None |
| 422 | + assert object_id is not None |
| 423 | + self._principals[principal_id] = Principal(client_id, display_name, object_id) |
| 424 | + return self._principals[principal_id] |
| 425 | + except NotFound: |
| 426 | + # don't load principals from external directories twice |
| 427 | + self._principals[principal_id] = None |
| 428 | + return self._principals[principal_id] |
| 429 | + |
| 430 | + def role_assignments( |
| 431 | + self, resource_id: str, *, principal_types: list[str] | None = None |
| 432 | + ) -> Iterable[AzureRoleAssignment]: |
| 433 | + """See https://learn.microsoft.com/en-us/rest/api/authorization/role-assignments/list-for-resource""" |
| 434 | + if not principal_types: |
| 435 | + principal_types = ["ServicePrincipal"] |
| 436 | + result = self._get_resource(f"{resource_id}/providers/Microsoft.Authorization/roleAssignments", "2022-04-01") |
| 437 | + for role_assignment in result.get("value", []): |
| 438 | + assignment_properties = role_assignment.get("properties", {}) |
| 439 | + principal_type = assignment_properties.get("principalType") |
| 440 | + if not principal_type: |
| 441 | + continue |
| 442 | + if principal_type not in principal_types: |
| 443 | + continue |
| 444 | + principal_id = assignment_properties.get("principalId") |
| 445 | + if not principal_id: |
| 446 | + continue |
| 447 | + role_definition_id = assignment_properties.get("roleDefinitionId") |
| 448 | + if not role_definition_id: |
| 449 | + continue |
| 450 | + scope = assignment_properties.get("scope") |
| 451 | + if not scope: |
| 452 | + continue |
| 453 | + if role_definition_id not in self._role_definitions: |
| 454 | + role_definition = self._get_resource(role_definition_id, "2022-04-01") |
| 455 | + definition_properties = role_definition.get("properties", {}) |
| 456 | + role_name: str = definition_properties.get("roleName") |
| 457 | + if not role_name: |
| 458 | + continue |
| 459 | + self._role_definitions[role_definition_id] = role_name |
| 460 | + principal = self._get_principal(principal_id) |
| 461 | + if not principal: |
| 462 | + continue |
| 463 | + role_name = self._role_definitions[role_definition_id] |
| 464 | + if scope == "/": |
| 465 | + scope = resource_id |
| 466 | + yield AzureRoleAssignment( |
| 467 | + resource=AzureResource(resource_id), |
| 468 | + scope=AzureResource(scope), |
| 469 | + principal=principal, |
| 470 | + role_name=role_name, |
| 471 | + ) |
| 472 | + |
| 473 | + |
| 474 | +@dataclass |
| 475 | +class StoragePermissionMapping: |
| 476 | + prefix: str |
| 477 | + client_id: str |
| 478 | + principal: str |
| 479 | + privilege: str |
| 480 | + |
| 481 | + |
| 482 | +class AzureResourcePermissions: |
| 483 | + def __init__(self, ws: WorkspaceClient, azurerm: AzureResources, lc: ExternalLocations, folder: str | None = None): |
| 484 | + self._locations = lc |
| 485 | + self._azurerm = azurerm |
| 486 | + self._ws = ws |
| 487 | + self._field_names = [_.name for _ in dataclasses.fields(StoragePermissionMapping)] |
| 488 | + if not folder: |
| 489 | + folder = f"/Users/{ws.current_user.me().user_name}/.ucx" |
| 490 | + self._folder = folder |
| 491 | + self._levels = { |
| 492 | + "Storage Blob Data Contributor": Privilege.WRITE_FILES, |
| 493 | + "Storage Blob Data Owner": Privilege.WRITE_FILES, |
| 494 | + "Storage Blob Data Reader": Privilege.READ_FILES, |
| 495 | + } |
| 496 | + |
| 497 | + def _map_storage(self, storage: AzureResource) -> list[StoragePermissionMapping]: |
| 498 | + logger.info(f"Fetching role assignment for {storage.storage_account}") |
| 499 | + out = [] |
| 500 | + for container in self._azurerm.containers(storage): |
| 501 | + for role_assignment in self._azurerm.role_assignments(str(container)): |
| 502 | + # one principal may be assigned multiple roles with overlapping dataActions, hence appearing |
| 503 | + # here in duplicates. hence, role name -> permission level is not enough for the perfect scenario. |
| 504 | + if role_assignment.role_name not in self._levels: |
| 505 | + continue |
| 506 | + privilege = self._levels[role_assignment.role_name].value |
| 507 | + out.append( |
| 508 | + StoragePermissionMapping( |
| 509 | + prefix=f"abfss://{container.container}@{container.storage_account}.dfs.core.windows.net/", |
| 510 | + client_id=role_assignment.principal.client_id, |
| 511 | + principal=role_assignment.principal.display_name, |
| 512 | + privilege=privilege, |
| 513 | + ) |
| 514 | + ) |
| 515 | + return out |
| 516 | + |
| 517 | + def save_spn_permissions(self) -> str | None: |
| 518 | + used_storage_accounts = self._get_storage_accounts() |
| 519 | + if len(used_storage_accounts) == 0: |
| 520 | + logger.warning( |
| 521 | + "There are no external table present with azure storage account. " |
| 522 | + "Please check if assessment job is run" |
| 523 | + ) |
| 524 | + return None |
| 525 | + storage_account_infos = [] |
| 526 | + for storage in self._azurerm.storage_accounts(): |
| 527 | + if storage.storage_account not in used_storage_accounts: |
| 528 | + continue |
| 529 | + for mapping in self._map_storage(storage): |
| 530 | + storage_account_infos.append(mapping) |
| 531 | + if len(storage_account_infos) == 0: |
| 532 | + logger.error("No storage account found in current tenant with spn permission") |
| 533 | + return None |
| 534 | + return self._save(storage_account_infos) |
| 535 | + |
| 536 | + def _save(self, storage_infos: list[StoragePermissionMapping]) -> str: |
| 537 | + buffer = io.StringIO() |
| 538 | + writer = csv.DictWriter(buffer, self._field_names) |
| 539 | + writer.writeheader() |
| 540 | + for storage_info in storage_infos: |
| 541 | + writer.writerow(dataclasses.asdict(storage_info)) |
| 542 | + buffer.seek(0) |
| 543 | + return self._overwrite_mapping(buffer) |
| 544 | + |
| 545 | + def _overwrite_mapping(self, buffer) -> str: |
| 546 | + path = f"{self._folder}/azure_storage_account_info.csv" |
| 547 | + self._ws.workspace.upload(path, buffer, overwrite=True, format=ImportFormat.AUTO) |
| 548 | + return path |
| 549 | + |
| 550 | + def _get_storage_accounts(self) -> list[str]: |
| 551 | + external_locations = self._locations.snapshot() |
| 552 | + storage_accounts = [] |
| 553 | + for location in external_locations: |
| 554 | + if location.location.startswith("abfss://"): |
| 555 | + start = location.location.index("@") |
| 556 | + end = location.location.index(".dfs.core.windows.net") |
| 557 | + storage_acct = location.location[start + 1 : end] |
| 558 | + if storage_acct not in storage_accounts: |
| 559 | + storage_accounts.append(storage_acct) |
| 560 | + return storage_accounts |
0 commit comments