Skip to content

Commit 2ab1321

Browse files
authored
Extract command codes and unify the checks for spark_conf, cluster_policy, init_scripts (#855)
1 parent a2b741a commit 2ab1321

File tree

7 files changed

+169
-92
lines changed

7 files changed

+169
-92
lines changed

src/databricks/labs/ucx/assessment/clusters.py

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
1+
import base64
12
import json
3+
import logging
24
from collections.abc import Iterable
35
from dataclasses import dataclass
46

57
from databricks.sdk import WorkspaceClient
68
from databricks.sdk.errors import NotFound
7-
from databricks.sdk.service.compute import ClusterDetails, ClusterSource, Policy
9+
from databricks.sdk.service.compute import (
10+
ClusterDetails,
11+
ClusterSource,
12+
InitScriptInfo,
13+
Policy,
14+
)
815

916
from databricks.labs.ucx.assessment.crawlers import (
1017
_AZURE_SP_CONF_FAILURE_MSG,
18+
_INIT_SCRIPT_DBFS_PATH,
1119
INCOMPATIBLE_SPARK_CONFIG_KEYS,
12-
_azure_sp_conf_in_init_scripts,
1320
_azure_sp_conf_present_check,
14-
_get_init_script_data,
15-
logger,
1621
spark_version_compatibility,
1722
)
23+
from databricks.labs.ucx.assessment.init_scripts import CheckInitScriptMixin
1824
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
1925

26+
logger = logging.getLogger(__name__)
27+
2028

2129
@dataclass
2230
class ClusterInfo:
@@ -27,7 +35,7 @@ class ClusterInfo:
2735
creator: str | None = None
2836

2937

30-
class ClustersMixin:
38+
class CheckClusterMixin(CheckInitScriptMixin):
3139
_ws: WorkspaceClient
3240

3341
def _safe_get_cluster_policy(self, policy_id: str) -> Policy | None:
@@ -37,62 +45,77 @@ def _safe_get_cluster_policy(self, policy_id: str) -> Policy | None:
3745
logger.warning(f"The cluster policy was deleted: {policy_id}")
3846
return None
3947

40-
def _check_spark_conf(self, cluster, failures):
48+
def _check_cluster_policy(self, policy_id: str, source: str) -> list[str]:
49+
failures: list[str] = []
50+
policy = self._safe_get_cluster_policy(policy_id)
51+
if policy:
52+
if policy.definition:
53+
if _azure_sp_conf_present_check(json.loads(policy.definition)):
54+
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
55+
if policy.policy_family_definition_overrides:
56+
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
57+
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
58+
return failures
59+
60+
def _get_init_script_data(self, init_script_info: InitScriptInfo) -> str | None:
61+
if init_script_info.dbfs is not None and init_script_info.dbfs.destination is not None:
62+
if len(init_script_info.dbfs.destination.split(":")) == _INIT_SCRIPT_DBFS_PATH:
63+
file_api_format_destination = init_script_info.dbfs.destination.split(":")[1]
64+
if file_api_format_destination:
65+
try:
66+
data = self._ws.dbfs.read(file_api_format_destination).data
67+
if data is not None:
68+
return base64.b64decode(data).decode("utf-8")
69+
except NotFound:
70+
return None
71+
if init_script_info.workspace is not None and init_script_info.workspace.destination is not None:
72+
workspace_file_destination = init_script_info.workspace.destination
73+
try:
74+
data = self._ws.workspace.export(workspace_file_destination).content
75+
if data is not None:
76+
return base64.b64decode(data).decode("utf-8")
77+
except NotFound:
78+
return None
79+
return None
80+
81+
def _check_cluster_init_script(self, init_scripts: list[InitScriptInfo], source: str) -> list[str]:
82+
failures: list[str] = []
83+
for init_script_info in init_scripts:
84+
init_script_data = self._get_init_script_data(init_script_info)
85+
failures.extend(self.check_init_script(init_script_data, source))
86+
return failures
87+
88+
def check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
89+
failures: list[str] = []
4190
for k in INCOMPATIBLE_SPARK_CONFIG_KEYS:
42-
if k in cluster.spark_conf:
91+
if k in conf:
4392
failures.append(f"unsupported config: {k}")
44-
for value in cluster.spark_conf.values():
93+
for value in conf.values():
4594
if "dbfs:/mnt" in value or "/dbfs/mnt" in value:
4695
failures.append(f"using DBFS mount in configuration: {value}")
4796
# Checking if Azure cluster config is present in spark config
48-
if _azure_sp_conf_present_check(cluster.spark_conf):
49-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
97+
if _azure_sp_conf_present_check(conf):
98+
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
99+
return failures
50100

51-
def _check_cluster_policy(self, cluster, failures):
52-
policy = self._safe_get_cluster_policy(cluster.policy_id)
53-
if policy:
54-
if policy.definition:
55-
if _azure_sp_conf_present_check(json.loads(policy.definition)):
56-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
57-
if policy.policy_family_definition_overrides:
58-
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
59-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
101+
def check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
102+
failures: list[str] = []
60103

61-
def _check_init_scripts(self, cluster, failures):
62-
for init_script_info in cluster.init_scripts:
63-
init_script_data = _get_init_script_data(self._ws, init_script_info)
64-
if not init_script_data:
65-
continue
66-
if not _azure_sp_conf_in_init_scripts(init_script_data):
67-
continue
68-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} cluster.")
69-
70-
def _check_cluster_failures(self, cluster: ClusterDetails):
71-
failures = []
72-
cluster_info = ClusterInfo(
73-
cluster_id=cluster.cluster_id if cluster.cluster_id else "",
74-
cluster_name=cluster.cluster_name,
75-
creator=cluster.creator_user_name,
76-
success=1,
77-
failures="[]",
78-
)
79104
support_status = spark_version_compatibility(cluster.spark_version)
80105
if support_status != "supported":
81106
failures.append(f"not supported DBR: {cluster.spark_version}")
82107
if cluster.spark_conf is not None:
83-
self._check_spark_conf(cluster, failures)
108+
failures.extend(self.check_spark_conf(cluster.spark_conf, source))
84109
# Checking if Azure cluster config is present in cluster policies
85-
if cluster.policy_id:
86-
self._check_cluster_policy(cluster, failures)
87-
if cluster.init_scripts:
88-
self._check_init_scripts(cluster, failures)
89-
cluster_info.failures = json.dumps(failures)
90-
if len(failures) > 0:
91-
cluster_info.success = 0
92-
return cluster_info
110+
if cluster.policy_id is not None:
111+
failures.extend(self._check_cluster_policy(cluster.policy_id, source))
112+
if cluster.init_scripts is not None:
113+
failures.extend(self._check_cluster_init_script(cluster.init_scripts, source))
114+
115+
return failures
93116

