Skip to content

Commit bc843c9

Browse files
authored
Make code more readable by enforcing max-nested-blocks = 3 with pylint (#1018)
No logic changes, just for readability and to spare code reviewer's sanity.
1 parent e442d63 commit bc843c9

File tree

11 files changed

+193
-168
lines changed

11 files changed

+193
-168
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ default-docstring-type = "default"
598598

599599
[tool.pylint.refactoring]
600600
# Maximum number of nested blocks for function / method body
601-
max-nested-blocks = 5
601+
max-nested-blocks = 3
602602

603603
# Complete name of functions that never returns. When checking for inconsistent-
604604
# return-statements if a never returning function is called then it will be

src/databricks/labs/ucx/account.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -131,41 +131,37 @@ def _get_valid_workspaces_groups(self, prompts: Prompts, workspace_ids: list[int
131131
for workspace in self._workspaces():
132132
if workspace.workspace_id not in workspace_ids:
133133
continue
134-
client = self.client_for(workspace)
135-
logger.info(f"Crawling groups in workspace {client.config.host}")
134+
self._load_workspace_groups(prompts, workspace, all_workspaces_groups)
136135

137-
ws_group_ids = client.groups.list(attributes="id")
138-
for group_id in ws_group_ids:
139-
full_workspace_group = self._safe_groups_get(client, group_id.id)
140-
if not full_workspace_group:
141-
continue
142-
group_name = full_workspace_group.display_name
136+
return all_workspaces_groups
143137

144-
if self._is_group_out_of_scope(full_workspace_group):
138+
def _load_workspace_groups(self, prompts, workspace, all_workspaces_groups):
139+
client = self.client_for(workspace)
140+
logger.info(f"Crawling groups in workspace {client.config.host}")
141+
ws_group_ids = client.groups.list(attributes="id")
142+
for group_id in ws_group_ids:
143+
full_workspace_group = self._safe_groups_get(client, group_id.id)
144+
if not full_workspace_group:
145+
continue
146+
group_name = full_workspace_group.display_name
147+
if self._is_group_out_of_scope(full_workspace_group):
148+
continue
149+
if not group_name:
150+
continue
151+
if group_name in all_workspaces_groups:
152+
if self._has_same_members(all_workspaces_groups[group_name], full_workspace_group):
153+
logger.info(f"Workspace group {group_name} already found, ignoring")
145154
continue
146-
147-
if group_name in all_workspaces_groups:
148-
if self._has_same_members(all_workspaces_groups[group_name], full_workspace_group):
149-
logger.info(f"Workspace group {group_name} already found, ignoring")
150-
continue
151-
152-
if prompts.confirm(
153-
f"Group {group_name} does not have the same amount of members "
154-
f"in workspace {client.config.host} than previous workspaces which contains the same group name,"
155-
f"it will be created at the account with name : {workspace.workspace_name}_{group_name}"
156-
):
157-
all_workspaces_groups[f"{workspace.workspace_name}_{group_name}"] = full_workspace_group
158-
continue
159-
160-
if not group_name:
155+
if prompts.confirm(
156+
f"Group {group_name} does not have the same amount of members "
157+
f"in workspace {client.config.host} than previous workspaces which contains the same group name,"
158+
f"it will be created at the account with name : {workspace.workspace_name}_{group_name}"
159+
):
160+
all_workspaces_groups[f"{workspace.workspace_name}_{group_name}"] = full_workspace_group
161161
continue
162-
163-
logger.info(f"Found new group {group_name}")
164-
all_workspaces_groups[group_name] = full_workspace_group
165-
166-
logger.info(f"Found a total of {len(all_workspaces_groups)} groups to migrate to the account")
167-
168-
return all_workspaces_groups
162+
logger.info(f"Found new group {group_name}")
163+
all_workspaces_groups[group_name] = full_workspace_group
164+
logger.info(f"Found a total of {len(all_workspaces_groups)} groups to migrate to the account")
169165

170166
def _is_group_out_of_scope(self, group: Group) -> bool:
171167
if group.display_name in {"users", "admins", "account users"}:

src/databricks/labs/ucx/assessment/jobs.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,34 @@ class JobInfo:
3838

3939

4040
class JobsMixin:
41-
@staticmethod
42-
def _get_cluster_configs_from_all_jobs(all_jobs, all_clusters_by_id): # pylint: disable=too-complex
43-
for j in all_jobs:
44-
if j.settings is None:
41+
@classmethod
42+
def _get_cluster_configs_from_all_jobs(cls, all_jobs, all_clusters_by_id):
43+
for job in all_jobs:
44+
if job.settings is None:
4545
continue
46-
if j.settings.job_clusters is not None:
47-
for job_cluster in j.settings.job_clusters:
48-
if job_cluster.new_cluster is None:
49-
continue
50-
yield j, job_cluster.new_cluster
51-
if j.settings.tasks is None:
46+
if job.settings.job_clusters is not None:
47+
yield from cls._job_clusters(job)
48+
if job.settings.tasks is None:
5249
continue
53-
for task in j.settings.tasks:
54-
if task.existing_cluster_id is not None:
55-
interactive_cluster = all_clusters_by_id.get(task.existing_cluster_id, None)
56-
if interactive_cluster is None:
57-
continue
58-
yield j, interactive_cluster
50+
yield from cls._task_clusters(job, all_clusters_by_id)
5951

60-
elif task.new_cluster is not None:
61-
yield j, task.new_cluster
52+
@classmethod
53+
def _task_clusters(cls, job, all_clusters_by_id):
54+
for task in job.settings.tasks:
55+
if task.existing_cluster_id is not None:
56+
interactive_cluster = all_clusters_by_id.get(task.existing_cluster_id, None)
57+
if interactive_cluster is None:
58+
continue
59+
yield job, interactive_cluster
60+
elif task.new_cluster is not None:
61+
yield job, task.new_cluster
62+
63+
@staticmethod
64+
def _job_clusters(job):
65+
for job_cluster in job.settings.job_clusters:
66+
if job_cluster.new_cluster is None:
67+
continue
68+
yield job, job_cluster.new_cluster
6269

6370

6471
class JobsCrawler(CrawlerBase[JobInfo], JobsMixin, CheckClusterMixin):
@@ -299,22 +306,10 @@ def _assess_job_runs(self, submit_runs: Iterable[BaseRun], all_clusters_by_id) -
299306
runs_per_hash: dict[str, list[int | None]] = {}
300307

301308
for submit_run in submit_runs:
302-
task_failures = []
309+
task_failures: list[str] = []
303310
# v2.1+ API, with tasks
304311
if submit_run.tasks:
305-
all_tasks: list[RunTask] = submit_run.tasks
306-
for task in sorted(all_tasks, key=lambda x: x.task_key if x.task_key is not None else ""):
307-
_task_key = task.task_key if task.task_key is not None else ""
308-
_cluster_details = None
309-
if task.new_cluster:
310-
_cluster_details = ClusterDetails.from_dict(task.new_cluster.as_dict())
311-
if self._needs_compatibility_check(task.new_cluster):
312-
task_failures.append("no data security mode specified")
313-
if task.existing_cluster_id:
314-
_cluster_details = all_clusters_by_id.get(task.existing_cluster_id, None)
315-
if _cluster_details:
316-
task_failures.extend(self._check_cluster_failures(_cluster_details, _task_key))
317-
312+
self._check_run_task(submit_run.tasks, all_clusters_by_id, task_failures)
318313
# v2.0 API, without tasks
319314
elif submit_run.cluster_spec:
320315
_cluster_details = ClusterDetails.from_dict(submit_run.cluster_spec.as_dict())
@@ -324,11 +319,23 @@ def _assess_job_runs(self, submit_runs: Iterable[BaseRun], all_clusters_by_id) -
324319
runs_per_hash[hashed_id].append(submit_run.run_id)
325320
else:
326321
runs_per_hash[hashed_id] = [submit_run.run_id]
327-
328322
result[hashed_id] = SubmitRunInfo(
329323
run_ids=json.dumps(runs_per_hash[hashed_id]),
330324
hashed_id=hashed_id,
331325
failures=json.dumps(list(set(task_failures))),
332326
)
333327

334328
return list(result.values())
329+
330+
def _check_run_task(self, all_tasks: list[RunTask], clusters: dict[str, ClusterDetails], task_failures: list[str]):
331+
for task in sorted(all_tasks, key=lambda x: x.task_key if x.task_key is not None else ""):
332+
_task_key = task.task_key if task.task_key is not None else ""
333+
cluster_details = None
334+
if task.new_cluster:
335+
cluster_details = ClusterDetails.from_dict(task.new_cluster.as_dict())
336+
if self._needs_compatibility_check(task.new_cluster):
337+
task_failures.append("no data security mode specified")
338+
if task.existing_cluster_id:
339+
cluster_details = clusters.get(task.existing_cluster_id, None)
340+
if cluster_details:
341+
task_failures.extend(self._check_cluster_failures(cluster_details, _task_key))

src/databricks/labs/ucx/assessment/pipelines.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,24 @@ def _assess_pipelines(self, all_pipelines) -> Iterable[PipelineInfo]:
5050
pipeline_config = pipeline_response.spec.configuration
5151
if pipeline_config:
5252
failures.extend(self._check_spark_conf(pipeline_config, "pipeline"))
53-
pipeline_cluster = pipeline_response.spec.clusters
54-
if pipeline_cluster:
55-
for cluster in pipeline_cluster:
56-
if cluster.spark_conf:
57-
failures.extend(self._check_spark_conf(cluster.spark_conf, "pipeline cluster"))
58-
# Checking if cluster config is present in cluster policies
59-
if cluster.policy_id:
60-
failures.extend(self._check_cluster_policy(cluster.policy_id, "pipeline cluster"))
61-
if cluster.init_scripts:
62-
failures.extend(self._check_cluster_init_script(cluster.init_scripts, "pipeline cluster"))
63-
53+
clusters = pipeline_response.spec.clusters
54+
if clusters:
55+
self._pipeline_clusters(clusters, failures)
6456
pipeline_info.failures = json.dumps(failures)
6557
if len(failures) > 0:
6658
pipeline_info.success = 0
6759
yield pipeline_info
6860

61+
def _pipeline_clusters(self, clusters, failures):
62+
for cluster in clusters:
63+
if cluster.spark_conf:
64+
failures.extend(self._check_spark_conf(cluster.spark_conf, "pipeline cluster"))
65+
# Checking if cluster config is present in cluster policies
66+
if cluster.policy_id:
67+
failures.extend(self._check_cluster_policy(cluster.policy_id, "pipeline cluster"))
68+
if cluster.init_scripts:
69+
failures.extend(self._check_cluster_init_script(cluster.init_scripts, "pipeline cluster"))
70+
6971
def snapshot(self) -> Iterable[PipelineInfo]:
7072
return self._snapshot(self._try_fetch, self._crawl)
7173

src/databricks/labs/ucx/framework/dashboards.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,17 @@ def validate(self):
126126
dashboard_folders = [f for f in step_folder.glob("*") if f.is_dir()]
127127
# Create separate dashboards per step, represented as second-level folders
128128
for dashboard_folder in dashboard_folders:
129-
dashboard_ref = f"{step_folder.stem}_{dashboard_folder.stem}".lower()
130-
for query in self._desired_queries(dashboard_folder, dashboard_ref):
131-
try:
132-
self._get_viz_options(query)
133-
self._get_widget_options(query)
134-
except Exception as err:
135-
msg = f"Error in {query.name}: {err}"
136-
raise AssertionError(msg) from err
129+
self._validate_folder(dashboard_folder, step_folder)
130+
131+
def _validate_folder(self, dashboard_folder, step_folder):
132+
dashboard_ref = f"{step_folder.stem}_{dashboard_folder.stem}".lower()
133+
for query in self._desired_queries(dashboard_folder, dashboard_ref):
134+
try:
135+
self._get_viz_options(query)
136+
self._get_widget_options(query)
137+
except Exception as err:
138+
msg = f"Error in {query.name}: {err}"
139+
raise AssertionError(msg) from err
137140

138141
def _install_widget(self, query: SimpleQuery, dashboard_ref: str):
139142
dashboard_id = self._state.dashboards[dashboard_ref]

src/databricks/labs/ucx/hive_metastore/locations.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from databricks.labs.blueprint.installation import Installation
99
from databricks.sdk import WorkspaceClient
10+
from databricks.sdk.service.catalog import ExternalLocationInfo
1011

1112
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
1213
from databricks.labs.ucx.framework.utils import escape_sql_identifier
@@ -39,23 +40,27 @@ def _external_locations(self, tables: list[Row], mounts) -> Iterable[ExternalLoc
3940
external_locations: list[ExternalLocation] = []
4041
for table in tables:
4142
location = table.location
42-
if location is not None and len(location) > 0:
43-
if location.startswith("dbfs:/mnt"):
44-
for mount in mounts:
45-
if location[5:].startswith(mount.name.lower()):
46-
location = location[5:].replace(mount.name, mount.source)
47-
break
48-
if (
49-
not location.startswith("dbfs")
50-
and (self._prefix_size[0] < location.find(":/") < self._prefix_size[1])
51-
and not location.startswith("jdbc")
52-
):
53-
self._dbfs_locations(external_locations, location, min_slash)
54-
if location.startswith("jdbc"):
55-
self._add_jdbc_location(external_locations, location, table)
56-
43+
if not location:
44+
continue
45+
if location.startswith("dbfs:/mnt"):
46+
location = self._resolve_mount(location, mounts)
47+
if (
48+
not location.startswith("dbfs")
49+
and (self._prefix_size[0] < location.find(":/") < self._prefix_size[1])
50+
and not location.startswith("jdbc")
51+
):
52+
self._dbfs_locations(external_locations, location, min_slash)
53+
if location.startswith("jdbc"):
54+
self._add_jdbc_location(external_locations, location, table)
5755
return external_locations
5856

57+
def _resolve_mount(self, location, mounts):
58+
for mount in mounts:
59+
if location[5:].startswith(mount.name.lower()):
60+
location = location[5:].replace(mount.name, mount.source)
61+
break
62+
return location
63+
5964
@staticmethod
6065
def _dbfs_locations(external_locations, location, min_slash):
6166
dupe = False
@@ -161,31 +166,33 @@ def _get_ext_location_definitions(self, missing_locations: list[ExternalLocation
161166
return tf_script
162167

163168
def match_table_external_locations(self) -> tuple[dict[str, int], list[ExternalLocation]]:
164-
uc_external_locations = list(self._ws.external_locations.list())
169+
existing_locations = list(self._ws.external_locations.list())
165170
table_locations = self.snapshot()
166-
matching_locations = {}
171+
matching_locations: dict[str, int] = {}
167172
missing_locations = []
168173
for table_loc in table_locations:
169174
# external_location.list returns url without trailing "/" but ExternalLocation.snapshot
170175
# does so removing the trailing slash before comparing
171-
matched = False
172-
for uc_loc in uc_external_locations:
173-
if not uc_loc.url:
174-
continue
175-
if not uc_loc.name:
176-
continue
177-
uc_loc_path = uc_loc.url.lower()
178-
if uc_loc_path in table_loc.location.rstrip("/").lower():
179-
if uc_loc.name not in matching_locations:
180-
matching_locations[uc_loc.name] = table_loc.table_count
181-
else:
182-
matching_locations[uc_loc.name] = matching_locations[uc_loc.name] + table_loc.table_count
183-
matched = True
184-
break
185-
if not matched:
176+
if not self._match_existing(table_loc, matching_locations, existing_locations):
186177
missing_locations.append(table_loc)
187178
return matching_locations, missing_locations
188179

180+
@staticmethod
181+
def _match_existing(table_loc, matching_locations: dict[str, int], existing_locations: list[ExternalLocationInfo]):
182+
for uc_loc in existing_locations:
183+
if not uc_loc.url:
184+
continue
185+
if not uc_loc.name:
186+
continue
187+
uc_loc_path = uc_loc.url.lower()
188+
if uc_loc_path in table_loc.location.rstrip("/").lower():
189+
if uc_loc.name not in matching_locations:
190+
matching_locations[uc_loc.name] = table_loc.table_count
191+
else:
192+
matching_locations[uc_loc.name] = matching_locations[uc_loc.name] + table_loc.table_count
193+
return True
194+
return False
195+
189196
def save_as_terraform_definitions_on_workspace(self, installation: Installation):
190197
matching_locations, missing_locations = self.match_table_external_locations()
191198
if len(matching_locations) > 0:

0 commit comments

Comments
 (0)