Skip to content

Commit 2a6b090

Browse files
authored
Refactor Azure service principal crawler and fix bug where tenant_id inside secret scope is not detected (#942)
## Changes - Fix tenant_id detection logic In the case where spn endpoint configuration is also inside a secret scope. - Refactoring spn crawler logic to make it more readable ### Linked issues Closes #896 ### Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> - [x] manually tested - [ ] added unit tests - [ ] added integration tests - [ ] verified on staging environment (screenshot attached)
1 parent fd0b604 commit 2a6b090

File tree

8 files changed

+225
-325
lines changed

8 files changed

+225
-325
lines changed

docs/assessment.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ These are Global Init Scripts that are incompatible with Unity Catalog compute.
9292
# Assessment Finding Index
9393
This section will help explain UCX Assessment findings and provide a recommended action.
9494
The assessment finding index is grouped by:
95-
- The 100 serieds findings are Databricks Runtime and compute configuration findings
95+
- The 100 series findings are Databricks Runtime and compute configuration findings.
9696
- The 200 series findings are centered around data related observations.
9797

9898
### AF101 - not supported DBR: ##.#.x-scala2.12

src/databricks/labs/ucx/assessment/azure.py

Lines changed: 150 additions & 181 deletions
Large diffs are not rendered by default.

src/databricks/labs/ucx/assessment/clusters.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
)
1616

1717
from databricks.labs.ucx.assessment.crawlers import (
18-
_AZURE_SP_CONF_FAILURE_MSG,
19-
_INIT_SCRIPT_DBFS_PATH,
18+
AZURE_SP_CONF_FAILURE_MSG,
2019
INCOMPATIBLE_SPARK_CONFIG_KEYS,
21-
_azure_sp_conf_present_check,
20+
INIT_SCRIPT_DBFS_PATH,
21+
azure_sp_conf_present_check,
2222
spark_version_compatibility,
2323
)
2424
from databricks.labs.ucx.assessment.init_scripts import CheckInitScriptMixin
@@ -51,16 +51,16 @@ def _check_cluster_policy(self, policy_id: str, source: str) -> list[str]:
5151
policy = self._safe_get_cluster_policy(policy_id)
5252
if policy:
5353
if policy.definition:
54-
if _azure_sp_conf_present_check(json.loads(policy.definition)):
55-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
54+
if azure_sp_conf_present_check(json.loads(policy.definition)):
55+
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
5656
if policy.policy_family_definition_overrides:
57-
if _azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
58-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
57+
if azure_sp_conf_present_check(json.loads(policy.policy_family_definition_overrides)):
58+
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
5959
return failures
6060

6161
def _get_init_script_data(self, init_script_info: InitScriptInfo) -> str | None:
6262
if init_script_info.dbfs is not None and init_script_info.dbfs.destination is not None:
63-
if len(init_script_info.dbfs.destination.split(":")) == _INIT_SCRIPT_DBFS_PATH:
63+
if len(init_script_info.dbfs.destination.split(":")) == INIT_SCRIPT_DBFS_PATH:
6464
file_api_format_destination = init_script_info.dbfs.destination.split(":")[1]
6565
if file_api_format_destination:
6666
try:
@@ -95,8 +95,8 @@ def check_spark_conf(self, conf: dict[str, str], source: str) -> list[str]:
9595
if "dbfs:/mnt" in value or "/dbfs/mnt" in value:
9696
failures.append(f"using DBFS mount in configuration: {value}")
9797
# Checking if Azure cluster config is present in spark config
98-
if _azure_sp_conf_present_check(conf):
99-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
98+
if azure_sp_conf_present_check(conf):
99+
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
100100
return failures
101101

102102
def check_cluster_failures(self, cluster: ClusterDetails, source: str) -> list[str]:

src/databricks/labs/ucx/assessment/crawlers.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,31 +9,31 @@
99
"spark.databricks.hive.metastore.glueCatalog.enabled",
1010
]
1111