94117

95-
class ClustersCrawler(CrawlerBase[ClusterInfo], ClustersMixin):
118+
class ClustersCrawler(CrawlerBase[ClusterInfo], CheckClusterMixin):
96119
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
97120
super().__init__(sbe, "hive_metastore", schema, "clusters", ClusterInfo)
98121
self._ws = ws
@@ -110,7 +133,18 @@ def _assess_clusters(self, all_clusters):
110133
f"Cluster {cluster.cluster_id} have Unknown creator, it means that the original creator "
111134
f"has been deleted and should be re-created"
112135
)
113-
yield self._check_cluster_failures(cluster)
136+
cluster_info = ClusterInfo(
137+
cluster_id=cluster.cluster_id if cluster.cluster_id else "",
138+
cluster_name=cluster.cluster_name,
139+
creator=cluster.creator_user_name,
140+
success=1,
141+
failures="[]",
142+
)
143+
failures = self.check_cluster_failures(cluster, "cluster")
144+
if len(failures) > 0:
145+
cluster_info.success = 0
146+
cluster_info.failures = json.dumps(failures)
147+
yield cluster_info
114148

115149
def snapshot(self) -> Iterable[ClusterInfo]:
116150
return self._snapshot(self._try_fetch, self._crawl)

src/databricks/labs/ucx/assessment/crawlers.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
import base64
21
import logging
32
import re
43

5-
from databricks.sdk.errors import NotFound
6-
74
logger = logging.getLogger(__name__)
85

