Skip to content

Commit 3cdf2ce

Browse files
authored
Add instance pool to cluster policy (#1078)
1 parent db099c6 commit 3cdf2ce

File tree

3 files changed

+77
-4
lines changed

3 files changed

+77
-4
lines changed

src/databricks/labs/ucx/installer/policy.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def _policy_config(value: str):
2525
def create(self, inventory_database: str) -> tuple[str, str, dict]:
2626
instance_profile = ""
2727
spark_conf_dict = {}
28+
# get instance pool id to be put into the cluster policy
29+
instance_pool_id = self._get_instance_pool_id()
2830
policies_with_external_hms = list(self._get_cluster_policies_with_external_hive_metastores())
2931
if len(policies_with_external_hms) > 0 and self._prompts.confirm(
3032
"We have identified one or more cluster policies set up for an external metastore"
@@ -54,7 +56,7 @@ def create(self, inventory_database: str) -> tuple[str, str, dict]:
5456
logger.info("Creating UCX cluster policy.")
5557
policy_id = self._ws.cluster_policies.create(
5658
name=policy_name,
57-
definition=self._definition(spark_conf_dict, instance_profile),
59+
definition=self._definition(spark_conf_dict, instance_profile, instance_pool_id),
5860
description="Custom cluster policy for Unity Catalog Migration (UCX)",
5961
).policy_id
6062
assert policy_id is not None
@@ -64,11 +66,35 @@ def create(self, inventory_database: str) -> tuple[str, str, dict]:
6466
spark_conf_dict,
6567
)
6668

67-
def _definition(self, conf: dict, instance_profile: str | None) -> str:
69+
def _get_instance_pool_id(self) -> str | None:
70+
try:
71+
instance_pool_id = self._prompts.question(
72+
"Instance pool id to be set in cluster policy for all workflow clusters", default="None"
73+
)
74+
except OSError:
75+
# when unit test v0.15.0_added_cluster_policy.py MockPromots cannot be injected to ClusterPolicyInstaller
76+
# return None to pass the test
77+
return None
78+
if instance_pool_id.lower() == "none":
79+
return None
80+
try:
81+
self._ws.instance_pools.get(instance_pool_id)
82+
return instance_pool_id
83+
except NotFound:
84+
logger.warning(
85+
f"Instance pool id {instance_pool_id} does not exist. Will not set instance pool in the cluster policy. You can manually edit the cluster policy after installation."
86+
)
87+
return None
88+
89+
def _definition(self, conf: dict, instance_profile: str | None, instance_pool_id: str | None) -> str:
6890
policy_definition = {
6991
"spark_version": self._policy_config(self._ws.clusters.select_spark_version(latest=True)),
7092
"node_type_id": self._policy_config(self._ws.clusters.select_node_type(local_disk=True)),
7193
}
94+
if instance_pool_id:
95+
policy_definition["instance_pool_id"] = self._policy_config(instance_pool_id)
96+
# 'node_type_id' cannot be supplied when an instance pool ID is provided
97+
policy_definition.pop("node_type_id")
7298
for key, value in conf.items():
7399
policy_definition[f"spark_conf.{key}"] = self._policy_config(value)
74100
if self._ws.config.is_aws:

tests/integration/test_installation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,13 @@ def test_check_inventory_database_exists(ws, new_installation):
466466
assert err.value.args[0] == f"Inventory database '{inventory_database}' already exists in another installation"
467467

468468

469-
@pytest.mark.skip
470469
@retried(on=[NotFound], timeout=timedelta(minutes=10))
471470
def test_table_migration_job( # pylint: disable=too-many-locals
472471
ws, new_installation, make_catalog, make_schema, make_table, env_or_skip, make_random, make_dbfs_data_copy
473472
):
473+
# skip this test if not in nightly test job: TEST_NIGHTLY is missing or is not set to "true"
474+
if env_or_skip("TEST_NIGHTLY").lower() != "true":
475+
pytest.skip("TEST_NIGHTLY is not true")
474476
# create external and managed tables to be migrated
475477
src_schema = make_schema(catalog_name="hive_metastore")
476478
src_managed_table = make_table(schema_name=src_schema.name)
@@ -489,6 +491,7 @@ def test_table_migration_job( # pylint: disable=too-many-locals
489491
r"Parallelism for migrating.*": "1000",
490492
r"Min workers for auto-scale.*": "2",
491493
r"Max workers for auto-scale.*": "20",
494+
r"Instance pool id to be set.*": env_or_skip("TEST_INSTANCE_POOL_ID"),
492495
},
493496
)
494497
# save table mapping for migration before trigger the run