12-
_AZURE_SP_CONF = [
12+
AZURE_SP_CONF = [
1313
"fs.azure.account.auth.type",
1414
"fs.azure.account.oauth.provider.type",
1515
"fs.azure.account.oauth2.client.id",
1616
"fs.azure.account.oauth2.client.secret",
1717
"fs.azure.account.oauth2.client.endpoint",
1818
]
19-
_SECRET_PATTERN = r"{{(secrets.*?)}}"
20-
_STORAGE_ACCOUNT_EXTRACT_PATTERN = r"(?:id|endpoint)(.*?)dfs"
21-
_AZURE_SP_CONF_FAILURE_MSG = "Uses azure service principal credentials config in"
22-
_SECRET_LIST_LENGTH = 3
23-
_CLIENT_ENDPOINT_LENGTH = 6
24-
_INIT_SCRIPT_DBFS_PATH = 2
19+
SECRET_PATTERN = r"{{(secrets.*?)}}"
20+
STORAGE_ACCOUNT_EXTRACT_PATTERN = r"(?:id|endpoint)(.*?)dfs"
21+
AZURE_SP_CONF_FAILURE_MSG = "Uses azure service principal credentials config in"
22+
SECRET_LIST_LENGTH = 3
23+
CLIENT_ENDPOINT_LENGTH = 6
24+
INIT_SCRIPT_DBFS_PATH = 2
2525

2626

27-
def _azure_sp_conf_in_init_scripts(init_script_data: str) -> bool:
28-
for conf in _AZURE_SP_CONF:
27+
def azure_sp_conf_in_init_scripts(init_script_data: str) -> bool:
28+
for conf in AZURE_SP_CONF:
2929
if re.search(conf, init_script_data):
3030
return True
3131
return False
3232

3333

34-
def _azure_sp_conf_present_check(config: dict) -> bool:
34+
def azure_sp_conf_present_check(config: dict) -> bool:
3535
for key in config.keys():
36-
for conf in _AZURE_SP_CONF:
36+
for conf in AZURE_SP_CONF:
3737
if re.search(conf, key):
3838
return True
3939
return False

src/databricks/labs/ucx/assessment/init_scripts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from databricks.sdk.errors import ResourceDoesNotExist
99

1010
from databricks.labs.ucx.assessment.crawlers import (
11-
_AZURE_SP_CONF_FAILURE_MSG,
12-
_azure_sp_conf_in_init_scripts,
11+
AZURE_SP_CONF_FAILURE_MSG,
12+
azure_sp_conf_in_init_scripts,
1313
)
1414
from databricks.labs.ucx.framework.crawlers import CrawlerBase, SqlBackend
1515

@@ -33,8 +33,8 @@ def check_init_script(self, init_script_data: str | None, source: str) -> list[s
3333
failures: list[str] = []
3434
if not init_script_data:
3535
return failures
36-
if _azure_sp_conf_in_init_scripts(init_script_data):
37-
failures.append(f"{_AZURE_SP_CONF_FAILURE_MSG} {source}.")
36+
if azure_sp_conf_in_init_scripts(init_script_data):
37+
failures.append(f"{AZURE_SP_CONF_FAILURE_MSG} {source}.")
3838
return failures
3939

4040

tests/unit/assessment/test_azure.py

Lines changed: 35 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,6 @@
66
from . import workspace_client_mock
77

88

9-
def test_azure_spn_info_without_secret():
10-
ws = workspace_client_mock(clusters="single-cluster-spn.json")
11-
sample_spns = [{"application_id": "test123456789", "secret_scope": "", "secret_key": ""}]
12-
AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_cluster_with_spn_in_spark_conf()
13-
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._assess_service_principals(sample_spns)
14-
result_set = list(crawler)
15-
16-
assert len(result_set) == 1
17-
assert result_set[0].application_id == "test123456789"
18-
19-
209
def test_azure_service_principal_info_crawl():
2110
ws = workspace_client_mock(
2211
clusters="assortment-spn.json",
@@ -25,7 +14,7 @@ def test_azure_service_principal_info_crawl():
2514
warehouse_config="spn-config.json",
2615
secret_exists=True,
2716
)
28-
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
17+
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
2918

3019
assert len(spn_crawler) == 5
3120

@@ -38,7 +27,7 @@ def test_azure_service_principal_info_spark_conf_crawl():
3827
warehouse_config="spn-config.json",
3928
)
4029

41-
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
30+
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
4231

4332
assert len(spn_crawler) == 3
4433

@@ -51,14 +40,14 @@ def test_azure_service_principal_info_no_spark_conf_crawl():
5140
warehouse_config="single-config.json",
5241
)
5342