96
INCOMPATIBLE_SPARK_CONFIG_KEYS = [
@@ -27,27 +24,6 @@
2724
_INIT_SCRIPT_DBFS_PATH = 2
2825

2926

30-
def _get_init_script_data(w, init_script_info):
31-
if init_script_info.dbfs:
32-
if len(init_script_info.dbfs.destination.split(":")) == _INIT_SCRIPT_DBFS_PATH:
33-
file_api_format_destination = init_script_info.dbfs.destination.split(":")[1]
34-
if file_api_format_destination:
35-
try:
36-
data = w.dbfs.read(file_api_format_destination).data
37-
return base64.b64decode(data).decode("utf-8")
38-
except NotFound:
39-
return None
40-
if init_script_info.workspace:
41-
workspace_file_destination = init_script_info.workspace.destination
42-
if workspace_file_destination:
43-
try:
44-
data = w.workspace.export(workspace_file_destination).content
45-
return base64.b64decode(data).decode("utf-8")
46-
except NotFound:
47-
return None
48-
return None
49-
50-
5127
def _azure_sp_conf_in_init_scripts(init_script_data: str) -> bool:
5228
for conf in _AZURE_SP_CONF:
5329
if re.search(conf, init_script_data):

src/databricks/labs/ucx/assessment/init_scripts.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import json
3+
import logging
34
from collections.abc import Iterable
45
from dataclasses import dataclass
56

@@ -8,10 +9,11 @@
89
from databricks.labs.ucx.assessment.crawlers import (
910
_AZURE_SP_CONF_FAILURE_MSG,
1011
_azure_sp_conf_in_init_scripts,
11-
logger,
1212
)
1313
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
1414

15+
logger = logging.getLogger(__name__)
16+
1517

1618
@dataclass
1719
class GlobalInitScriptInfo:
@@ -23,7 +25,19 @@ class GlobalInitScriptInfo:
2325
enabled: bool | None = None
2426

2527

26-
class GlobalInitScriptCrawler(CrawlerBase[GlobalInitScriptInfo]):
28+
class CheckInitScriptMixin:
29+
_ws: WorkspaceClient
30+
31+
def check_init_script(self, init_script_data: str | None, source: str) -> list[str]:
32+
failures: list[str] = []
33+
if not init_script_data:
34+
return failures
35+
if _azure_sp_conf_in_init_scripts(init_script_data):
36+
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
37+
return failures
38+
39+
40+
class GlobalInitScriptCrawler(CrawlerBase[GlobalInitScriptInfo], CheckInitScriptMixin):
2741
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
2842
super().__init__(sbe, "hive_metastore", schema, "global_init_scripts", GlobalInitScriptInfo)
2943
self._ws = ws
@@ -52,9 +66,8 @@ def _assess_global_init_scripts(self, all_global_init_scripts):
5266
global_init_script = base64.b64decode(script.script).decode("utf-8")
5367
if not global_init_script:
5468
continue
55-
if _azure_sp_conf_in_init_scripts(global_init_script):
56-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} global init script.")
57-
global_init_script_info.failures = json.dumps(failures)
69+
failures.extend(self.check_init_script(global_init_script, "global init script"))
70+
global_init_script_info.failures = json.dumps(failures)
5871

5972
if len(failures) > 0:
6073
global_init_script_info.success = 0

src/databricks/labs/ucx/assessment/jobs.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import json
2+
import logging
23
from collections.abc import Iterable
34
from dataclasses import dataclass
45

56
from databricks.sdk import WorkspaceClient
67
from databricks.sdk.service.compute import ClusterDetails
78
from databricks.sdk.service.jobs import BaseJob
89

9-
from databricks.labs.ucx.assessment.clusters import ClustersMixin
10-
from databricks.labs.ucx.assessment.crawlers import logger
10+
from databricks.labs.ucx.assessment.clusters import CheckClusterMixin
1111
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
1212

13+
logger = logging.getLogger(__name__)
14+
1315

1416
@dataclass
1517
class JobInfo:
@@ -20,7 +22,7 @@ class JobInfo:
2022
creator: str | None = None
2123

2224

23-
class JobsMixin(ClustersMixin):
25+
class JobsMixin:
2426
@staticmethod
2527
def _get_cluster_configs_from_all_jobs(all_jobs, all_clusters_by_id):
2628
for j in all_jobs:
@@ -44,7 +46,7 @@ def _get_cluster_configs_from_all_jobs(all_jobs, all_clusters_by_id):
4446
yield j, t.new_cluster
4547

