|
2 | 2 | import logging |
3 | 3 | from collections.abc import Iterable |
4 | 4 | from dataclasses import dataclass |
| 5 | +from datetime import datetime, timedelta, timezone |
| 6 | +from hashlib import sha256 |
5 | 7 |
|
6 | 8 | from databricks.sdk import WorkspaceClient |
| 9 | +from databricks.sdk.service import compute |
7 | 10 | from databricks.sdk.service.compute import ClusterDetails |
8 | | -from databricks.sdk.service.jobs import BaseJob |
| 11 | +from databricks.sdk.service.jobs import ( |
| 12 | + BaseJob, |
| 13 | + BaseRun, |
| 14 | + DbtTask, |
| 15 | + GitSource, |
| 16 | + ListRunsRunType, |
| 17 | + PythonWheelTask, |
| 18 | + RunConditionTask, |
| 19 | + RunTask, |
| 20 | + SparkJarTask, |
| 21 | + SqlTask, |
| 22 | +) |
9 | 23 |
|
10 | 24 | from databricks.labs.ucx.assessment.clusters import CheckClusterMixin |
| 25 | +from databricks.labs.ucx.assessment.crawlers import spark_version_compatibility |
11 | 26 | from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend |
12 | 27 |
|
13 | 28 | logger = logging.getLogger(__name__) |
@@ -63,7 +78,7 @@ def _assess_jobs(self, all_jobs: list[BaseJob], all_clusters_by_id) -> Iterable[ |
63 | 78 | if not job_id: |
64 | 79 | continue |
65 | 80 | cluster_details = ClusterDetails.from_dict(cluster_config.as_dict()) |
66 | | - cluster_failures = self.check_cluster_failures(cluster_details, "Job cluster") |
| 81 | + cluster_failures = self._check_cluster_failures(cluster_details, "Job cluster") |
67 | 82 | job_assessment[job_id].update(cluster_failures) |
68 | 83 |
|
69 | 84 | # TODO: next person looking at this - rewrite, as this code makes no sense |
@@ -108,3 +123,212 @@ def snapshot(self) -> Iterable[JobInfo]: |
108 | 123 | def _try_fetch(self) -> Iterable[JobInfo]: |
109 | 124 | for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"): |
110 | 125 | yield JobInfo(*row) |
| 126 | + |
| 127 | + |
| 128 | +@dataclass |
| 129 | +class SubmitRunInfo: |
| 130 | + run_ids: str # JSON-encoded list of run ids |
| 131 | + hashed_id: str # a pseudo id that combines all the hashable attributes of the run |
| 132 | + failures: str = "[]" # JSON-encoded list of failures |
| 133 | + |
| 134 | + |
| 135 | +class SubmitRunsCrawler(CrawlerBase[SubmitRunInfo], JobsMixin, CheckClusterMixin): |
| 136 | + _FS_LEVEL_CONF_SETTING_PATTERNS = [ |
| 137 | + "fs.s3a", |
| 138 | + "fs.s3n", |
| 139 | + "fs.s3", |
| 140 | + "fs.azure", |
| 141 | + "fs.wasb", |
| 142 | + "fs.abfs", |
| 143 | + "fs.adl", |
| 144 | + ] |
| 145 | + |
| 146 | + def __init__(self, ws: WorkspaceClient, sbe: SqlBackend, schema: str, num_days_history: int): |
| 147 | + super().__init__(sbe, "hive_metastore", schema, "submit_runs", SubmitRunInfo) |
| 148 | + self._ws = ws |
| 149 | + self._num_days_history = num_days_history |
| 150 | + |
| 151 | + def snapshot(self) -> Iterable[SubmitRunInfo]: |
| 152 | + return self._snapshot(self._try_fetch, self._crawl) |
| 153 | + |
| 154 | + @staticmethod |
| 155 | + def _dt_to_ms(date_time: datetime): |
| 156 | + return int(date_time.timestamp() * 1000) |
| 157 | + |
| 158 | + @staticmethod |
| 159 | + def _get_current_dttm() -> datetime: |
| 160 | + return datetime.now(timezone.utc) |
| 161 | + |
| 162 | + def _crawl(self) -> Iterable[SubmitRunInfo]: |
| 163 | + end = self._dt_to_ms(self._get_current_dttm()) |
| 164 | + start = self._dt_to_ms(self._get_current_dttm() - timedelta(days=self._num_days_history)) |
| 165 | + submit_runs = self._ws.jobs.list_runs( |
| 166 | + expand_tasks=True, |
| 167 | + completed_only=True, |
| 168 | + run_type=ListRunsRunType.SUBMIT_RUN, |
| 169 | + start_time_from=start, |
| 170 | + start_time_to=end, |
| 171 | + ) |
| 172 | + all_clusters = {c.cluster_id: c for c in self._ws.clusters.list()} |
| 173 | + return self._assess_job_runs(submit_runs, all_clusters) |
| 174 | + |
| 175 | + def _try_fetch(self) -> Iterable[SubmitRunInfo]: |
| 176 | + for row in self._fetch(f"SELECT * FROM {self._schema}.{self._table}"): |
| 177 | + yield SubmitRunInfo(*row) |
| 178 | + |
| 179 | + def _check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]: |
| 180 | + failures: list[str] = [] |
| 181 | + for key in conf.keys(): |
| 182 | + if any(pattern in key for pattern in self._FS_LEVEL_CONF_SETTING_PATTERNS): |
| 183 | + failures.append(f"Potentially unsupported config property: {key}") |
| 184 | + |
| 185 | + failures.extend(super()._check_spark_conf(conf, source)) |
| 186 | + return failures |
| 187 | + |
| 188 | + def _check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]: |
| 189 | + failures: list[str] = [] |
| 190 | + if cluster.aws_attributes and cluster.aws_attributes.instance_profile_arn: |
| 191 | + failures.append(f"using instance profile: {cluster.aws_attributes.instance_profile_arn}") |
| 192 | + |
| 193 | + failures.extend(super()._check_cluster_failures(cluster, source)) |
| 194 | + return failures |
| 195 | + |
| 196 | + @staticmethod |
| 197 | + def _needs_compatibility_check(spec: compute.ClusterSpec) -> bool: |
| 198 | + """ |
| 199 | + # we recognize a task as a potentially incompatible one if: |
| 200 | + # 1. cluster is not configured with data security mode |
| 201 | + # 2. cluster's DBR version is greater than 11.3 |
| 202 | + """ |
| 203 | + if not spec.data_security_mode: |
| 204 | + compatibility = spark_version_compatibility(spec.spark_version) |
| 205 | + return compatibility == "supported" |
| 206 | + return False |
| 207 | + |
| 208 | + def _get_hash_from_run(self, run: BaseRun) -> str: |
| 209 | + hashable_items = [] |
| 210 | + all_tasks: list[RunTask] = run.tasks if run.tasks is not None else [] |
| 211 | + for task in sorted(all_tasks, key=lambda x: x.task_key if x.task_key is not None else ""): |
| 212 | + hashable_items.extend(self._run_task_values(task)) |
| 213 | + |
| 214 | + if run.git_source: |
| 215 | + hashable_items.extend(self._git_source_values(run.git_source)) |
| 216 | + |
| 217 | + return sha256(bytes("|".join(hashable_items).encode("utf-8"))).hexdigest() |
| 218 | + |
| 219 | + @classmethod |
| 220 | + def _sql_task_values(cls, task: SqlTask) -> list[str]: |
| 221 | + hash_values = [ |
| 222 | + task.file.path if task.file else None, |
| 223 | + task.alert.alert_id if task.alert else None, |
| 224 | + task.dashboard.dashboard_id if task.dashboard else None, |
| 225 | + task.query.query_id if task.query else None, |
| 226 | + ] |
| 227 | + return [str(value) for value in hash_values if value is not None] |
| 228 | + |
| 229 | + @classmethod |
| 230 | + def _git_source_values(cls, source: GitSource) -> list[str]: |
| 231 | + hash_values = [source.git_url] |
| 232 | + return [str(value) for value in hash_values if value is not None] |
| 233 | + |
| 234 | + @classmethod |
| 235 | + def _dbt_task_values(cls, dbt_task: DbtTask) -> list[str]: |
| 236 | + hash_values = [ |
| 237 | + dbt_task.schema, |
| 238 | + dbt_task.catalog, |
| 239 | + dbt_task.warehouse_id, |
| 240 | + dbt_task.project_directory, |
| 241 | + ",".join(sorted(dbt_task.commands)), |
| 242 | + ] |
| 243 | + return [str(value) for value in hash_values if value is not None] |
| 244 | + |
| 245 | + @classmethod |
| 246 | + def _jar_task_values(cls, spark_jar_task: SparkJarTask) -> list[str]: |
| 247 | + hash_values = [spark_jar_task.jar_uri, spark_jar_task.main_class_name] |
| 248 | + return [str(value) for value in hash_values if value is not None] |
| 249 | + |
| 250 | + @classmethod |
| 251 | + def _python_wheel_task_values(cls, pw_task: PythonWheelTask) -> list[str]: |
| 252 | + hash_values = [pw_task.package_name, pw_task.entry_point] |
| 253 | + return [str(value) for value in hash_values if value is not None] |
| 254 | + |
| 255 | + @classmethod |
| 256 | + def _run_condition_task_values(cls, c_task: RunConditionTask) -> list[str]: |
| 257 | + hash_values = [c_task.op.value if c_task.op else None, c_task.right, c_task.left, c_task.outcome] |
| 258 | + return [str(value) for value in hash_values if value is not None] |
| 259 | + |
| 260 | + @classmethod |
| 261 | + def _run_task_values(cls, task: RunTask) -> list[str]: |
| 262 | + """ |
| 263 | + Retrieve all hashable attributes and append to a list with None removed |
| 264 | + - specifically ignore parameters as these change. |
| 265 | + """ |
| 266 | + hash_values = [ |
| 267 | + task.notebook_task.notebook_path if task.notebook_task else None, |
| 268 | + task.spark_python_task.python_file if task.spark_python_task else None, |
| 269 | + ( |
| 270 | + '|'.join(task.spark_submit_task.parameters) |
| 271 | + if (task.spark_submit_task and task.spark_submit_task.parameters) |
| 272 | + else None |
| 273 | + ), |
| 274 | + task.pipeline_task.pipeline_id if task.pipeline_task is not None else None, |
| 275 | + task.run_job_task.job_id if task.run_job_task else None, |
| 276 | + ] |
| 277 | + hash_lists = [ |
| 278 | + cls._jar_task_values(task.spark_jar_task) if task.spark_jar_task else None, |
| 279 | + (cls._python_wheel_task_values(task.python_wheel_task) if (task.python_wheel_task) else None), |
| 280 | + cls._sql_task_values(task.sql_task) if task.sql_task else None, |
| 281 | + cls._dbt_task_values(task.dbt_task) if task.dbt_task else None, |
| 282 | + cls._run_condition_task_values(task.condition_task) if task.condition_task else None, |
| 283 | + cls._git_source_values(task.git_source) if task.git_source else None, |
| 284 | + ] |
| 285 | + # combining all the values from the lists where the list is not "None" |
| 286 | + hash_values_from_lists = sum([hash_list for hash_list in hash_lists if hash_list], []) |
| 287 | + return [str(value) for value in hash_values + hash_values_from_lists] |
| 288 | + |
| 289 | + def _assess_job_runs(self, submit_runs: Iterable[BaseRun], all_clusters_by_id) -> Iterable[SubmitRunInfo]: |
| 290 | + """ |
| 291 | + Assessment logic: |
| 292 | + 1. For eaxch submit run, we analyze all tasks inside this run. |
| 293 | + 2. Per each task, we calculate a unique hash based on the _retrieve_hash_values_from_task function |
| 294 | + 3. Then we coalesce all task hashes into a single hash for the submit run |
| 295 | + 4. Coalesce all runs under the same hash into a single pseudo-job |
| 296 | + 5. Return a list of pseudo-jobs with their assessment results |
| 297 | + """ |
| 298 | + result: dict[str, SubmitRunInfo] = {} |
| 299 | + runs_per_hash: dict[str, list[int | None]] = {} |
| 300 | + |
| 301 | + for submit_run in submit_runs: |
| 302 | + task_failures = [] |
| 303 | + # v2.1+ API, with tasks |
| 304 | + if submit_run.tasks: |
| 305 | + all_tasks: list[RunTask] = submit_run.tasks |
| 306 | + for task in sorted(all_tasks, key=lambda x: x.task_key if x.task_key is not None else ""): |
| 307 | + _task_key = task.task_key if task.task_key is not None else "" |
| 308 | + _cluster_details = None |
| 309 | + if task.new_cluster: |
| 310 | + _cluster_details = ClusterDetails.from_dict(task.new_cluster.as_dict()) |
| 311 | + if self._needs_compatibility_check(task.new_cluster): |
| 312 | + task_failures.append("no data security mode specified") |
| 313 | + if task.existing_cluster_id: |
| 314 | + _cluster_details = all_clusters_by_id.get(task.existing_cluster_id, None) |
| 315 | + if _cluster_details: |
| 316 | + task_failures.extend(self._check_cluster_failures(_cluster_details, _task_key)) |
| 317 | + |
| 318 | + # v2.0 API, without tasks |
| 319 | + elif submit_run.cluster_spec: |
| 320 | + _cluster_details = ClusterDetails.from_dict(submit_run.cluster_spec.as_dict()) |
| 321 | + task_failures.extend(self._check_cluster_failures(_cluster_details, "root_task")) |
| 322 | + hashed_id = self._get_hash_from_run(submit_run) |
| 323 | + if hashed_id in runs_per_hash: |
| 324 | + runs_per_hash[hashed_id].append(submit_run.run_id) |
| 325 | + else: |
| 326 | + runs_per_hash[hashed_id] = [submit_run.run_id] |
| 327 | + |
| 328 | + result[hashed_id] = SubmitRunInfo( |
| 329 | + run_ids=json.dumps(runs_per_hash[hashed_id]), |
| 330 | + hashed_id=hashed_id, |
| 331 | + failures=json.dumps(list(set(task_failures))), |
| 332 | + ) |
| 333 | + |
| 334 | + return list(result.values()) |
0 commit comments