Skip to content

Commit 35bfe7c

Browse files
authored
Add a check for external metastore in SQL warehouse configuration (#1046)
1 parent 50d57e8 commit 35bfe7c

File tree

2 files changed

+249
-11
lines changed

2 files changed

+249
-11
lines changed

src/databricks/labs/ucx/installer/policy.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from databricks.sdk import WorkspaceClient
88
from databricks.sdk.errors import NotFound
99
from databricks.sdk.service import compute
10+
from databricks.sdk.service.sql import GetWorkspaceWarehouseConfigResponse
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -34,6 +35,14 @@ def create(self, inventory_database: str) -> tuple[str, str, dict]:
3435
if len(cluster_policies) >= 1:
3536
cluster_policy = json.loads(self._prompts.choice_from_dict("Choose a cluster policy", cluster_policies))
3637
instance_profile, spark_conf_dict = self._extract_external_hive_metastore_conf(cluster_policy)
38+
else:
39+
warehouse_config = self._get_warehouse_config_with_external_hive_metastore()
40+
if warehouse_config and self._prompts.confirm(
41+
"We have identified the workspace warehouse is set up for an external metastore"
42+
"Would you like to set UCX to connect to the external metastore?"
43+
):
44+
logger.info("Setting up an external metastore")
45+
instance_profile, spark_conf_dict = self._extract_external_hive_metastore_sql_conf(warehouse_config)
3746
policy_name = f"Unity Catalog Migration ({inventory_database}) ({self._ws.current_user.me().user_name})"
3847
policies = self._ws.cluster_policies.list()
3948
for policy in policies:
@@ -63,12 +72,12 @@ def _definition(self, conf: dict, instance_profile: str | None) -> str:
6372
for key, value in conf.items():
6473
policy_definition[f"spark_conf.{key}"] = self._policy_config(value)
6574
if self._ws.config.is_aws:
75+
if instance_profile:
76+
policy_definition["aws_attributes.instance_profile_arn"] = self._policy_config(instance_profile)
6677
policy_definition["aws_attributes.availability"] = self._policy_config(
6778
compute.AwsAvailability.ON_DEMAND.value
6879
)
69-
if instance_profile:
70-
policy_definition["aws_attributes.instance_profile_arn"] = self._policy_config(instance_profile)
71-
elif self._ws.config.is_azure: # pylint: disable=confusing-consecutive-elif
80+
elif self._ws.config.is_azure:
7281
policy_definition["azure_attributes.availability"] = self._policy_config(
7382
compute.AzureAvailability.ON_DEMAND_AZURE.value
7483
)
@@ -107,6 +116,41 @@ def _get_cluster_policies_with_external_hive_metastores(self):
107116
yield policy
108117
break
109118

119+
@staticmethod
120+
def _extract_external_hive_metastore_sql_conf(sql_config: GetWorkspaceWarehouseConfigResponse):
121+
spark_conf_dict: dict[str, str] = {}
122+
instance_profile = None
123+
if sql_config.instance_profile_arn is not None:
124+
instance_profile = sql_config.instance_profile_arn
125+
logger.info(f"Instance Profile is Set to {instance_profile}")
126+
if sql_config.data_access_config is None:
127+
return instance_profile, spark_conf_dict
128+
for conf in sql_config.data_access_config:
129+
if conf.key is None:
130+
continue
131+
if conf.value is None:
132+
continue
133+
if (
134+
conf.key.startswith("spark.sql.hive.metastore")
135+
or conf.key.startswith("spark.hadoop.javax.jdo.option")
136+
or conf.key.startswith("spark.databricks.hive.metastore")
137+
or conf.key.startswith("spark.hadoop.hive.metastore.glue")
138+
):
139+
spark_conf_dict[conf.key] = conf.value
140+
return instance_profile, spark_conf_dict
141+
142+
def _get_warehouse_config_with_external_hive_metastore(self) -> GetWorkspaceWarehouseConfigResponse | None:
143+
sql_config = self._ws.warehouses.get_workspace_warehouse_config()
144+
if sql_config.data_access_config is None:
145+
return None
146+
for conf in sql_config.data_access_config:
147+
if conf.key is None:
148+
continue
149+
is_glue = conf.key.startswith("spark.databricks.hive.metastore.glueCatalog.enabled")
150+
if conf.key.startswith("spark.sql.hive.metastore") or is_glue:
151+
return sql_config
152+
return None
153+
110154
def update_job_policy(self, state: InstallState, policy_id: str):
111155
if not state.jobs:
112156
logger.error("No jobs found in states")
@@ -118,7 +162,8 @@ def update_job_policy(self, state: InstallState, policy_id: str):
118162
assert job.job_id is not None
119163
assert job_settings is not None
120164
if job_settings.job_clusters is None:
121-
# if job_clusters is None, it means override cluster is being set and hence policy should not be applied
165+
# if job_clusters is None, it means override cluster is being set
166+
# and hence policy should not be applied
122167
return
123168
except NotFound:
124169
logger.error(f"Job id {job_id} not found. Please check if the job is present in the workspace")

tests/unit/installer/test_policy.py

Lines changed: 200 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from databricks.sdk.service import iam
1010
from databricks.sdk.service.compute import ClusterSpec, Policy
1111
from databricks.sdk.service.jobs import Job, JobCluster, JobSettings
12+
from databricks.sdk.service.sql import (
13+
EndpointConfPair,
14+
GetWorkspaceWarehouseConfigResponse,
15+
)
1216

1317
from databricks.labs.ucx.installer.policy import ClusterPolicyInstaller
1418

@@ -67,12 +71,6 @@ def test_cluster_policy_definition_azure_hms():
6771
description="Custom cluster policy for Unity Catalog Migration (UCX)",
6872
)
6973
]
70-
prompts = MockPrompts(
71-
{
72-
r".*We have identified one or more cluster.*": "Yes",
73-
r".*Choose a cluster policy.*": "0",
74-
}
75-
)
7674
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
7775
policy_id, _, _ = policy_installer.create('ucx')
7876
policy_definition_actual = {
@@ -120,8 +118,8 @@ def test_cluster_policy_definition_aws_glue():
120118
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
121119
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
122120
"spark_conf.spark.databricks.hive.metastore.glueCatalog.enabled": {"type": "fixed", "value": "true"},
123-
"aws_attributes.availability": {"type": "fixed", "value": "ON_DEMAND"},
124121
"aws_attributes.instance_profile_arn": {"type": "fixed", "value": "role_arn_1"},
122+
"aws_attributes.availability": {"type": "fixed", "value": "ON_DEMAND"},
125123
}
126124
assert policy_id == "foo1"
127125
assert instance_profile == "role_arn_1"
@@ -221,3 +219,198 @@ def test_update_job_policy():
221219
ws.jobs.get.return_value = job
222220
policy_installer.update_job_policy(states, 'foobar')
223221
ws.jobs.update.assert_called_with(123, new_settings=job_setting)
222+
223+
224+
def test_cluster_policy_definition_azure_hms_warehouse():
225+
ws, _ = common()
226+
ws.config.is_aws = False
227+
ws.config.is_azure = True
228+
ws.config.is_gcp = False
229+
endpoint_conf = [
230+
EndpointConfPair("spark.hadoop.javax.jdo.option.ConnectionURL", "url"),
231+
EndpointConfPair("spark.hadoop.javax.jdo.option.ConnectionUserName", "user1"),
232+
EndpointConfPair("spark.hadoop.javax.jdo.option.ConnectionPassword", "pwd"),
233+
EndpointConfPair("spark.hadoop.javax.jdo.option.ConnectionDriverName", "SQLServerDriver"),
234+
EndpointConfPair("spark.sql.hive.metastore.version", "0.13"),
235+
EndpointConfPair("spark.sql.hive.metastore.jars", "jar1"),
236+
]
237+
238+
ws.warehouses.get_workspace_warehouse_config.return_value = GetWorkspaceWarehouseConfigResponse(
239+
data_access_config=endpoint_conf
240+
)
241+
242+
ws.cluster_policies.list.return_value = [
243+
Policy(
244+
policy_id="id1",
245+
name="foo",
246+
definition=json.dumps({}),
247+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
248+
)
249+
]
250+
251+
prompts = MockPrompts(
252+
{
253+
r".*We have identified one or more cluster.*": "No",
254+
r".*We have identified the workspace warehouse.*": "Yes",
255+
}
256+
)
257+
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
258+
policy_id, _, _ = policy_installer.create('ucx')
259+
policy_definition_actual = {
260+
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
261+
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
262+
"spark_conf.spark.hadoop.javax.jdo.option.ConnectionURL": {"type": "fixed", "value": "url"},
263+
"spark_conf.spark.hadoop.javax.jdo.option.ConnectionUserName": {"type": "fixed", "value": "user1"},
264+
"spark_conf.spark.hadoop.javax.jdo.option.ConnectionPassword": {"type": "fixed", "value": "pwd"},
265+
"spark_conf.spark.hadoop.javax.jdo.option.ConnectionDriverName": {"type": "fixed", "value": "SQLServerDriver"},
266+
"spark_conf.spark.sql.hive.metastore.version": {"type": "fixed", "value": "0.13"},
267+
"spark_conf.spark.sql.hive.metastore.jars": {"type": "fixed", "value": "jar1"},
268+
"azure_attributes.availability": {"type": "fixed", "value": "ON_DEMAND_AZURE"},
269+
}
270+
assert policy_id == "foo1"
271+
272+
ws.cluster_policies.create.assert_called_with(
273+
name="Unity Catalog Migration (ucx) ([email protected])",
274+
definition=json.dumps(policy_definition_actual),
275+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
276+
)
277+
278+
279+
def test_cluster_policy_definition_aws_glue_warehouse():
280+
ws, _ = common()
281+
ws.config.is_aws = True
282+
ws.config.is_azure = False
283+
ws.config.is_gcp = False
284+
endpoint_conf = [
285+
EndpointConfPair("spark.databricks.hive.metastore.glueCatalog.enabled", "true"),
286+
]
287+
288+
ws.warehouses.get_workspace_warehouse_config.return_value = GetWorkspaceWarehouseConfigResponse(
289+
data_access_config=endpoint_conf,
290+
instance_profile_arn="role_arn_1",
291+
)
292+
293+
ws.cluster_policies.list.return_value = [
294+
Policy(
295+
policy_id="id1",
296+
name="foo",
297+
definition=json.dumps({}),
298+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
299+
)
300+
]
301+
302+
prompts = MockPrompts(
303+
{
304+
r".*We have identified one or more cluster.*": "No",
305+
r".*We have identified the workspace warehouse.*": "Yes",
306+
}
307+
)
308+
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
309+
policy_id, instance_profile, _ = policy_installer.create('ucx')
310+
policy_definition_actual = {
311+
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
312+
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
313+
"spark_conf.spark.databricks.hive.metastore.glueCatalog.enabled": {"type": "fixed", "value": "true"},
314+
"aws_attributes.instance_profile_arn": {"type": "fixed", "value": "role_arn_1"},
315+
"aws_attributes.availability": {"type": "fixed", "value": "ON_DEMAND"},
316+
}
317+
assert policy_id == "foo1"
318+
assert instance_profile == "role_arn_1"
319+
320+
ws.cluster_policies.create.assert_called_with(
321+
name="Unity Catalog Migration (ucx) ([email protected])",
322+
definition=json.dumps(policy_definition_actual),
323+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
324+
)
325+
326+
327+
def test_cluster_policy_definition_gcp_hms_warehouse():
328+
ws, _ = common()
329+
ws.config.is_aws = False
330+
ws.config.is_azure = False
331+
ws.config.is_gcp = True
332+
endpoint_conf = [
333+
EndpointConfPair(None, None),
334+
EndpointConfPair("random", None),
335+
EndpointConfPair("spark.hadoop.javax.jdo.option.ConnectionURL", "url"),
336+
EndpointConfPair("spark.hadoop.javax.jdo.option.ConnectionUserName", "user1"),
337+
EndpointConfPair("spark.hadoop.javax.jdo.option.ConnectionPassword", "pwd"),
338+
EndpointConfPair("spark.hadoop.javax.jdo.option.ConnectionDriverName", "SQLServerDriver"),
339+
EndpointConfPair("spark.sql.hive.metastore.version", "0.13"),
340+
EndpointConfPair("spark.sql.hive.metastore.jars", "jar1"),
341+
]
342+
343+
ws.warehouses.get_workspace_warehouse_config.return_value = GetWorkspaceWarehouseConfigResponse(
344+
data_access_config=endpoint_conf
345+
)
346+
347+
ws.cluster_policies.list.return_value = [
348+
Policy(
349+
policy_id="id1",
350+
name="foo",
351+
definition=json.dumps({}),
352+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
353+
)
354+
]
355+
356+
prompts = MockPrompts(
357+
{
358+
r".*We have identified one or more cluster.*": "No",
359+
r".*We have identified the workspace warehouse.*": "Yes",
360+
}
361+
)
362+
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
363+
policy_id, _, _ = policy_installer.create('ucx')
364+
policy_definition_actual = {
365+
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
366+
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
367+
"spark_conf.spark.hadoop.javax.jdo.option.ConnectionURL": {"type": "fixed", "value": "url"},
368+
"spark_conf.spark.hadoop.javax.jdo.option.ConnectionUserName": {"type": "fixed", "value": "user1"},
369+
"spark_conf.spark.hadoop.javax.jdo.option.ConnectionPassword": {"type": "fixed", "value": "pwd"},
370+
"spark_conf.spark.hadoop.javax.jdo.option.ConnectionDriverName": {"type": "fixed", "value": "SQLServerDriver"},
371+
"spark_conf.spark.sql.hive.metastore.version": {"type": "fixed", "value": "0.13"},
372+
"spark_conf.spark.sql.hive.metastore.jars": {"type": "fixed", "value": "jar1"},
373+
"gcp_attributes.availability": {"type": "fixed", "value": "ON_DEMAND_GCP"},
374+
}
375+
assert policy_id == "foo1"
376+
377+
ws.cluster_policies.create.assert_called_with(
378+
name="Unity Catalog Migration (ucx) ([email protected])",
379+
definition=json.dumps(policy_definition_actual),
380+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
381+
)
382+
383+
384+
def test_cluster_policy_definition_empty_config():
385+
ws, prompts = common()
386+
ws.config.is_aws = True
387+
ws.config.is_azure = False
388+
ws.config.is_gcp = False
389+
390+
ws.warehouses.get_workspace_warehouse_config.return_value = GetWorkspaceWarehouseConfigResponse(
391+
data_access_config=None
392+
)
393+
394+
ws.cluster_policies.list.return_value = [
395+
Policy(
396+
policy_id="id1",
397+
name="foo",
398+
definition=json.dumps({}),
399+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
400+
)
401+
]
402+
403+
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
404+
policy_id, _, _ = policy_installer.create('ucx')
405+
policy_definition_actual = {
406+
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
407+
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
408+
"aws_attributes.availability": {"type": "fixed", "value": "ON_DEMAND"},
409+
}
410+
assert policy_id == "foo1"
411+
412+
ws.cluster_policies.create.assert_called_with(
413+
name="Unity Catalog Migration (ucx) ([email protected])",
414+
definition=json.dumps(policy_definition_actual),
415+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
416+
)

0 commit comments

Comments
 (0)