54-
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
43+
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
5544

5645
assert len(spn_crawler) == 0
5746

5847

5948
def test_azure_service_principal_info_policy_family_conf_crawl(mocker):
6049
ws = workspace_client_mock()
61-
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
50+
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
6251

6352
assert len(spn_crawler) == 0
6453

@@ -67,20 +56,10 @@ def test_azure_service_principal_info_null_applid_crawl():
6756
ws = workspace_client_mock(
6857
clusters="single-cluster-spn-with-policy.json", pipelines="single-pipeline.json", jobs="single-job.json"
6958
)
70-
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
59+
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
7160
assert len(spn_crawler) == 0
7261

7362

74-
def test_azure_spn_info_with_secret():
75-
ws = workspace_client_mock(clusters="single-cluster-spn.json", secret_exists=True)
76-
sample_spns = [{"application_id": "test123456780", "secret_scope": "abcff", "secret_key": "sp_app_client_id"}]
77-
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._assess_service_principals(sample_spns)
78-
result_set = list(crawler)
79-
80-
assert len(result_set) == 1
81-
assert result_set[0].application_id == "test123456780"
82-
83-
8463
def test_spn_with_spark_config_snapshot_try_fetch():
8564
sample_spns = [
8665
{
@@ -120,37 +99,36 @@ def test_spn_with_spark_config_snapshot():
12099

121100
def test_list_all_cluster_with_spn_in_spark_conf_with_secret():
122101
ws = workspace_client_mock(clusters="single-cluster-spn.json")
123-
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_cluster_with_spn_in_spark_conf()
124-
result_set = list(crawler)
102+
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
125103

126104
assert len(result_set) == 1
127105

128106

129107
def test_list_all_wh_config_with_spn_no_secret():
130108
ws = workspace_client_mock(warehouse_config="spn-config.json")
131-
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_spn_in_sql_warehouses_spark_conf()
109+
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
132110

133111
assert len(result_set) == 2
134-
assert result_set[0].get("application_id") == "dummy_application_id"
135-
assert result_set[0].get("tenant_id") == "dummy_tenant_id"
136-
assert result_set[0].get("storage_account") == "storage_acct2"
112+
assert any(_ for _ in result_set if _.application_id == "dummy_application_id")
113+
assert any(_ for _ in result_set if _.tenant_id == "dummy_tenant_id")
114+
assert any(_ for _ in result_set if _.storage_account == "storage_acct2")
137115

138116

139117
def test_list_all_wh_config_with_spn_and_secret():
140118
ws = workspace_client_mock(warehouse_config="spn-secret-config.json", secret_exists=True)
141-
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_spn_in_sql_warehouses_spark_conf()
119+
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
142120

143121
assert len(result_set) == 2
144-
assert result_set[0].get("tenant_id") == "dummy_tenant_id"
145-
assert result_set[0].get("storage_account") == "abcde"
122+
assert any(_ for _ in result_set if _.tenant_id == "dummy_tenant_id")
123+
assert any(_ for _ in result_set if _.storage_account == "abcde")
146124

147125

148126
def test_list_all_clusters_spn_in_spark_conf_with_tenant():
149127
ws = workspace_client_mock(clusters="single-cluster-spn.json", secret_exists=True)
150-
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_cluster_with_spn_in_spark_conf()
128+
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
151129

152130
assert len(result_set) == 1
153-
assert result_set[0].get("tenant_id") == "dummy_tenant_id"
131+
assert result_set[0].tenant_id == "dummy_tenant_id"
154132

155133

156134
def test_azure_service_principal_info_policy_conf():
@@ -161,7 +139,7 @@ def test_azure_service_principal_info_policy_conf():
161139
warehouse_config="spn-config.json",
162140
secret_exists=True,
163141
)
164-
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
142+
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
165143

166144
assert len(spn_crawler) == 4
167145

@@ -174,48 +152,48 @@ def test_azure_service_principal_info_dedupe():
174152
warehouse_config="dupe-spn-config.json",
175153
secret_exists=True,
176154
)
177-
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
155+
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
178156

