Skip to content

Commit 906e187

Browse files
authored
Added assessment for the incompatible RunSubmit API usages (#849)
1 parent c849cdf commit 906e187

25 files changed

+624
-13
lines changed

src/databricks/labs/ucx/assessment/clusters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _check_cluster_init_script(self, init_scripts: list[InitScriptInfo], source:
8585
failures.extend(self.check_init_script(init_script_data, source))
8686
return failures
8787

88-
def check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
88+
def _check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
8989
failures: list[str] = []
9090
for k in INCOMPATIBLE_SPARK_CONFIG_KEYS:
9191
if k in conf:
@@ -98,7 +98,7 @@ def check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
9898
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
9999
return failures
100100

101-
def check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
101+
def _check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
102102
failures: list[str] = []
103103

104104
unsupported_cluster_types = [
@@ -110,7 +110,7 @@ def check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[s
110110
if support_status != "supported":
111111
failures.append(f"not supported DBR: {cluster.spark_version}")
112112
if cluster.spark_conf is not None:
113-
failures.extend(self.check_spark_conf(cluster.spark_conf, source))
113+
failures.extend(self._check_spark_conf(cluster.spark_conf, source))
114114
# Checking if Azure cluster config is present in cluster policies
115115
if cluster.policy_id is not None:
116116
failures.extend(self._check_cluster_policy(cluster.policy_id, source))
@@ -149,7 +149,7 @@ def _assess_clusters(self, all_clusters):
149149
success=1,
150150
failures="[]",
151151
)
152-
failures = self.check_cluster_failures(cluster, "cluster")
152+
failures = self._check_cluster_failures(cluster, "cluster")
153153
if len(failures) > 0:
154154
cluster_info.success = 0
155155
cluster_info.failures = json.dumps(failures)

src/databricks/labs/ucx/assessment/jobs.py

Lines changed: 226 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,27 @@
22
import logging
33
from collections.abc import Iterable
44
from dataclasses import dataclass
5+
from datetime import datetime, timedelta, timezone
6+
from hashlib import sha256
57

68
from databricks.sdk import WorkspaceClient
9+
from databricks.sdk.service import compute
710
from databricks.sdk.service.compute import ClusterDetails
8-
from databricks.sdk.service.jobs import BaseJob
11+
from databricks.sdk.service.jobs import (
12+
BaseJob,
13+
BaseRun,
14+
DbtTask,
15+
GitSource,
16+
ListRunsRunType,
17+
PythonWheelTask,
18+
RunConditionTask,
19+
RunTask,
20+
SparkJarTask,
21+
SqlTask,
22+
)
923

1024
from databricks.labs.ucx.assessment.clusters import CheckClusterMixin
25+
from databricks.labs.ucx.assessment.crawlers import spark_version_compatibility
1126
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
1227

1328
logger = logging.getLogger(__name__)
@@ -63,7 +78,7 @@ def _assess_jobs(self, all_jobs: list[BaseJob], all_clusters_by_id) -> Iterable[
6378
if not job_id:
6479
continue
6580
cluster_details = ClusterDetails.from_dict(cluster_config.as_dict())
66-
cluster_failures = self.check_cluster_failures(cluster_details, "Job cluster")
81+
cluster_failures = self._check_cluster_failures(cluster_details, "Job cluster")
6782
job_assessment[job_id].update(cluster_failures)
6883

6984
# TODO: next person looking at this - rewrite, as this code makes no sense
@@ -108,3 +123,212 @@ def snapshot(self) -> Iterable[JobInfo]:
108123
def _try_fetch(self) -> Iterable[JobInfo]:
109124
for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"):
110125
yield JobInfo(*row)
126+
127+
128+
@dataclass
129+
class SubmitRunInfo:
130+
run_ids: str # JSON-encoded list of run ids
131+
hashed_id: str # a pseudo id that combines all the hashable attributes of the run
132+
failures: str = "[]" # JSON-encoded list of failures
133+
134+
135+
class SubmitRunsCrawler(CrawlerBase[SubmitRunInfo], JobsMixin, CheckClusterMixin):
136+
_FS_LEVEL_CONF_SETTING_PATTERNS = [
137+
"fs.s3a",
138+
"fs.s3n",
139+
"fs.s3",
140+
"fs.azure",
141+
"fs.wasb",
142+
"fs.abfs",
143+
"fs.adl",
144+
]
145+
146+
def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema: str, num_days_history: int):
147+
super().__init__(sbe, "hive_metastore", schema, "submit_runs", SubmitRunInfo)
148+
self._ws = ws
149+
self._num_days_history = num_days_history
150+
151+
def snapshot(self) -> Iterable[SubmitRunInfo]:
152+
return self._snapshot(self._try_fetch, self._crawl)
153+
154+
@staticmethod
155+
def _dt_to_ms(date_time: datetime):
156+
return int(date_time.timestamp() * 1000)
157+
158+
@staticmethod
159+
def _get_current_dttm() -> datetime:
160+
return datetime.now(timezone.utc)
161+
162+
def _crawl(self) -> Iterable[SubmitRunInfo]:
163+
end = self._dt_to_ms(self._get_current_dttm())
164+
start = self._dt_to_ms(self._get_current_dttm() - timedelta(days=self._num_days_history))
165+
submit_runs = self._ws.jobs.list_runs(
166+
expand_tasks=True,
167+
completed_only=True,
168+
run_type=ListRunsRunType.SUBMIT_RUN,
169+
start_time_from=start,
170+
start_time_to=end,
171+
)
172+
all_clusters = {c.cluster_id: c for c in self._ws.clusters.list()}
173+
return self._assess_job_runs(submit_runs, all_clusters)
174+
175+
def _try_fetch(self) -> Iterable[SubmitRunInfo]:
176+
for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"):
177+
yield SubmitRunInfo(*row)
178+
179+
def _check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
180+
failures: list[str] = []
181+
for key in conf.keys():
182+
if any(pattern in key for pattern in self._FS_LEVEL_CONF_SETTING_PATTERNS):
183+
failures.append(f"Potentially unsupported config property: {key}")
184+
185+
failures.extend(super()._check_spark_conf(conf, source))
186+
return failures
187+
188+
def _check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:
189+
failures: list[str] = []
190+
if cluster.aws_attributes and cluster.aws_attributes.instance_profile_arn:
191+
failures.append(f"using instance profile: {cluster.aws_attributes.instance_profile_arn}")
192+
193+
failures.extend(super()._check_cluster_failures(cluster, source))
194+
return failures
195+
196+
@staticmethod
197+
def _needs_compatibility_check(spec: compute.ClusterSpec) -> bool:
198+
"""
199+
# we recognize a task as a potentially incompatible one if:
200+
# 1. cluster is not configured with data security mode
201+
# 2. cluster's DBR version is greater than 11.3
202+
"""
203+
if not spec.data_security_mode:
204+
compatibility = spark_version_compatibility(spec.spark_version)
205+
return compatibility == "supported"
206+
return False
207+
208+
def _get_hash_from_run(self, run: BaseRun) -> str:
209+
hashable_items = []
210+
all_tasks: list[RunTask] = run.tasks if run.tasks is not None else []
211+
for task in sorted(all_tasks, key=lambda x: x.task_key if x.task_key is not None else ""):
212+
hashable_items.extend(self._run_task_values(task))
213+
214+
if run.git_source:
215+
hashable_items.extend(self._git_source_values(run.git_source))
216+
217+
return sha256(bytes("|".join(hashable_items).encode("utf-8"))).hexdigest()
218+
219+
@classmethod
220+
def _sql_task_values(cls, task: SqlTask) -> list[str]:
221+
hash_values = [
222+
task.file.path if task.file else None,
223+
task.alert.alert_id if task.alert else None,
224+
task.dashboard.dashboard_id if task.dashboard else None,
225+
task.query.query_id if task.query else None,
226+
]
227+
return [str(value) for value in hash_values if value is not None]
228+
229+
@classmethod
230+
def _git_source_values(cls, source: GitSource) -> list[str]:
231+
hash_values = [source.git_url]
232+
return [str(value) for value in hash_values if value is not None]
233+
234+
@classmethod
235+
def _dbt_task_values(cls, dbt_task: DbtTask) -> list[str]:
236+
hash_values = [
237+
dbt_task.schema,
238+
dbt_task.catalog,
239+
dbt_task.warehouse_id,
240+
dbt_task.project_directory,
241+
",".join(sorted(dbt_task.commands)),
242+
]
243+
return [str(value) for value in hash_values if value is not None]
244+
245+
@classmethod
246+
def _jar_task_values(cls, spark_jar_task: SparkJarTask) -> list[str]:
247+
hash_values = [spark_jar_task.jar_uri, spark_jar_task.main_class_name]
248+
return [str(value) for value in hash_values if value is not None]
249+
250+
@classmethod
251+
def _python_wheel_task_values(cls, pw_task: PythonWheelTask) -> list[str]:
252+
hash_values = [pw_task.package_name, pw_task.entry_point]
253+
return [str(value) for value in hash_values if value is not None]
254+
255+
@classmethod
256+
def _run_condition_task_values(cls, c_task: RunConditionTask) -> list[str]:
257+
hash_values = [c_task.op.value if c_task.op else None, c_task.right, c_task.left, c_task.outcome]
258+
return [str(value) for value in hash_values if value is not None]
259+
260+
@classmethod
261+
def _run_task_values(cls, task: RunTask) -> list[str]:
262+
"""
263+
Retrieve all hashable attributes and append to a list with None removed
264+
- specifically ignore parameters as these change.
265+
"""
266+
hash_values = [
267+
task.notebook_task.notebook_path if task.notebook_task else None,
268+
task.spark_python_task.python_file if task.spark_python_task else None,
269+
(
270+
'|'.join(task.spark_submit_task.parameters)
271+
if (task.spark_submit_task and task.spark_submit_task.parameters)
272+
else None
273+
),
274+
task.pipeline_task.pipeline_id if task.pipeline_task is not None else None,
275+
task.run_job_task.job_id if task.run_job_task else None,
276+
]
277+
hash_lists = [
278+
cls._jar_task_values(task.spark_jar_task) if task.spark_jar_task else None,
279+
(cls._python_wheel_task_values(task.python_wheel_task) if (task.python_wheel_task) else None),
280+
cls._sql_task_values(task.sql_task) if task.sql_task else None,
281+
cls._dbt_task_values(task.dbt_task) if task.dbt_task else None,
282+
cls._run_condition_task_values(task.condition_task) if task.condition_task else None,
283+
cls._git_source_values(task.git_source) if task.git_source else None,
284+
]
285+
# combining all the values from the lists where the list is not "None"
286+
hash_values_from_lists = sum([hash_list for hash_list in hash_lists if hash_list], [])
287+
return [str(value) for value in hash_values + hash_values_from_lists]
288+
289+
def _assess_job_runs(self, submit_runs: Iterable[BaseRun], all_clusters_by_id) -> Iterable[SubmitRunInfo]:
290+
"""
291+
Assessment logic:
292+
1. For eaxch submit run, we analyze all tasks inside this run.
293+
2. Per each task, we calculate a unique hash based on the _retrieve_hash_values_from_task function
294+
3. Then we coalesce all task hashes into a single hash for the submit run
295+
4. Coalesce all runs under the same hash into a single pseudo-job
296+
5. Return a list of pseudo-jobs with their assessment results
297+
"""
298+
result: dict[str, SubmitRunInfo] = {}
299+
runs_per_hash: dict[str, list[int | None]] = {}
300+
301+
for submit_run in submit_runs:
302+
task_failures = []
303+
# v2.1+ API, with tasks
304+
if submit_run.tasks:
305+
all_tasks: list[RunTask] = submit_run.tasks
306+
for task in sorted(all_tasks, key=lambda x: x.task_key if x.task_key is not None else ""):
307+
_task_key = task.task_key if task.task_key is not None else ""
308+
_cluster_details = None
309+
if task.new_cluster:
310+
_cluster_details = ClusterDetails.from_dict(task.new_cluster.as_dict())
311+
if self._needs_compatibility_check(task.new_cluster):
312+
task_failures.append("no data security mode specified")
313+
if task.existing_cluster_id:
314+
_cluster_details = all_clusters_by_id.get(task.existing_cluster_id, None)
315+
if _cluster_details:
316+
task_failures.extend(self._check_cluster_failures(_cluster_details, _task_key))
317+
318+
# v2.0 API, without tasks
319+
elif submit_run.cluster_spec:
320+
_cluster_details = ClusterDetails.from_dict(submit_run.cluster_spec.as_dict())
321+
task_failures.extend(self._check_cluster_failures(_cluster_details, "root_task"))
322+
hashed_id = self._get_hash_from_run(submit_run)
323+
if hashed_id in runs_per_hash:
324+
runs_per_hash[hashed_id].append(submit_run.run_id)
325+
else:
326+
runs_per_hash[hashed_id] = [submit_run.run_id]
327+
328+
result[hashed_id] = SubmitRunInfo(
329+
run_ids=json.dumps(runs_per_hash[hashed_id]),
330+
hashed_id=hashed_id,
331+
failures=json.dumps(list(set(task_failures))),
332+
)
333+
334+
return list(result.values())

src/databricks/labs/ucx/assessment/pipelines.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,12 @@ def _assess_pipelines(self, all_pipelines) -> Iterable[PipelineInfo]:
4949
assert pipeline_response.spec is not None
5050
pipeline_config = pipeline_response.spec.configuration
5151
if pipeline_config:
52-
failures.extend(self.check_spark_conf(pipeline_config, "pipeline"))
52+
failures.extend(self._check_spark_conf(pipeline_config, "pipeline"))
5353
pipeline_cluster = pipeline_response.spec.clusters
5454
if pipeline_cluster:
5555
for cluster in pipeline_cluster:
5656
if cluster.spark_conf:
57-
failures.extend(self.check_spark_conf(cluster.spark_conf, "pipeline cluster"))
57+
failures.extend(self._check_spark_conf(cluster.spark_conf, "pipeline cluster"))
5858
# Checking if cluster config is present in cluster policies
5959
if cluster.policy_id:
6060
failures.extend(self._check_cluster_policy(cluster.policy_id, "pipeline cluster"))

src/databricks/labs/ucx/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class WorkspaceConfig: # pylint: disable=too-many-instance-attributes
3535

3636
override_clusters: dict[str, str] | None = None
3737
policy_id: str | None = None
38+
num_days_submit_runs_history: int = 30
3839

3940
def replace_inventory_variable(self, text: str) -> str:
4041
return text.replace("$inventory", f"hive_metastore.{self.inventory_database}")

src/databricks/labs/ucx/install.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from databricks.labs.ucx.assessment.azure import AzureServicePrincipalInfo
5555
from databricks.labs.ucx.assessment.clusters import ClusterInfo
5656
from databricks.labs.ucx.assessment.init_scripts import GlobalInitScriptInfo
57-
from databricks.labs.ucx.assessment.jobs import JobInfo
57+
from databricks.labs.ucx.assessment.jobs import JobInfo, SubmitRunInfo
5858
from databricks.labs.ucx.assessment.pipelines import PipelineInfo
5959
from databricks.labs.ucx.config import WorkspaceConfig
6060
from databricks.labs.ucx.configure import ConfigureClusterOverrides
@@ -160,6 +160,7 @@ def deploy_schema(sql_backend: SqlBackend, inventory_schema: str):
160160
functools.partial(table, "table_failures", TableError),
161161
functools.partial(table, "workspace_objects", WorkspaceObjectInfo),
162162
functools.partial(table, "permissions", Permissions),
163+
functools.partial(table, "submit_runs", SubmitRunInfo),
163164
],
164165
)
165166
deployer.deploy_view("objects", "queries/views/objects.sql")
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- viz type=table, name=Submit Runs, columns=hashed_id,failure,run_ids
2+
-- widget title=Incompatible Submit Runs, row=6, col=4, size_x=3, size_y=8
3+
SELECT
4+
hashed_id,
5+
EXPLODE(FROM_JSON(failures, 'array<string>')) AS failure,
6+
FROM_JSON(run_ids, 'array<string>') AS run_ids
7+
FROM $inventory.submit_runs
8+
ORDER BY hashed_id DESC
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- viz type=table, name=Submit Runs Failures, columns=failure,submit_runs,run_ids
2+
-- widget title=Incompatible Submit Runs Failures, row=6, col=5, size_x=3, size_y=8
3+
SELECT
4+
EXPLODE(FROM_JSON(failures, 'array<string>')) AS failure,
5+
COUNT(DISTINCT hashed_id) AS submit_runs,
6+
COLLECT_LIST(DISTINCT run_ids) AS run_ids
7+
FROM $inventory.submit_runs
8+
group by 1

src/databricks/labs/ucx/queries/views/objects.sql

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ SELECT "clusters" AS object_type, cluster_id AS object_id, failures FROM $invent
44
UNION ALL
55
SELECT "global init scripts" AS object_type, script_id AS object_id, failures FROM $inventory.global_init_scripts
66
UNION ALL
7+
SELECT "submit_runs" AS object_type, hashed_id AS object_id, failures FROM $inventory.submit_runs
8+
UNION ALL
79
SELECT "pipelines" AS object_type, pipeline_id AS object_id, failures FROM $inventory.pipelines
810
UNION ALL
911
SELECT object_type, object_id, failures FROM (

src/databricks/labs/ucx/runtime.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from databricks.labs.ucx.assessment.azure import AzureServicePrincipalCrawler
88
from databricks.labs.ucx.assessment.clusters import ClustersCrawler
99
from databricks.labs.ucx.assessment.init_scripts import GlobalInitScriptCrawler
10-
from databricks.labs.ucx.assessment.jobs import JobsCrawler
10+
from databricks.labs.ucx.assessment.jobs import JobsCrawler, SubmitRunsCrawler
1111
from databricks.labs.ucx.assessment.pipelines import PipelinesCrawler
1212
from databricks.labs.ucx.config import WorkspaceConfig
1313
from databricks.labs.ucx.framework.crawlers import SqlBackend
@@ -139,6 +139,21 @@ def assess_pipelines(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: Sql
139139
crawler.snapshot()
140140

141141

142+
@task("assessment")
143+
def assess_incompatible_submit_runs(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend):
144+
"""This module scans through all the Submit Runs and identifies those runs which may become incompatible after
145+
the workspace attachment.
146+
147+
It looks for:
148+
- All submit runs with DBR >=11.3 and data_security_mode:None
149+
150+
It also combines several submit runs under a single pseudo_id based on hash of the submit run configuration.
151+
Subsequently, a list of all the incompatible runs with failures are stored in the
152+
`$inventory.submit_runs` table."""
153+
crawler = SubmitRunsCrawler(ws, sql_backend, cfg.inventory_database, cfg.num_days_submit_runs_history)
154+
crawler.snapshot()
155+
156+
142157
@task("assessment", cloud="azure")
143158
def assess_azure_service_principals(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBackend):
144159
"""This module scans through all the clusters configurations, cluster policies, job cluster configurations,
@@ -220,6 +235,7 @@ def crawl_groups(cfg: WorkspaceConfig, ws: WorkspaceClient, sql_backend: SqlBack
220235
crawl_permissions,
221236
guess_external_locations,
222237
assess_jobs,
238+
assess_incompatible_submit_runs,
223239
assess_clusters,
224240
assess_azure_service_principals,
225241
assess_pipelines,

0 commit comments

Comments
 (0)