tests/unit/installer/test_policy.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from databricks.sdk import WorkspaceClient
88
from databricks.sdk.errors import NotFound
99
from databricks.sdk.service import iam
10-
from databricks.sdk.service.compute import ClusterSpec, Policy
10+
from databricks.sdk.service.compute import ClusterSpec, GetInstancePool, Policy
1111
from databricks.sdk.service.jobs import Job, JobCluster, JobSettings
1212
from databricks.sdk.service.sql import (
1313
EndpointConfPair,
@@ -34,6 +34,7 @@ def common():
3434
{
3535
r".*We have identified one or more cluster.*": "Yes",
3636
r".*Choose a cluster policy.*": "0",
37+
r".*Instance pool id to be set in cluster policy.*": "",
3738
}
3839
)
3940
return w, prompts
@@ -252,6 +253,7 @@ def test_cluster_policy_definition_azure_hms_warehouse():
252253
{
253254
r".*We have identified one or more cluster.*": "No",
254255
r".*We have identified the workspace warehouse.*": "Yes",
256+
r".*Instance pool id to be set in cluster policy.*": "",
255257
}
256258
)
257259
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
@@ -303,6 +305,7 @@ def test_cluster_policy_definition_aws_glue_warehouse():
303305
{
304306
r".*We have identified one or more cluster.*": "No",
305307
r".*We have identified the workspace warehouse.*": "Yes",
308+
r".*Instance pool id to be set in cluster policy.*": "",
306309
}
307310
)
308311
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
@@ -357,6 +360,7 @@ def test_cluster_policy_definition_gcp_hms_warehouse():
357360
{
358361
r".*We have identified one or more cluster.*": "No",
359362
r".*We have identified the workspace warehouse.*": "Yes",
363+
r".*Instance pool id to be set in cluster policy.*": "",
360364
}
361365
)
362366
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
@@ -414,3 +418,43 @@ def test_cluster_policy_definition_empty_config():
414418
definition=json.dumps(policy_definition_actual),
415419
description="Custom cluster policy for Unity Catalog Migration (UCX)",
416420
)
421+
422+
423+
def test_cluster_policy_instance_pool():
424+
ws, prompts = common()
425+
prompts = prompts.extend({r".*Instance pool id to be set in cluster policy.*": "instance_pool_1"})
426+
427+
ws.instance_pools.get.return_value = GetInstancePool("instance_pool_1")
428+
ws.cluster_policies.list.return_value = []
429+
ws.config.is_aws = True
430+
ws.config.is_azure = False
431+
ws.config.is_gcp = False
432+
433+
policy_installer = ClusterPolicyInstaller(MockInstallation(), ws, prompts)
434+
policy_installer.create('ucx')
435+
436+
policy_expected = {
437+
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
438+
"instance_pool_id": {"type": "fixed", "value": "instance_pool_1"},
439+
"aws_attributes.availability": {"type": "fixed", "value": "ON_DEMAND"},
440+
}
441+
# test the instance pool is added to the cluster policy
442+
ws.cluster_policies.create.assert_called_with(
443+
name="Unity Catalog Migration (ucx) ([email protected])",
444+
definition=json.dumps(policy_expected),
445+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
446+
)
447+
448+
# test the instance pool is not found
449+
ws.instance_pools.get.side_effect = NotFound()
450+
policy_expected = {
451+
"spark_version": {"type": "fixed", "value": "14.2.x-scala2.12"},
452+
"node_type_id": {"type": "fixed", "value": "Standard_F4s"},
453+
"aws_attributes.availability": {"type": "fixed", "value": "ON_DEMAND"},
454+
}
455+
policy_installer.create('ucx')
456+
ws.cluster_policies.create.assert_called_with(
457+
name="Unity Catalog Migration (ucx) ([email protected])",
458+
definition=json.dumps(policy_expected),
459+
description="Custom cluster policy for Unity Catalog Migration (UCX)",
460+
)

0 commit comments

Comments
 (0)