179157
assert len(spn_crawler) == 2
180158

181159

182160
def test_list_all_pipeline_with_conf_spn_in_spark_conf():
183161
ws = workspace_client_mock(pipelines="single-pipeline-with-spn.json")
184-
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
162+
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
185163

186164
assert len(result_set) == 1
187-
assert result_set[0].get("storage_account") == "newstorageacct"
188-
assert result_set[0].get("tenant_id") == "directory_12345"
189-
assert result_set[0].get("application_id") == "pipeline_dummy_application_id"
165+
assert result_set[0].storage_account == "newstorageacct"
166+
assert result_set[0].tenant_id == "directory_12345"
167+
assert result_set[0].application_id == "pipeline_dummy_application_id"
190168

191169

192170
def test_list_all_pipeline_wo_conf_spn_in_spark_conf():
193171
ws = workspace_client_mock(pipelines="single-pipeline.json")
194-
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
172+
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
195173

196174
assert len(result_set) == 0
197175

198176

199177
def test_list_all_pipeline_with_conf_spn_tenant():
200178
ws = workspace_client_mock(pipelines="single-pipeline-with-spn.json")
201-
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
179+
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
202180

203181
assert len(result_set) == 1
204-
assert result_set[0].get("storage_account") == "newstorageacct"
205-
assert result_set[0].get("application_id") == "pipeline_dummy_application_id"
182+
assert result_set[0].storage_account == "newstorageacct"
183+
assert result_set[0].application_id == "pipeline_dummy_application_id"
206184

207185

208186
def test_list_all_pipeline_with_conf_spn_secret():
209187
ws = workspace_client_mock(pipelines="single-pipeline-with-spn.json", secret_exists=True)
210-
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
188+
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
211189

212190
assert len(result_set) == 1
213-
assert result_set[0].get("storage_account") == "newstorageacct"
191+
assert result_set[0].storage_account == "newstorageacct"
214192

215193

216194
def test_azure_service_principal_info_policy_family():
217195
ws = workspace_client_mock(clusters="single-cluster-spn-with-policy-overrides.json")
218-
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._crawl()
196+
spn_crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
219197

220198
assert len(spn_crawler) == 1
221199
assert spn_crawler[0].application_id == "dummy_appl_id"
@@ -225,19 +203,19 @@ def test_azure_service_principal_info_policy_family():
225203
def test_list_all_pipeline_with_conf_spn_secret_unavlbl():
226204
ws = workspace_client_mock(pipelines="single-pipeline.json", secret_exists=False)
227205
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")
228-
result_set = crawler._list_all_pipeline_with_spn_in_spark_conf()
206+
result_set = crawler.snapshot()
229207

230208
assert len(result_set) == 0
231209

232210

233211
def test_list_all_pipeline_with_conf_spn_secret_avlb():
234212
ws = workspace_client_mock(pipelines="single-pipeline-with-spn.json", secret_exists=True)
235-
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._list_all_pipeline_with_spn_in_spark_conf()
213+
result_set = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx").snapshot()
236214

237215
assert len(result_set) > 0
238-
assert result_set[0].get("application_id") == "pipeline_dummy_application_id"
239-
assert result_set[0].get("tenant_id") == "directory_12345"
240-
assert result_set[0].get("storage_account") == "newstorageacct"
216+
assert result_set[0].application_id == "pipeline_dummy_application_id"
217+
assert result_set[0].tenant_id == "directory_12345"
218+
assert result_set[0].storage_account == "newstorageacct"
241219

242220

243221
def test_azure_spn_info_with_secret_unavailable():
@@ -251,6 +229,6 @@ def test_azure_spn_info_with_secret_unavailable():
251229
"spark.hadoop.fs.azure.account."
252230
"oauth2.client.secret.abcde.dfs.core.windows.net": "{{secrets/abcff/sp_secret}}",
253231
}
254-
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._get_azure_spn_list(spark_conf)
232+
crawler = AzureServicePrincipalCrawler(ws, MockBackend(), "ucx")._get_azure_spn_from_config(spark_conf)
255233

256-
assert crawler == []
234+
assert crawler == set()

0 commit comments

Comments
 (0)