4648

47-
class JobsCrawler(CrawlerBase[JobInfo], JobsMixin):
49+
class JobsCrawler(CrawlerBase[JobInfo], JobsMixin, CheckClusterMixin):
4850
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
4951
super().__init__(sbe, "hive_metastore", schema, "jobs", JobInfo)
5052
self._ws = ws
@@ -86,9 +88,8 @@ def _assess_jobs(self, all_jobs: list[BaseJob], all_clusters_by_id) -> Iterable[
8688
if not job_id:
8789
continue
8890
cluster_details = ClusterDetails.from_dict(cluster_config.as_dict())
89-
cluster_failures = self._check_cluster_failures(cluster_details)
90-
for failure in json.loads(cluster_failures.failures):
91-
job_assessment[job_id].add(failure)
91+
cluster_failures = self.check_cluster_failures(cluster_details, "Job cluster")
92+
job_assessment[job_id].update(cluster_failures)
9293

9394
# TODO: next person looking at this - rewrite, as this code makes no sense
9495
for job_key in job_details.keys(): # pylint: disable=consider-using-dict-items,consider-iterating-dictionary

src/databricks/labs/ucx/assessment/pipelines.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import json
2+
import logging
23
from collections.abc import Iterable
34
from dataclasses import dataclass
45

56
from databricks.sdk import WorkspaceClient
67

7-
from databricks.labs.ucx.assessment.crawlers import (
8-
_AZURE_SP_CONF_FAILURE_MSG,
9-
_azure_sp_conf_present_check,
10-
logger,
11-
)
8+
from databricks.labs.ucx.assessment.clusters import CheckClusterMixin
129
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
1310

11+
logger = logging.getLogger(__name__)
12+
1413

1514
@dataclass
1615
class PipelineInfo:
@@ -21,7 +20,7 @@ class PipelineInfo:
2120
creator_name: str | None = None
2221

2322

24-
class PipelinesCrawler(CrawlerBase[PipelineInfo]):
23+
class PipelinesCrawler(CrawlerBase[PipelineInfo], CheckClusterMixin):
2524
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema):
2625
super().__init__(sbe, "hive_metastore", schema, "pipelines", PipelineInfo)
2726
self._ws = ws
@@ -50,8 +49,7 @@ def _assess_pipelines(self, all_pipelines) -> Iterable[PipelineInfo]:
5049
assert pipeline_response.spec is not None
5150
pipeline_config = pipeline_response.spec.configuration
5251
if pipeline_config:
53-
if _azure_sp_conf_present_check(pipeline_config):
54-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} pipeline.")
52+
failures.extend(self.check_spark_conf(pipeline_config, "pipeline"))
5553

5654
pipeline_info.failures = json.dumps(failures)
5755
if len(failures) > 0:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
[
2+
{
3+
"autoscale": {
4+
"max_workers": 6,
5+
"min_workers": 1
6+
},
7+
"cluster_source": "JOB",
8+
"creator_user_name":"[email protected]",
9+
"cluster_id": "0123-190044-1122334422",
10+
"cluster_name": "Single User Cluster Name",
11+
"policy_id": "single-user-with-spn",
12+
"spark_version": "9.3.x-cpu-ml-scala2.12",
13+
"spark_conf" : {
14+
"spark.databricks.delta.preview.enabled": "true"
15+
},
16+
"spark_context_id":"5134472582179565315"
17+
},
18+
{
19+
"autoscale": {
20+
"max_workers": 6,
21+
"min_workers": 1
22+
},
23+
"creator_user_name":"[email protected]",
24+
"cluster_id": "0123-190044-1122334411",
25+
"cluster_name": "Single User Cluster Name",
26+
"policy_id": "azure-oauth",
27+
"spark_version": "13.3.x-cpu-ml-scala2.12",
28+
"spark_conf" : {
29+
"spark.databricks.delta.preview.enabled": "true"
30+
},
31+
"spark_context_id":"5134472582179565315"
32+
}
33+
]

0 commit comments

Comments
 (0)